import torch
import pytorch_lightning as L
from torch.utils.data import DataLoader, DistributedSampler
from omegaconf import DictConfig
from torchvision.transforms.functional import pil_to_tensor
from functools import partial
import random
from tqdm import tqdm

from datasets import load_dataset
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel

from custom_pipeline import CustomPipeline
# from custom_pipeline_new_key import CustomPipeline
from utils.attacks import (attack_blur, attack_br_shift, 
                        attack_contrast, attack_gamma,
                        attack_noise, attack_WB)
from utils.pxl_swap import extract_wm, gen_private_key, gen_public_key, compute_wm_error
from DiffJPEG.DiffJPEG import DiffJPEG
from kornia.enhance import AdjustSaturation, AdjustHue, sharpness

from datetime import datetime, timedelta

torch.set_float32_matmul_precision('medium')# | 'high')


class PromptDS(torch.utils.data.Dataset):
    def __init__(self, prompts, caption_col_name, num):
        self.caption_col_name = caption_col_name
        self.prompts = list(set(prompts))
        self.prompts.remove('')
        
        self.prompts = sorted(self.prompts)
        random.shuffle(self.prompts) 
        self.prompts = self.prompts[:num]
    
    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self, idx):
        return {
            self.caption_col_name: self.prompts[idx]
        }


