import random
import wandb
import argparse
import copy
import hashlib
import itertools
import logging
import os
from pathlib import Path
import datasets
import diffusers
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.utils.import_utils import is_xformers_available
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from robust_facecloak.model.db_train import  DreamBoothDatasetFromTensor
from robust_facecloak.model.db_train import import_model_class_from_model_name_or_path
# from share_args import parse_args
from robust_facecloak.generic.data_utils import PromptDataset, load_data
from robust_facecloak.generic.share_args import share_parse_args


logger = get_logger(__name__)

import torch
import torch.nn.functional as F


def train_few_step(
    args,
    models,
    tokenizer,
    noise_scheduler,
    vae,
    data_tensor: torch.Tensor,
    num_steps=20,
    step_wise_save=False,
    save_step=100, 
):
    # Load the tokenizer

    unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
    params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())

    optimizer = torch.optim.AdamW(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=1e-2,
        eps=1e-08,
    )

    train_dataset = DreamBoothDatasetFromTensor(
        data_tensor,
        args.instance_prompt,
        tokenizer,
        args.class_data_dir,
        args.class_prompt,
        args.resolution,
        args.center_crop,
    )

    weight_dtype = torch.bfloat16
    device = torch.device("cuda")

    vae.to(device, dtype=weight_dtype)
    text_encoder.to(device, dtype=weight_dtype)
    unet.to(device, dtype=weight_dtype)

    
    step2modelstate={}
    if step_wise_save:
        step2modelstate[0] = {
             "unet": copy.deepcopy(unet.cpu().state_dict()),
            "text_encoder": copy.deepcopy(text_encoder.cpu().state_dict()),
        }
        # move the model back to gpu
        unet.to(device, dtype=weight_dtype)
        text_encoder.to(device, dtype=weight_dtype)
        

    for step in range(num_steps):
        unet.train()
        text_encoder.train()

        step_data = train_dataset[step % len(train_dataset)]
        pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(
            device, dtype=weight_dtype
        )
        input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)

        latents = vae.encode(pixel_values).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get the text embedding for conditioning
        encoder_hidden_states = text_encoder(input_ids)[0]
        
        # print(' this is the encoder hidden states')
        # print(encoder_hidden_states.shape)
        
        # args=self.args
        if "robust_instance_conditioning_vector" in vars(args).keys() and args.robust_instance_conditioning_vector:
            condition_vector = args.robust_instance_conditioning_vector_data
            # print('this is your condition vector')
            # print(condition_vector.shape)
            encoder_hidden_states[0,:7,:] = condition_vector.to(device, dtype=weight_dtype)

        # Predict the noise residual
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        # with prior preservation loss
        if args.with_prior_preservation:
            model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
            target, target_prior = torch.chunk(target, 2, dim=0)

            # Compute instance loss
            instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            # Compute prior loss
            prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

            # Add the prior loss to the instance loss.
            loss = instance_loss + args.prior_loss_weight * prior_loss

        else:
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        loss.backward()
        torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True)
        optimizer.step()
        optimizer.zero_grad()
        if step_wise_save and step % save_step == 0:
            # make sure the model state dict is put to cpu
            step2modelstate[step] = {
                "unet": copy.deepcopy(unet.cpu().state_dict()),
                "text_encoder": copy.deepcopy(text_encoder.cpu().state_dict()),
            }
            # move the model back to gpu
            unet.to(device, dtype=weight_dtype); text_encoder.to(device, dtype=weight_dtype)
    if step_wise_save:
        return [unet, text_encoder], step2modelstate
    else:     
        return [unet, text_encoder]


