#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""

import argparse
import logging
import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path

import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig, get_peft_model
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from peft import PeftConfig, PeftModel
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
from datasets.DomainNet import get_domainnet_dataset_single,get_domainnet_dloader,DomainNet
from models.attention_utils import AttnController,MyCrossAttnProcessor
from datasets.datasets_utils import ForgetMeNotDataset,collate_fn,MyDataset

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
def retrieve_timesteps(
    scheduler,
    num_inference_steps,
    device,
    timesteps = None,
    **kwargs,
):
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps
    
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(model, prefix=prefix)

    warn_missing_keys = []
    ignore_missing_keys = []
    for key in missing_keys:
        keep_flag = True
        for ignore_key in ignore_missing.split('|'):
            if ignore_key in key:
                keep_flag = False
                break
        if keep_flag:
            warn_missing_keys.append(key)
        else:
            ignore_missing_keys.append(key)

    missing_keys = warn_missing_keys

    if len(missing_keys) > 0:
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))
    if len(ignore_missing_keys) > 0:
        print("Ignored weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, ignore_missing_keys))
    if len(error_msgs) > 0:
        print('\n'.join(error_msgs))
        
def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="sd-model-finetuned-lora",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        default="sd-model-finetuned-lora",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    return args

    
def main():
    args = parse_args()
    
    output_dir = args.output_dir
    ckpt_dir = args.ckpt_dir
    
    concepts = os.listdir("./data/celebs")
    concepts.remove('paths.txt')
    client_num = 10
    client_concepts = ['Elon Musk','Donald Trump','Barack Obama','Tom Hiddleston','Rihanna','Arnold Schwarzenegger','Tom Cruise','Leonardo Dicaprio','Andrew Garfield','Joe Biden']
    
    #concepts = os.listdir("./data/artists")
    #concepts.remove('paths.txt')
    #client_num = 10
    #client_concepts = ['Vincent van Gogh','Leonardo da Vinci','Claude Monet','Wassily Kandinsky','J.M.W. Turner','Albrecht Anker','Francisco Goya','Henri Matisse','Hilma af Klint','Paul Gauguin']
    
    for global_step in range(10): 
        for concept in client_concepts:

            accelerator = Accelerator(
                gradient_accumulation_steps=1,
                mixed_precision='fp16'
            )


            def unwrap_model(model):

                model = accelerator.unwrap_model(model)
                model = model._orig_mod if is_compiled_module(model) else model

                return model


            # Disable AMP for MPS.
            if torch.backends.mps.is_available():
                accelerator.native_amp = False        
            if args.seed is not None:
                set_seed(args.seed)     

            text_encoder_dir = os.path.join(ckpt_dir,concept)
            
            out_dir  = os.path.join(output_dir,concept)
            if accelerator.is_main_process:
                if out_dir is not None:
                    os.makedirs(out_dir, exist_ok=True)
            # Load scheduler, tokenizer and models.
            noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
            tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=None)
            text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=None)
            vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=None, variant=None)
            unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None)
            
            teacher_unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None)
            state_dict = torch.load(os.path.join(ckpt_dir+'/'+concept,"checkpoint-50.ckpt"),map_location = 'cpu')
            unet_lora_config = LoraConfig(
                r=4,
                lora_alpha=2,
                init_lora_weights="gaussian",
                target_modules=["to_k", "to_q"],
            )
            teacher_unet.add_adapter(unet_lora_config) 
            load_state_dict(teacher_unet,state_dict)

            # freeze parameters of models to save more memory
            teacher_unet.requires_grad_(False)
            unet.requires_grad_(False)
            vae.requires_grad_(False)
            text_encoder.requires_grad_(False)

            # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
            # as these weights are only used for inference, keeping weights in full precision is not required.
            weight_dtype = torch.float32
            if accelerator.mixed_precision == "fp16":
                weight_dtype = torch.float16
            elif accelerator.mixed_precision == "bf16":
                weight_dtype = torch.bfloat16

            # Freeze the unet parameters before adding adapters
            for param in unet.parameters():
                param.requires_grad_(False)

            unet_lora_config = LoraConfig(
                r=4,
                lora_alpha=2,
                init_lora_weights="gaussian",
                target_modules=["to_k", "to_q"],
            )

            text_encoder = PeftModel.from_pretrained(text_encoder, text_encoder_dir)


            # Move unet, vae and text_encoder to device and cast to weight_dtype
            teacher_unet.to(accelerator.device, dtype=weight_dtype)
            unet.to(accelerator.device, dtype=weight_dtype)
            vae.to(accelerator.device, dtype=weight_dtype)
            text_encoder.to(accelerator.device, dtype=weight_dtype)


            # Add adapter and make sure the trainable params are in float32.


            unet.add_adapter(unet_lora_config)  


            cast_training_params(unet, dtype=torch.float32)
            lora_layers = filter(lambda p: p.requires_grad, unet.parameters())  


            for param in teacher_unet.parameters():
                param.requires_grad_(False)  

            # Enable TF32 for faster training on Ampere GPUs,
            # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices

            # Initialize the optimizer
            optimizer_cls = torch.optim.AdamW

            optimizer = optimizer_cls(
                lora_layers,
                lr=0.001
            )

            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)

            attn_controller = AttnController()
            module_count = 0
            for n, m in unet.named_modules():
                if n.endswith('attn2'):
                    m.set_processor(MyCrossAttnProcessor(attn_controller, n))
                    module_count += 1
            print(f"cross attention module count: {module_count}")


            module_count = 0
            teacher_attn_controller = AttnController()
            for n, m in teacher_unet.named_modules():
                if n.endswith('attn2'):
                    m.set_processor(MyCrossAttnProcessor(teacher_attn_controller, n))
                    module_count += 1
            print(f"cross attention module count: {module_count}")

            # Prepare everything with our `accelerator`.
            unet, optimizer, text_encoder,tokenizer,lr_scheduler,teacher_unet = accelerator.prepare(
                unet, optimizer, text_encoder,tokenizer,lr_scheduler,teacher_unet
            )

            transform = transforms.Compose([ 
                transforms.ToTensor(),
                transforms.RandomResizedCrop(
                    (512, 512)
                    ),
                transforms.Normalize(
                  [0.48145466, 0.4578275, 0.40821073],
                  [0.26862954, 0.26130258, 0.27577711]),
            ])

            train_dataset = MyDataset(tokenizer,'celebs',concept)

            #train_dataset = MyDataset(tokenizer,'artists',concept)

            print(len(train_dataset))
            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, num_workers=8,shuffle=True, pin_memory=True,collate_fn=lambda examples: collate_fn(examples))
            for step, batch in enumerate(train_loader):
                latents = vae.encode(batch["pixel_values"].to(dtype=torch.float16).cuda()).latent_dist.sample()
                latents = latents * vae.config.scaling_factor                
                prompt = batch["input_ids"]
                positions = batch["concept_positions"]
                
            # Train!
            # Potentially load in the weights and states from a previous save

            first_epoch = 0
            debug_once = True
            for epoch in range(10):
                unet.train()
                teacher_unet.eval()
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                stu_latents = noise.detach().clone()
                teacher_latents = noise.detach().clone()
                timesteps, num_inference_steps = retrieve_timesteps(noise_scheduler, 50, noise.device)
                warm_step = 800
                for timestep in timesteps:
                    
                    with accelerator.accumulate(unet):
                        
                        optimizer.zero_grad(set_to_none=False)
                        attn_controller.zero_attn_probs()  
                        teacher_attn_controller.zero_attn_probs()
                        
                        with torch.no_grad():
                            # Get the text embedding for conditioning
                            encoder_hidden_states = text_encoder(batch["input_ids"].cuda())[0]
                            # set concept_positions for this batch 
                            teacher_attn_controller.set_concept_positions(batch["concept_positions"].cuda())

                            # Predict the noise residual and compute loss
                            teacher_model_pred = teacher_unet(teacher_latents.cuda(), timestep, encoder_hidden_states.cuda()).sample
                            teacher_latents = noise_scheduler.step(teacher_model_pred, timestep, latents, return_dict=False)[0]
                            
                        if timestep > warm_step: 
                            with torch.no_grad():
                                # Get the text embedding for conditioning
                                encoder_hidden_states = text_encoder(batch["input_ids"].cuda())[0]
                                attn_controller.set_concept_positions(batch["concept_positions"].cuda())
                                # Predict the noise residual and compute loss
                                model_pred = unet(stu_latents.cuda(), timestep, encoder_hidden_states.cuda()).sample
                                stu_latents = noise_scheduler.step(model_pred, timestep, latents, return_dict=False)[0]
                                print("warming: ",timestep)
                        else :
                            # Get the text embedding for conditioning
                            encoder_hidden_states = text_encoder(batch["input_ids"].cuda())[0]
                            # set concept_positions for this batch 
                            attn_controller.set_concept_positions(batch["concept_positions"].cuda())
                            
                            # Predict the noise residual and compute loss
                            model_pred = unet(stu_latents.cuda(), timestep, encoder_hidden_states.cuda()).sample
                            loss_a = attn_controller.loss()

                            loss_b = []

                            for i in range(len(attn_controller.attn_probs)):
                                attn_loss = torch.nn.functional.mse_loss(attn_controller.attn_probs[i].mean(dim = 1), teacher_attn_controller.attn_probs[i].mean(dim = 1),reduction = 'mean')
                                loss_b.append(attn_loss)

                            loss_b = torch.tensor(loss_b).mean()
                            loss = loss_a + loss_b
                            # Backpropagate
                            accelerator.backward(loss)

                            if accelerator.sync_gradients:
                                params_to_clip = lora_layers
                                accelerator.clip_grad_norm_(params_to_clip, 1.0)
                            
                            optimizer.step()
                            lr_scheduler.step()
                            optimizer.zero_grad(set_to_none=False)
                            attn_controller.zero_attn_probs()  
                            teacher_attn_controller.zero_attn_probs()
                            logs = {"loss": loss.detach().item(),"loss_a": loss_a.detach().item(),"loss_b": loss_b.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                            accelerator.log(logs, step=global_step)  
                            
                            stu_latents = noise_scheduler.step(model_pred, timestep, latents, return_dict=False)[0].detach()
                            print(timestep,logs)
                            

            if accelerator.is_main_process:
                save_path = os.path.join(out_dir, f"checkpoint-{epoch}.ckpt")
                torch.save(unet.state_dict(),save_path)

        # Save the lora layers
        accelerator.wait_for_everyone()
        accelerator.end_training()


if __name__ == "__main__":
    main()
