#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""

import argparse
import logging
import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path

import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig, get_peft_model
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from peft import PeftConfig, PeftModel
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
from datasets.DomainNet import get_domainnet_dataset_single,get_domainnet_dloader,DomainNet
from models.attention_utils import AttnController,MyCrossAttnProcessor
from datasets.datasets_utils import ForgetMeNotDataset,collate_fn,MyDataset

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.

def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(model, prefix=prefix)

    warn_missing_keys = []
    ignore_missing_keys = []
    for key in missing_keys:
        keep_flag = True
        for ignore_key in ignore_missing.split('|'):
            if ignore_key in key:
                keep_flag = False
                break
        if keep_flag:
            warn_missing_keys.append(key)
        else:
            ignore_missing_keys.append(key)

    missing_keys = warn_missing_keys

    if len(missing_keys) > 0:
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))
    if len(ignore_missing_keys) > 0:
        print("Ignored weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, ignore_missing_keys))
    if len(error_msgs) > 0:
        print('\n'.join(error_msgs))
        
def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="sd-model-finetuned-lora",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--round",
        type=int,
        default=0,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    return args

    
def main():
    args = parse_args()
    
    #output_dir = args.output_dir
    
    output_dir = os.path.join(args.output_dir,'round'+str(args.round))
    
    #concepts = os.listdir("./data/celebs")
    #concepts.remove('paths.txt')
    #client_num = 10
    #client_concepts = ['Elon Musk','Donald Trump','Barack Obama','Tom Hiddleston','Rihanna','Arnold Schwarzenegger','Tom Cruise','Leonardo Dicaprio','Andrew Garfield','Joe Biden']
    
    #concepts = os.listdir("./data/artists")
    #concepts.remove('paths.txt')
    #client_num = 10
    #client_concepts = ['Vincent van Gogh','Leonardo da Vinci','Claude Monet','Wassily Kandinsky','J.M.W. Turner','Albrecht Anker','Francisco Goya','Henri Matisse','Hilma af Klint','Paul Gauguin']
    
    client_concepts = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']
    
    
    for concept in client_concepts:
        torch.cuda.empty_cache()
        accelerator = Accelerator(
            gradient_accumulation_steps=1,
            mixed_precision='fp16'
        )
        
        
        def unwrap_model(model):

            model = accelerator.unwrap_model(model)
            model = model._orig_mod if is_compiled_module(model) else model

            return model
    
    
        # Disable AMP for MPS.
        if torch.backends.mps.is_available():
            accelerator.native_amp = False        
        if args.seed is not None:
            set_seed(args.seed+args.round)     
            
        out_dir = os.path.join(output_dir,concept)
        
        if accelerator.is_main_process:
            if out_dir is not None:
                os.makedirs(out_dir, exist_ok=True)

        # Load scheduler, tokenizer and models.
        noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=None)
        text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=None)
        vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=None, variant=None)
        unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None)
        teacher_unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None)


        # freeze parameters of models to save more memory
        teacher_unet.requires_grad_(False)
        unet.requires_grad_(False)
        vae.requires_grad_(False)
        text_encoder.requires_grad_(False)

        # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
        # as these weights are only used for inference, keeping weights in full precision is not required.
        weight_dtype = torch.float32
        if accelerator.mixed_precision == "fp16":
            weight_dtype = torch.float16
        elif accelerator.mixed_precision == "bf16":
            weight_dtype = torch.bfloat16

        # Freeze the unet parameters before adding adapters
        for param in unet.parameters():
            param.requires_grad_(False)

        unet_lora_config = LoraConfig(
            r=4,
            lora_alpha=2,
            init_lora_weights="gaussian",
            target_modules=["to_k", "to_q"],
        )

        #text_encoder = PeftModel.from_pretrained(text_encoder, out_dir)   
        
        # Move unet, vae and text_encoder to device and cast to weight_dtype
        teacher_unet.to(accelerator.device, dtype=weight_dtype)
        unet.to(accelerator.device, dtype=weight_dtype)
        vae.to(accelerator.device, dtype=weight_dtype)
        text_encoder.to(accelerator.device, dtype=weight_dtype)
        
        # Add adapter and make sure the trainable params are in float32.
        unet.add_adapter(unet_lora_config)  
        cast_training_params(unet, dtype=torch.float32)
        lora_layers = filter(lambda p: p.requires_grad, unet.parameters())  
        
            
        if args.round != 0 :
            teacher_unet.add_adapter(unet_lora_config)   
            prev_dir = os.path.join(args.output_dir,'round'+str(args.round-1))
            state_dict = torch.load(os.path.join(prev_dir,"agg_unet.ckpt"),map_location = 'cpu')
            load_state_dict(unet,state_dict)
            load_state_dict(teacher_unet,state_dict)    
            for name,i in unet.named_parameters():
                if 'lora' in name:
                    i.requires_grad = True
            cast_training_params(unet, dtype=torch.float32)
            lora_layers = filter(lambda p: p.requires_grad, unet.parameters())  
            
        for param in teacher_unet.parameters():
            param.requires_grad_(False)  
            
        # Enable TF32 for faster training on Ampere GPUs,
        # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices

        # Initialize the optimizer
        optimizer_cls = torch.optim.AdamW

        optimizer = optimizer_cls(
            lora_layers,
            lr=0.001
        )

        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)

        attn_controller = AttnController()
        module_count = 0
        for n, m in unet.named_modules():
            if n.endswith('attn2'):
                m.set_processor(MyCrossAttnProcessor(attn_controller, n))
                module_count += 1
        print(f"cross attention module count: {module_count}")


        module_count = 0
        teacher_attn_controller = AttnController()
        for n, m in teacher_unet.named_modules():
            if n.endswith('attn2'):
                m.set_processor(MyCrossAttnProcessor(teacher_attn_controller, n))
                module_count += 1
        print(f"cross attention module count: {module_count}")

        # Prepare everything with our `accelerator`.
        unet, optimizer, text_encoder,tokenizer,lr_scheduler,teacher_unet = accelerator.prepare(
            unet, optimizer, text_encoder,tokenizer,lr_scheduler,teacher_unet
        )

        transform = transforms.Compose([ 
            transforms.ToTensor(),
            transforms.RandomResizedCrop(
                (512, 512)
                ),
            transforms.Normalize(
              [0.48145466, 0.4578275, 0.40821073],
              [0.26862954, 0.26130258, 0.27577711]),
        ])

        #train_dataset = MyDataset(tokenizer,'celebs',concept)
        train_dataset = MyDataset(tokenizer,'artists',concept)
        
        print(len(train_dataset))
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, num_workers=8,shuffle=True, pin_memory=True,collate_fn=lambda examples: collate_fn(examples))


        # Train!
        # Potentially load in the weights and states from a previous save

        global_step = 0
        first_epoch = 0
        progress_bar = tqdm(range(global_step, 50000), disable=not accelerator.is_local_main_process)
        progress_bar.set_description("Steps")

        debug_once = True
        for epoch in range(20):
            unet.train()
            for step, batch in enumerate(train_loader):
                with accelerator.accumulate(unet):

                    if debug_once:
                        print(batch["instance_prompts"][0])
                        print(batch["instance_refine_prompts"][0])
                        debug_once = False
                    # Convert images to latent space
                    latents = vae.encode(batch["pixel_values"].to(dtype=torch.float16).cuda()).latent_dist.sample()
                    latents = latents * vae.config.scaling_factor

                    # Sample noise that we'll add to the latents
                    noise = torch.randn_like(latents)
                    bsz = latents.shape[0]
                    # Sample a random timestep for each image
                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
                    timesteps = timesteps.long()

                    # Add noise to the latents according to the noise magnitude at each timestep
                    # (this is the forward diffusion process)
                    #if args.no_real_image:
                    #    noisy_latents = noise_scheduler.add_noise(torch.zeros_like(noise), noise, timesteps)                
                    #else:

                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                    
                    with torch.no_grad():
                        # Get the text embedding for conditioning
                        encoder_hidden_states = text_encoder(batch["refine_input_ids"].cuda())[0]
                        # set concept_positions for this batch 
                        teacher_attn_controller.set_concept_positions(batch["refine_concept_positions"].cuda())

                        # Predict the noise residual and compute loss
                        tea_model_pred = teacher_unet(noisy_latents.cuda(), timesteps, encoder_hidden_states.cuda()).sample

                    # Get the text embedding for conditioning
                    encoder_hidden_states = text_encoder(batch["input_ids"].cuda())[0]
                    # set concept_positions for this batch 
                    attn_controller.set_concept_positions(batch["concept_positions"].cuda())

                    # Predict the noise residual and compute loss
                    model_pred = unet(noisy_latents.cuda(), timesteps, encoder_hidden_states.cuda()).sample
                    
                    #loss_a = attn_controller.loss()
                    mse_loss = torch.nn.functional.mse_loss(model_pred,tea_model_pred,reduction='sum')
                    #loss_b = []
                    
                    #for i in range(len(attn_controller.attn_probs)):
                    #    attn_loss = torch.nn.functional.mse_loss(attn_controller.attn_probs[i].mean(dim = 1), teacher_attn_controller.attn_probs[i].mean(dim = 1),reduction = 'mean')
                    #    loss_b.append(attn_loss)
                        
                    #loss_b = torch.tensor(loss_b).mean()

                    loss = mse_loss
                    
                    loss_a = mse_loss
                    loss_b = mse_loss
                    #loss = loss_a + loss_b


                    # Backpropagate
                    accelerator.backward(loss)

                    if accelerator.sync_gradients:
                        params_to_clip = lora_layers
                        accelerator.clip_grad_norm_(params_to_clip, 1.0)


                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad(set_to_none=False)
                    attn_controller.zero_attn_probs()  
                    teacher_attn_controller.zero_attn_probs()

                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    global_step += 1

                logs = {"loss": loss.detach().item(),"loss_a": loss_a.detach().item(),"loss_b": loss_b.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                progress_bar.set_postfix(**logs)
                accelerator.log(logs, step=global_step)     
                    
        if accelerator.is_main_process:
            save_path = os.path.join(out_dir, f"checkpoint-{global_step}.ckpt")
            torch.save(unet.module.state_dict(),save_path)

        # Save the lora layers
        accelerator.wait_for_everyone()
        accelerator.end_training()


if __name__ == "__main__":
    main()