class DiffusionModel(L.LightningModule):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.pipeline_cfg = cfg.pipeline
        self.mode = cfg.mode
        self.num_trials = cfg.num_trials

        self.batch_size = cfg.batch_size
        self.seed = cfg.seed
        self.use_ddp = cfg.trainer.get('strategy', None) == 'ddp'

        self.setup_models(cfg)
        self.setup_data(cfg)

        self.caption_column = cfg.caption_column
        self.image_column = cfg.image_column
        
        self.resolution = cfg.resolution
        
        self.wm_len = cfg.wm_len
        
        self.is_structured = cfg.is_structured
        self.inner_region_len = cfg.inner_region_len
        self.delta = cfg.delta
        
        self.num_collision_trials = cfg.num_collision_trials

        blur_sigma = torch.full((1,), cfg.blur_sigma)
        
        # accepts images w/ values in range [0,1]!
        self.diff_jpeg = DiffJPEG(height=cfg.resolution[0], width=cfg.resolution[1], 
                            differentiable=False, quality=cfg.jpeg_quality).eval()
        
        saturation_attack = AdjustSaturation(cfg.saturation_factor)
        hue_attack = AdjustHue(cfg.hue_factor)
        sharpness_attack = partial(sharpness, factor=cfg.sharpness_factor)
        
        self.attacks = {
            'blur': partial(
                attack_blur, 
                sigma=blur_sigma
            ),
            'jpeg': lambda x: self.diff_jpeg(x / 255),
            'brightness_shift': partial(
                attack_br_shift, 
                low=cfg.br_shift.low, 
                high=cfg.br_shift.high
            ),
            'positive_contrast': partial(
                attack_contrast, 
                low=cfg.positive_contrast.low, 
                high=cfg.positive_contrast.high
                ),
            'negative_contrast': partial(
                attack_contrast, 
                low=cfg.negative_contrast.low, 
                high=cfg.negative_contrast.high
                ),
            'gamma': partial(
                attack_gamma, 
                low=cfg.gamma.low, 
                high=cfg.gamma.high
            ),
            'saturation': lambda x: saturation_attack(x.detach() / 255),
            'hue': lambda x: hue_attack(x.detach() / 255),
            'sharpness': lambda x: sharpness_attack(x.detach() / 255),
            'noise': partial(
                attack_noise,
                inf_norm_val=cfg.noise_inf_norm_val
            ),          # TODO mutliple noises and wb
            'white_box': partial(
                attack_WB,
                num_iter=cfg.white_box.num_iter,
                lr=cfg.white_box.lr,
                attack_budget=cfg.white_box.attack_budget, 
                wm_loss_w=cfg.pipeline.wm_loss_w, 
                lpips_w=cfg.pipeline.lpips_w, 
                mse_w=cfg.pipeline.mse_w, 
                eps=cfg.pipeline.eps, 
            ),
        }
        
        self._times = []


    def setup_data(self, cfg):
        dataset = load_dataset(
            cfg.dataset_name,
            cfg.dataset_config_name,
            cache_dir=cfg.cache_dir,
            data_dir=cfg.train_data_dir,
            # trust_remote_code=cfg.trust_remote_code,
            # split='train'
        )
        
        def preprocess_val(examples):
            return {
                self.caption_column: examples[self.caption_column]
                }
        
        # if cfg.max_val_samples is not None:
        #     val_samples = min(len(dataset['prompt']), cfg.max_val_samples)
        #     dataset['prompt'] = dataset['prompt'].select(range(val_samples))
        # self.val_ds = dataset['prompt']
        
        # print(f"DS LEN before transform {len(dataset['train']['prompt'])}")
        # print(f"{dataset['train']['prompt'][:32]}")
        
        # if cfg.max_val_samples is not None:
        #     val_samples = min(len(dataset['train']), cfg.max_val_samples)
        #     dataset['train'] = dataset['train'].select(range(val_samples))
        # self.val_ds = dataset['train'].with_transform(preprocess_val)

        self.val_ds = PromptDS(
            prompts=dataset['train']['prompt'],
            caption_col_name=cfg.caption_column,
            num=cfg.max_val_samples
        )

        print(f'DS LEN {len(self.val_ds)}')

    def setup_models(self, cfg):
        self.noise_scheduler = DDPMScheduler.from_pretrained(cfg.pretrained_model_name_or_path,
                                                            subfolder="scheduler")
        self.tokenizer = CLIPTokenizer.from_pretrained(
            cfg.pretrained_model_name_or_path, 
            subfolder="tokenizer", 
            revision=cfg.revision
        )

        self.text_encoder = CLIPTextModel.from_pretrained(
            cfg.pretrained_model_name_or_path, 
            subfolder="text_encoder", 
            revision=cfg.revision, 
            variant=cfg.variant
        )
    
        self.vae = AutoencoderKL.from_pretrained(
            cfg.pretrained_model_name_or_path, 
            subfolder="vae", 
            revision=cfg.revision, 
            variant=cfg.variant
        )

        self.unet = UNet2DConditionModel.from_pretrained(
            cfg.pretrained_model_name_or_path, 
            subfolder="unet", 
            revision=cfg.non_ema_revision
        )

        self.vae.requires_grad_(False)
        self.text_encoder.requires_grad_(False)


        self.pipeline = CustomPipeline.from_pretrained(
            cfg.pretrained_model_name_or_path,
            vae=self.vae,
            text_encoder=self.text_encoder,
            tokenizer=self.tokenizer,
            unet=self.unet,
            safety_checker=None,
            revision=cfg.revision,
            variant=cfg.variant,
        )
        
        self.pipeline.set_progress_bar_config(disable=True)

        for k, v in cfg.pipeline.items():
            setattr(self.pipeline, k, v)

        if cfg.enable_xformers_memory_efficient_attention:
            self.pipeline.enable_xformers_memory_efficient_attention()

        if self.seed is None:
            self.generator = None
        else:
            self.generator = torch.Generator(device=self.device).manual_seed(self.seed)


    @staticmethod
    def get_sampler(ds, shuffle=True):
        world_size = torch.distributed.get_world_size()
        local_rank = torch.distributed.get_rank()
        return DistributedSampler(ds, world_size, local_rank, shuffle=shuffle)

    def val_collate_fn(self, examples):
        return {
            self.caption_column: [x[self.caption_column] for x in examples]
            }

    def val_dataloader(self):
        sampler = self.get_sampler(self.val_ds) if self.use_ddp else None

        val_loader = DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            drop_last=False,
            sampler=sampler,
            collate_fn=self.val_collate_fn
        )
        return val_loader    
    
    def on_validation_epoch_start(self) -> None:
        if self.mode != 'from_scratch':
            self.all_images = torch.load(...,
                                        map_location='cpu')
            if self.num_trials == 1:
                self.priv_key = torch.load(...,
                                           map_location='cpu')
                self.pub_key = torch.load(...,
                                          map_location='cpu')
        return super().on_validation_epoch_start()

    def validation_step(self, batch, batch_idx): 
        if self.mode == 'from_scratch':
            prompts = batch[self.caption_column]
    
            batch_pub_key = gen_public_key(num_keys=len(prompts), 
                                            key_len=self.wm_len,
                                            device=self.device)
            
            batch_priv_key = gen_private_key(
                img_size=(len(prompts), 3, *self.resolution), 
                key_len=self.wm_len, is_structured=self.is_structured,
                inner_region_len=self.inner_region_len, 
                delta=self.delta, device=self.device
                                        )
           
            # if batch_priv_key.ndim == 3: # FIXME edge case
            #     batch_priv_key.unsqueeze(0)
            
            images = []
            for n_iter, (prompt, priv_key, pub_key) in enumerate(zip(prompts, 
                                                                    batch_priv_key, 
                                                                    batch_pub_key)):
                
                t_i = datetime.now()
                image = self.pipeline(prompt=prompt,
                                    private_key = priv_key,
                                    pub_key = pub_key,
                                    num_inference_steps=50,
                                    generator=self.generator,
                                    num_iter=n_iter
                                    ).images[0]    
                images.append(image)
                
                self._times.append(datetime.now() - t_i)
             
            torch_image = torch.stack([pil_to_tensor(im) for im in images]).to(dtype=self.dtype,
                                                                            device=self.device)
            print(f'mean time {sum(self._times, timedelta()) / len(self._times)}')
            # if self.logger is not None:
            #     # TensorBoard
            #     for image, prompt in zip(torch_image.to(torch.uint8), prompts):
            #         self.logger.experiment.add_image(prompt, image, 0)
            
            # self.compute_metrics(torch_image, batch_priv_key, batch_pub_key)
                
        elif self.mode in ('orig', 'wm'):   
            torch_image = self.all_images[batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size].to(self.device)
            
            for _ in tqdm(range(self.num_trials)):
                if self.num_trials == 1:
                    batch_priv_key = self.priv_key[batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size].to(self.device)
                    batch_pub_key = self.pub_key[batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size].to(self.device)
                else:
                    batch_pub_key = gen_public_key(num_keys=len(torch_image), 
                                        key_len=self.wm_len,
                                        device=self.device)

                    batch_priv_key = gen_private_key(
                        img_size=(len(torch_image), 3, *self.resolution), 
                        key_len=self.wm_len, is_structured=self.is_structured,
                        inner_region_len=self.inner_region_len, 
                        delta=self.delta, device=self.device
                                                )
                    
                self.compute_metrics(torch_image, batch_priv_key, batch_pub_key)

            torch.cuda.empty_cache()
        else:
            raise NotImplementedError
            
    @property
    def is_master(self) -> bool:
        """
        Returns True if the caller is the master node (Either code is running on 1 GPU or current rank is 0)
        """
        return (self.use_ddp is False) or (torch.distributed.get_rank() == 0) 
    
    
    def compute_metrics(self, torch_image, batch_priv_key, batch_pub_key):
        taus = [1, 2, 3, 5, 7]
        
        if taus is None:
            generation_error = compute_wm_error(torch_image, 
                                            batch_priv_key, 
                                            batch_pub_key)
            self.log("generation_error", generation_error, batch_size=len(torch_image))
            
            
            errors = {} 
            for k, attack in self.attacks.items():
                if 'white_box' in k:
                    attacked_img = attack(x=torch_image, priv_key=batch_priv_key)
                else:
                    attacked_img = attack(torch_image)
                errors[k + '_error'] = compute_wm_error(attacked_img, 
                                                        batch_priv_key, 
                                                        batch_pub_key)

            for k, v in errors.items():
                if isinstance(v, torch.Tensor):
                    v = v.item()    
                self.log(k, v, batch_size=len(torch_image))
        else:
            for tau in taus:
                torch.cuda.empty_cache()
                generation_error = compute_wm_error(torch_image, 
                                batch_priv_key, 
                                batch_pub_key, tau=tau)
                self.log(f"generation_error_{tau}", generation_error, batch_size=len(torch_image))
                
                
                errors = {} 
                for k, attack in self.attacks.items():
                    if 'white_box' in k:
                        attacked_img = attack(x=torch_image, priv_key=batch_priv_key)
                    else:
                        attacked_img = attack(torch_image)
                    errors[k + '_error'] = compute_wm_error(attacked_img, 
                                                            batch_priv_key, 
                                                            batch_pub_key, tau=tau)

                for k, v in errors.items():
                    if isinstance(v, torch.Tensor):
                        v = v.item()    
                    self.log(k + f'_{tau}', v, batch_size=len(torch_image))