def load_model(args, model_path):
    print(model_path)
    # import correct text encoder class
    text_encoder_cls = import_model_class_from_model_name_or_path(model_path, args.revision)

    # Load scheduler and models
    text_encoder = text_encoder_cls.from_pretrained(
        model_path,
        subfolder="text_encoder",
        revision=args.revision,
    )
    unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet", revision=args.revision)

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        subfolder="tokenizer",
        revision=args.revision,
        use_fast=False,
    )

    noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")

    vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", revision=args.revision)

    vae.requires_grad_(False)

    if not args.train_text_encoder:
        text_encoder.requires_grad_(False)

    if args.enable_xformers_memory_efficient_attention:
        print("You selected to used efficient xformers")
        print("Make sure to install the following packages before continue")
        print("pip install triton==2.0.0.dev20221031")
        print("pip install pip install xformers==0.0.17.dev461")

        unet.enable_xformers_memory_efficient_attention()

    return text_encoder, unet, tokenizer, noise_scheduler, vae



def parse_args(): 
    
    parser = share_parse_args()
    
    parser.add_argument(
        "--transform_hflip",
        action="store_true",
        help="Whether to use horizontal flip for transform.",
    )
    
    parser.add_argument(
        "--instance_data_dir_for_train",
        type=str,
        default=None,
        required=True,
        help="A folder containing the training data of instance images.",
    )
    parser.add_argument(
        "--instance_data_dir_for_adversarial",
        type=str,
        default=None,
        required=True,
        help="A folder containing the images to add adversarial noise",
    )
    
    parser.add_argument(
        "--defense_pgd_ascending",
        action="store_true",
        help="Whether to use ascending order for pgd.",
    )
    
    parser.add_argument(
        "--defense_pgd_radius",
        type=float,
        default=8,
        help="The radius for defense pgd.",
    )
    
    parser.add_argument(
        "--defense_pgd_step_size",
        type=float,
        default=2,
        help="The step size for defense pgd.",
    )
    parser.add_argument(
        "--defense_pgd_step_num",
        type=int,
        default=8,
        help="The number of steps for defense pgd.",
    )
    
    parser.add_argument(
        "--defense_pgd_random_start",
        action="store_true",
        help="Whether to use random start for pgd.",
    )
    
    parser.add_argument(
        "--attack_pgd_radius",
        type=float,
        default=4,
        help="The radius for attack pgd.",
    )
    parser.add_argument(
        "--attack_pgd_step_size",
        type=int,
        default=2,
        help="The step size for attack pgd.",
    )
    parser.add_argument(
        "--attack_pgd_step_num",
        type=int,
        default=4,
        help="The number of steps for attack pgd.",
    )
    parser.add_argument(
        "--attack_pgd_ascending",
        action="store_true",
        help="Whether to use ascending order for pgd.",
    )
    
    parser.add_argument(
        "--attack_pgd_random_start",
        action="store_true",
        help="Whether to use random start for pgd.",
    )
    
    parser.add_argument(
        "--target_image_path",
        default=None,
        help="target image for attacking",
    )
    
    # args.gau_kernel_size
    parser.add_argument(
        "--gau_kernel_size",
        type=int,
        default=5,
        help="The kernel size for gaussian filter.",
    )
    # defense_sample_num
    parser.add_argument(
        "--defense_sample_num",
        type=int,
        default=1,
        help="The number of samples for defense.",
    )
    
    parser.add_argument(
        "--rot_degree",
        type=int,
        default=5,
        help="The degree for rotation.",
    )
    
    parser.add_argument(
        "--transform_rot", 
        action="store_true",
        help="Whether to use rotation for transform.",
        
    )
    
    parser.add_argument(
        "--transform_gau",
        action="store_true",
        help="Whether to use gaussian filter for transform.",
    )
    
    parser.add_argument(
        "--original_flow", 
        action="store_true",
        help="Whether to use original flow for transform.",
    )
    
    parser.add_argument(
        "--total_trail_num",
        type=int,
        default=60,
    )
    parser.add_argument(
        "--unroll_steps",
        type=int,
        default=2,
    )
    
    parser.add_argument(
        "--interval",
        type=int,
        default=40,
    )
    
    parser.add_argument(
        "--total_train_steps",
        type=int,
        default=1000, 
    )
    
    # remove_kstep
    parser.add_argument( 
        "--remove_kstep",
        action="store_true",
        help="Whether to remove kstep",
    )
    
    # remove_eot
    parser.add_argument(
        "--remove_eot",
        action="store_true",
        help="Whether to remove eot",
    )
    
    # remove_cir_ensemble
    parser.add_argument(
        "--remove_cir_ensemble",
        action="store_true",
        help="Whether to remove cir ensemble",
    )
    
    parser.add_argument(
        "--shuffle_cir_ensemble",
        action="store_true",
        help="Whether to shuffle the ensembling order",
    )
    
    parser.add_argument(
        "--clean_trace_back",
         action="store_true",
         help="Whether to clean the trace back",
    )
    
    
    args = parser.parse_args()
    return args

