from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
import torch
from torch.nn.functional import mse_loss
from torch.optim import Adam, SGD

from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.image.fid import FrechetInceptionDistance
# from torcheval.metrics import FrechetInceptionDistance
from torcheval.metrics.functional import peak_signal_noise_ratio as psnr

from collections import defaultdict

from utils.pxl_swap import compute_wm_error, swap_pxls
from utils.attacks import linear_transform_to_range

class CustomPipeline(StableDiffusionPipeline):
    real_img = None
    lpips = LearnedPerceptualImagePatchSimilarity(net_type='alex')
    fid = FrechetInceptionDistance(feature=2048, normalize=True, input_img_size=(3, 512, 512))
    ssim = StructuralSimilarityIndexMeasure()
    
    errors = []
    metrics = defaultdict(list)
    
    def __call__(self, *args, **kwargs):
        torch.set_grad_enabled(True)
        
        private_key = kwargs.pop('private_key')
        pub_key = kwargs.pop('pub_key')
        prompt_num = kwargs.pop('num_iter')
        latents = kwargs.pop('latents', None)
        
        private_key = private_key.to(self.device)
        pub_key = pub_key.to(self.device)
        
        if latents is None:
            kwargs['output_type'] = 'latent'
            with torch.no_grad():
                latents = super().__call__(*args, **kwargs).images
        
        loss_img = {'loss': float('inf'), 'image': None}
        num_iter = 0
        cnt_iter = 0
        
        latents.requires_grad = True
        opt = Adam((latents,), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=self.thr_num_iter, gamma=0.5)
        
        while True:
            image = self.vae.decode(latents / self.vae.config.scaling_factor, 
                                    return_dict=False)[0]

            if self.real_img is None:
                self.real_img = image.detach().clone()
                self.get_real_img_grads(private_key)
                self.real_img = self.real_img.clip(-1, 1)
                
                if self.fid.device != image.device:    
                    self.fid = self.fid.to(image.device)
                if self.ssim.device != image.device:    
                    self.ssim = self.ssim.to(image.device)
                                
            loss = self.wm_loss(image, private_key, pub_key)
            # print(f'{prompt_num}. Loss at iter {num_iter}: {loss}')
            
            if loss < self.loss_thr:
                break
            elif loss < loss_img['loss']:
                loss_img['loss'] = loss
                loss_img['image'] = image
                cnt_iter = 0
        
            if num_iter >= self.max_num_iter or cnt_iter > self.early_stop_iter_num:
                image = loss_img['image']
                break
        
            opt.zero_grad()
            loss.backward()
            opt.step()
            scheduler.step()
        
            num_iter += 1
            cnt_iter += 1
        
        
        image = image.detach()
        self.update_metrics(image)
        
        # if self.post_swap:
        #     image = swap_pxls(image, private_key, pub_key, eps=self.eps)
            
        norm_num_errors = compute_wm_error(image, private_key, pub_key)
        print(f'{prompt_num}. gen error in pipeline {norm_num_errors}')
        self.errors.append(norm_num_errors)        
        
        image, has_nsfw_concept = self.run_safety_checker(image, image.device, private_key.dtype)
        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(image, output_type='pil', 
                                                do_denormalize=do_denormalize) # FIXME hardcoded pil

        # Offload all models
        self.maybe_free_model_hooks()
        self.real_img = None
        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)



    def wm_loss(self, imgs, priv_key, pub_key):
        if priv_key.ndim == 4:
            priv_key = priv_key[0]
        
        # LPIPS
        if self.lpips.device != imgs.device:    
            self.lpips = self.lpips.to(imgs.device)
        lpips_loss = self.lpips(imgs.clip(-1, 1), self.real_img) 
        lpips_loss = lpips_loss + self.mse_w * mse_loss(imgs, self.real_img)  
        
        # WM error
        pk1, pk2 = priv_key
        pxls1 = imgs[:, pk1[0], pk1[1], pk1[2]]
        pxls2 = imgs[:, pk2[0], pk2[1], pk2[2]]
        sgn = 2 * pub_key - 1
        if self.grad_thr is not None:
            pxls1, pxls2 = self.penalize_pxls(pxls1, pxls2)
            
        wm_error = - torch.min(sgn * (pxls1 - pxls2) - self.eps, 
                            torch.zeros_like(pxls1)).sum()
        
        return self.wm_loss_w * wm_error + self.lpips_w * lpips_loss


    def update_metrics(self, image):
        self.metrics['psnr'].append(
            psnr(self.real_img, image).item()
        ) 
        self.metrics['ssim'].append(
            self.ssim(self.real_img, image).item()
        )        
        self.metrics['lpips'].append(
            self.lpips(self.real_img, image.clip(-1, 1)).item()
        ) 
        
        # image = linear_transform_to_range(image, new_max=1)#.to(torch.uint8)
        # real_img = linear_transform_to_range(self.real_img, new_max=1)#.to(torch.uint8)

        # self.fid.update(real_img, real=True)
        # self.fid.update(image, real=False)
        
        # self.metrics['fid'].append(
        #     self.fid.compute().item()
        # ) 
        # self.fid.reset()
        

    def get_real_img_grads(self, priv_key):
        pk1, pk2 = priv_key
        
        img_grad = self.get_pxl_grad(self.real_img)
        
        # detach clone just in case
        self.real_img_pxls_1 = self.real_img[:, pk1[0], pk1[1], pk1[2]].detach().clone()
        self.real_img_pxls_2 = self.real_img[:, pk2[0], pk2[1], pk2[2]].detach().clone()
        
        self.real_img_grad_1 = img_grad[:, pk1[0], pk1[1], pk1[2]].detach().clone()
        self.real_img_grad_2 = img_grad[:, pk2[0], pk2[1], pk2[2]].detach().clone()
    

    def penalize_pxls(self, pxls1, pxls2):
        pxls1 = torch.where(
            self.real_img_grad_1 > self.grad_thr,
            pxls1,
            self.real_img_pxls_1
        )
        
        pxls2 = torch.where(
            self.real_img_grad_2 > self.grad_thr,
            pxls2,
            self.real_img_pxls_2
        )
        
        return pxls1, pxls2


    @staticmethod
    def get_pxl_grad(x):
        dy_down = torch.abs(x - torch.roll(x, shifts=1, dims=2)) 
        dy_down[:,:,0,:] = 0
        
        dy_up = torch.abs(x - torch.roll(x, shifts=-1, dims=2))
        dy_up[:,:,-1,:] = 0
        
        dx_right = torch.abs(x - torch.roll(x, shifts=1, dims=3))
        dx_right[:,:,:,0] = 0
        
        dx_left = torch.abs(x - torch.roll(x, shifts=-1, dims=3))
        dx_left[:,:,:,-1] = 0
        
        pxl_grad = (dy_down + dy_up + dx_left + dx_right) / 4
        return pxl_grad
    