#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""

import argparse
import logging
import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path

import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig, get_peft_model
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
from peft import PeftConfig, PeftModel
from peft import LoraConfig, get_peft_model

def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(model, prefix=prefix)

    warn_missing_keys = []
    ignore_missing_keys = []
    for key in missing_keys:
        keep_flag = True
        for ignore_key in ignore_missing.split('|'):
            if ignore_key in key:
                keep_flag = False
                break
        if keep_flag:
            warn_missing_keys.append(key)
        else:
            ignore_missing_keys.append(key)

    missing_keys = warn_missing_keys

    if len(missing_keys) > 0:
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))
    if len(ignore_missing_keys) > 0:
        print("Ignored weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, ignore_missing_keys))
    if len(error_msgs) > 0:
        print('\n'.join(error_msgs))
        


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.

def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="sd-model-finetuned-lora",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--round",
        type=int,
        default=0,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    return args

def main():
    args = parse_args()

    # Handle the repository creation
    concepts = os.listdir("./data/celebs")
    concepts.remove('paths.txt')
    client_num = 10
    client_concepts = ['Elon Musk','Donald Trump','Barack Obama','Tom Hiddleston','Rihanna','Arnold Schwarzenegger','Tom Cruise','Leonardo Dicaprio','Andrew Garfield','Joe Biden']
    
    #concepts = os.listdir("./data/artists")
    #concepts.remove('paths.txt')
    #client_num = 10
    #client_concepts = ['Vincent van Gogh','Leonardo da Vinci','Claude Monet','Wassily Kandinsky','J.M.W. Turner','Albrecht Anker','Francisco Goya','Henri Matisse','Hilma af Klint','Paul Gauguin']
    
    output_dir = os.path.join(args.output_dir,'round'+str(args.round))
    
    for concept in client_concepts:
        accelerator = Accelerator(
            gradient_accumulation_steps=1,
            mixed_precision='fp16'
        )
        if torch.backends.mps.is_available():
            accelerator.native_amp = False
        if args.seed is not None:
            set_seed(args.seed+args.round)      
            
        out_dir = os.path.join(output_dir,concept)
        
        if accelerator.is_main_process:
            if out_dir is not None:
                os.makedirs(out_dir, exist_ok=True)
                
        # Load scheduler, tokenizer and models.
        noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=None)
        text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=None)
        vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=None, variant=None)
        unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None)
        teacher_text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=None)
        # freeze parameters of models to save more memory
        unet.requires_grad_(False)
        vae.requires_grad_(False)
        text_encoder.requires_grad_(False)
        teacher_text_encoder.requires_grad_(False)
        # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
        # as these weights are only used for inference, keeping weights in full precision is not required.
        weight_dtype = torch.float32
        if accelerator.mixed_precision == "fp16":
            weight_dtype = torch.float16
        elif accelerator.mixed_precision == "bf16":
            weight_dtype = torch.bfloat16

        # Freeze the unet parameters before adding adapters
        for param in unet.parameters():
            param.requires_grad_(False)

        #unet_lora_config = LoraConfig(
        #    r=2.,
        #    lora_alpha=2,
        #    init_lora_weights="gaussian",
        #    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        #)
        text_lora_config = LoraConfig(
        r=32,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.0,
        bias="none",)

        # Move unet, vae and text_encoder to device and cast to weight_dtype
        unet.to(accelerator.device, dtype=weight_dtype)
        vae.to(accelerator.device, dtype=weight_dtype)
        text_encoder.to(accelerator.device, dtype=weight_dtype)
        teacher_text_encoder.to(accelerator.device, dtype=weight_dtype)
        
        # Add adapter and make sure the trainable params are in float32.
        lora_text_encoder = get_peft_model(text_encoder, text_lora_config)         
        cast_training_params(lora_text_encoder, dtype=torch.float32)
        lora_layers = filter(lambda p: p.requires_grad, lora_text_encoder.parameters())    
            
        if args.round != 0 :
            prev_dir = os.path.join(args.output_dir,'round'+str(args.round-1))
            prev_text_encoder = PeftModel.from_pretrained(text_encoder,os.path.join(prev_dir,concept))
            state_dict = torch.load(os.path.join(prev_dir,"agg_text_encoder.ckpt"),map_location = 'cpu')
            load_state_dict(prev_text_encoder,state_dict)
            load_state_dict(lora_text_encoder,prev_text_encoder.state_dict())
            for name,i in lora_text_encoder.named_parameters():
                if 'lora' in name:
                    i.requires_grad = True
            cast_training_params(lora_text_encoder, dtype=torch.float32)
            lora_layers = filter(lambda p: p.requires_grad, lora_text_encoder.parameters())      
        # Enable TF32 for faster training on Ampere GPUs,
        # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices

        # Initialize the optimizer
        optimizer_cls = torch.optim.AdamW

        optimizer = optimizer_cls(
            lora_layers,
            lr=0.001
        )

        # Prepare everything with our `accelerator`.
        unet, optimizer, lora_text_encoder,tokenizer = accelerator.prepare(
            unet, optimizer, lora_text_encoder,tokenizer
        )

        # Train!
        # Potentially load in the weights and states from a previous save

        forget_text_inputs = tokenizer(
            "an image of " + concept,
            #"An artwork in "+ concept +" style.",
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
            )
        forget_text_input_ids = forget_text_inputs.input_ids
        refine_text_inputs = tokenizer(
            "an image of a person",
            #"An artwork in normal style",
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
            )
        refine_text_input_ids = refine_text_inputs.input_ids
        refine_text_feature = teacher_text_encoder(refine_text_input_ids.to(text_encoder.device))[0]

        for epoch in range(5):
            lora_text_encoder.train()
            with accelerator.accumulate(lora_text_encoder):
                forget_text_feature = lora_text_encoder(forget_text_input_ids.to(lora_text_encoder.device))[0]
                noised_refine_text_feature = refine_text_feature.detach()

                loss = F.cosine_similarity(forget_text_feature.float(), noised_refine_text_feature.float()).mean()
                print(epoch,loss)
                accelerator.backward(loss)
                # Backpropagate
                if accelerator.sync_gradients:
                    params_to_clip = lora_layers
                    accelerator.clip_grad_norm_(params_to_clip, 1.0)
                optimizer.step()
                optimizer.zero_grad()

        lora_text_encoder.module.save_pretrained(out_dir)
        accelerator.wait_for_everyone()
        accelerator.end_training()
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()