# error low boundary by articifially aligning the pub key
    ### 
    
    #     pk1, pk2 = priv_key
    #     t_img = pil_to_tensor(image).long()
    #     pub_key = (t_img[pk1[0], pk1[1], pk1[2]] - t_img[pk2[0], pk2[1], pk2[2]]) >= 0
    #     pub_key = pub_key.long().to(self.device)
        
    #     # print(f'timg {t_img[pk1[0], pk1[1], pk1[2]], t_img[pk2[0], pk2[1], pk2[2]]}')
        
    #     if _new_pub_key is None:
    #         _new_pub_key = pub_key[None]
    #     else:
    #         _new_pub_key = torch.concatenate((_new_pub_key, pub_key[None]))
        
    # batch_pub_key = _new_pub_key.clone() 
    
    ###   
    
# old collision error
    
            # collision_errors = []
        # for _ in range(self.num_collision_trials):        
        #     collision_pub_key = gen_public_key(num_keys=len(prompts), 
        #                             key_len=self.wm_len,
        #                             device=self.device)
        #     collision_priv_key = gen_private_key(
        #         img_size=(len(prompts), 3, *self.resolution), 
        #         key_len=self.wm_len, is_structured=self.is_structured,
        #         inner_region_len=self.inner_region_len, 
        #         delta=self.delta, device=self.device
        #                                 )
            
        #     generation_error = compute_wm_error(torch_image, 
        #                                         collision_priv_key, 
        #                                         collision_pub_key,
        #                                         reduce='none')
            
        #     collision_errors.append(generation_error)
        
        # means = torch.stack(collision_errors).mean(0)
        # coll_mean, coll_std = means.mean(), means.std()
        
        # self.log("collision_error_mean", coll_mean, batch_size=len(prompts))
        # self.log("collision_error_std", coll_std, batch_size=len(prompts))
        