def main(args):
    logging_dir = Path(args.output_dir, args.logging_dir)

    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        # logging_dir=logging_dir,
    )

    
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    if args.seed is not None:
        set_seed(args.seed)

    # Generate class images if prior preservation is enabled.
    if args.with_prior_preservation:
        class_images_dir = Path(args.class_data_dir)
        if not class_images_dir.exists():
            class_images_dir.mkdir(parents=True)
        cur_class_images = len(list(class_images_dir.iterdir()))

        if cur_class_images < args.num_class_images:
            torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
            if args.mixed_precision == "fp32":
                torch_dtype = torch.float32
            elif args.mixed_precision == "fp16":
                torch_dtype = torch.float16
            elif args.mixed_precision == "bf16":
                torch_dtype = torch.bfloat16
            pipeline = DiffusionPipeline.from_pretrained(
                list(args.pretrained_model_name_or_path.split(","))[-1], 
                torch_dtype=torch_dtype,
                safety_checker=None,
                revision=args.revision,
            )
            pipeline.set_progress_bar_config(disable=True)

            num_new_images = args.num_class_images - cur_class_images
            logger.info(f"Number of class images to sample: {num_new_images}.")

            sample_dataset = PromptDataset(args.class_prompt, num_new_images)
            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)

            sample_dataloader = accelerator.prepare(sample_dataloader)
            pipeline.to(accelerator.device)

            for example in tqdm(
                sample_dataloader,
                desc="Generating class images",
                disable=not accelerator.is_local_main_process,
            ):
                images = pipeline(example["prompt"]).images

                for i, image in enumerate(images):
                    hash_image = hashlib.sha1(image.tobytes()).hexdigest()
                    image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
                    image.save(image_filename)

            del pipeline
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
    if args.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    clean_data = load_data(
        args.instance_data_dir_for_train,
        # size=args.resolution,
        # center_crop=args.center_crop,
    )
    
    perturbed_data = load_data(
        args.instance_data_dir_for_adversarial,
        # size=args.resolution,
        # center_crop=args.center_crop,
    )
    
    original_data= copy.deepcopy(perturbed_data)
    
    if args.robust_instance_conditioning_vector:
        robust_instance_conditioning_vector = torch.load(args.robust_instance_conditioning_vector_path)
        args.robust_instance_conditioning_vector_data = robust_instance_conditioning_vector
        
    import torchvision
    train_aug = [
            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
    ]
    rotater = torchvision.transforms.RandomRotation(degrees=(0, args.rot_degree))
    gau_filter = transforms.GaussianBlur(kernel_size=args.gau_kernel_size,)
    defense_transform = [
    ]
    if args.transform_hflip:
        defense_transform = defense_transform + [transforms.RandomHorizontalFlip(p=0.5)]
    if args.transform_rot:
        defense_transform = defense_transform + [rotater]
    if args.transform_gau:
        defense_transform = [gau_filter] + defense_transform
    
    tensorize_and_normalize = [
        transforms.Normalize([0.5*255]*3,[0.5*255]*3),
    ]
    
    all_trans = train_aug + defense_transform + tensorize_and_normalize
    all_trans = transforms.Compose(all_trans)
    
    non_trans = train_aug + tensorize_and_normalize
    non_trans = transforms.Compose(non_trans)
    if args.remove_eot:
        all_trans = non_trans
    
    from robust_facecloak.attacks.worker.robust_pgd_worker import RobustPGDAttacker
    from robust_facecloak.attacks.worker.pgd_worker import PGDAttacker
    attacker = PGDAttacker(
        radius=args.attack_pgd_radius, 
        steps=args.attack_pgd_step_num, 
        step_size=args.attack_pgd_step_size,
        random_start=args.attack_pgd_random_start,
        ascending=args.attack_pgd_ascending,
        args=args, 
        x_range=[-1, 1],
    )
    defender = RobustPGDAttacker(
        radius=args.defense_pgd_radius,
        steps=args.defense_pgd_step_num,
        step_size=args.defense_pgd_step_size,
        random_start=args.defense_pgd_random_start,
        ascending=args.defense_pgd_ascending,
        args=args,
        attacker=attacker, 
        trans=all_trans,
        sample_num=args.defense_sample_num,
        x_range=[0, 255],
    )
    model_paths = list(args.pretrained_model_name_or_path.split(","))
    num_models = len(model_paths)

    MODEL_BANKS = [load_model(args, path) for path in model_paths]
    MODEL_STATEDICTS = [
        {
            "text_encoder": MODEL_BANKS[i][0].state_dict(),
            "unet": MODEL_BANKS[i][1].state_dict(),
        }
        for i in range(num_models)
    ]
    
    def save_image(perturbed_data, id_stamp):
        save_folder = f"{args.output_dir}/noise-ckpt/{id_stamp}"
        os.makedirs(save_folder, exist_ok=True)
        noised_imgs = perturbed_data.detach()
        img_names = [
            str(instance_path).split("/")[-1]
            for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir())
        ]
        for img_pixel, img_name in zip(noised_imgs, img_names):
            save_path = os.path.join(save_folder, f"noisy_{img_name}")
            Image.fromarray(
                img_pixel.float().detach().cpu().permute(1, 2, 0).numpy().squeeze().astype(np.uint8)
            ).save(save_path)
    
    if args.remove_cir_ensemble:
        total = args.total_trail_num * args.total_train_steps * num_models
        pbar = tqdm(total=total)
        for j in range(args.total_trail_num):
            for model_i in range(num_models):
                text_encoder, unet, tokenizer, noise_scheduler, vae = MODEL_BANKS[model_i]
                f=[unet, text_encoder]
                for i in range(args.total_train_steps):
                    f_per = copy.deepcopy(f)
                    if args.remove_kstep:
                        perturbed_data = defender.perturb(f_per, perturbed_data, original_data, vae, tokenizer, noise_scheduler,)
                    else:
                        f_per = train_few_step(
                            args,
                            f_per,
                            tokenizer,
                            noise_scheduler,
                            vae,
                            perturbed_data.float(),
                            args.unroll_steps,
                        )
                        perturbed_data = defender.perturb(f_per, perturbed_data, original_data, vae, tokenizer, noise_scheduler, )
                    del f_per
                    f = train_few_step(
                        args,
                        f,
                        tokenizer,
                        noise_scheduler,
                        vae,
                        perturbed_data.float(),
                        args.advance_steps,
                    )
                    pbar.update(1)
                del f, unet, text_encoder, tokenizer, noise_scheduler, vae
                torch.cuda.empty_cache()
        save_image(perturbed_data, "final")
        pbar.close()
        return 
    
    
    init_model_state_pool = {}
    pbar = tqdm(total=num_models, desc="initializing models")
    for j in range(num_models):
        init_model_state_pool[j] = {}
        text_encoder, unet, tokenizer, noise_scheduler, vae = MODEL_BANKS[j]
        
        unet.load_state_dict(MODEL_STATEDICTS[j]["unet"])
        text_encoder.load_state_dict(MODEL_STATEDICTS[j]["text_encoder"])
        f_ori = [unet, text_encoder]
        f_ori, step2state_dict = train_few_step(
                args,
                f_ori,
                tokenizer,
                noise_scheduler,
                vae,
                perturbed_data.float(),
                args.total_train_steps,
                step_wise_save=True,
                save_step=args.interval,
        )  
        init_model_state_pool[j] = step2state_dict
        del f_ori, unet, text_encoder, tokenizer, noise_scheduler, vae
        import gc
        gc.collect()
        torch.cuda.empty_cache()
        pbar.update(1)
    pbar.close()
    
            
    steps_list = list(init_model_state_pool[0].keys())
    pbar = tqdm(total=args.total_trail_num * num_models * (args.interval // args.advance_steps) * len(steps_list), desc="meta poison with model ensemble")
    cnt=0
    for _ in range(args.total_trail_num):
        # now_model_state_pool = copy.deepcopy(init_model_state_pool)
        for model_i in range(num_models):
            text_encoder, unet, tokenizer, noise_scheduler, vae = MODEL_BANKS[model_i]
            if args.shuffle_cir_ensemble:
               random.shuffle(steps_list)
            for split_step in steps_list: 
                unet.load_state_dict(init_model_state_pool[model_i][split_step]["unet"])
                text_encoder.load_state_dict(init_model_state_pool[model_i][split_step]["text_encoder"])
                f = [unet, text_encoder]
                if args.clean_trace_back:
                    f_ori = copy.deepcopy(f)
                for j in range(args.interval // args.advance_steps):
                # en_data = 0.0
                    pbar.update(1)
                    if args.remove_kstep:
                        f_per = copy.deepcopy(f)
                        perturbed_data = defender.perturb(f_per, perturbed_data, original_data, vae, tokenizer, noise_scheduler,)
                    else:
                        f_per = copy.deepcopy(f)
                        f_per = train_few_step(
                            args,
                            f_per,
                            tokenizer,
                            noise_scheduler,
                            vae,
                            perturbed_data.float(),
                            args.unroll_steps,
                        )
                        perturbed_data = defender.perturb(f_per, perturbed_data, original_data, vae, tokenizer, noise_scheduler,)
                    cnt+=1
                    del f_per
                    # import gc
                    # gc.collect()
                    # torch.cuda.empty_cache()
                    f = train_few_step(
                        args,
                        f,
                        tokenizer,
                        noise_scheduler,
                        vae,
                        perturbed_data.float(),
                        args.advance_steps,
                    )
                    # f = [f[0].to("cpu"), f[1].to("cpu")]
                    # now_model_state_pool[model_i][split_step]["text_encoder"] = f[1].state_dict()
                    # now_model_state_pool[model_i][split_step]["unet"] = f[0].state_dict()
                    if cnt % 100 == 0:
                        save_image(perturbed_data, f"{cnt}")
                del f 
                if args.clean_trace_back:
                    init_model_state_pool[model_i][split_step]["text_encoder"] = f_ori[1].state_dict()
                    init_model_state_pool[model_i][split_step]["unet"] = f_ori[0].state_dict()
                    del f_ori
                
            del unet, text_encoder, tokenizer, noise_scheduler, vae
            # clean some cache
            # import gc
            # gc.collect()       
            if torch.cuda.is_available():
                torch.cuda.empty_cache() 
            # perturbed_data = en_data
            # del en_data
            
            
        # del now_model_state_pool 
        # del init_model_state_pool
        import gc
        gc.collect()
        torch.cuda.empty_cache()      
    pbar.close()

    save_image(perturbed_data, "final")


if __name__ == "__main__":
    wandb.init(project="", entity="")
    args = parse_args()
    wandb.config.update(args)
    wandb.log({'status': 'gen'})
    main(args)
