from abc import ABC, abstractmethod
import torch
from guided_diffusion.custom_util import *
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
from util.img_utils import clear_color

from util.logger import get_logger

from guided_diffusion.custom_util import *
import argparse

from scipy.stats import entropy

__CONDITIONING_METHOD__ = {}

def register_conditioning_method(name: str):
    def wrapper(cls):
        if __CONDITIONING_METHOD__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __CONDITIONING_METHOD__[name] = cls
        return cls
    return wrapper

def get_conditioning_method(name: str, operator, noiser, **kwargs):
    if __CONDITIONING_METHOD__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined!")
    return __CONDITIONING_METHOD__[name](operator=operator, noiser=noiser, **kwargs)


    
class ConditioningMethod(ABC):
    def __init__(self, operator, noiser, **kwargs):
        self.operator = operator
        self.noiser = noiser
        print("kwargs: ",kwargs)

        self.num_call = 0
        self.data_min_list = []
        self.data_max_list = []
        self.data_mean_list = []
        self.mask_patch_list = []
        self.threshold = 1.5
        self.threshold_init = 0
        self.threshold_mid_1 = 0
        self.decay = 0.998

        self.num_call_ind = 0

        self.DPSSAG_mask_save = 0

        self.mean_kl_matrix_dec_list = []


        
    
    def project(self, data, noisy_measurement, **kwargs):
        return self.operator.project(data=data, measurement=noisy_measurement, **kwargs)

    # DPS G만
    def grad_and_value(self, x_prev, x_0_hat, measurement, **kwargs):
        if self.noiser.__name__ == 'gaussian':

            a = measurement[0] # [1,3,512,512]
            pad_size = 128
            #print ("0_hat range:",x_0_hat.max(),x_0_hat.min())
            x_0_hat_pad = F.pad(x_0_hat,(pad_size,pad_size,pad_size,pad_size), mode='constant', value=-1) # # [1,3,256,256] =>[1,3,512,512]
            b = self.operator.forward(x_0_hat_pad,kernel_size=512, **kwargs)[0]

            b = crop_and_noise_2(b,490,0) # #crop b = crop(b,490)

            difference = normalize(a)-normalize(b)
            #difference = a[:,64:-64,64:-64] - b[:,64:-64,64:-64] # center 128x128 comparison

            norm = torch.linalg.norm(difference)
            norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0] # [1,3,256,256]


            self.num_call += 1
        
        elif self.noiser.__name__ == 'poisson':
            Ax = self.operator.forward(x_0_hat, **kwargs)
            difference = measurement-Ax
            norm = torch.linalg.norm(difference) / measurement.abs()
            norm = norm.mean()
            norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]

        else:
            raise NotImplementedError
        
        return norm_grad, norm
    



@register_conditioning_method(name='ps_roi_deconv_patch_kl_stochastic')
class PosteriorSamplingROI(ConditioningMethod):
    def __init__(self,args, operator, noiser, device='cuda', **kwargs):
        self.args = args
        GPU_NUM =self.args.gpu
        
        super().__init__(operator, noiser)
        self.scale = kwargs.get('scale')
        print ("self.scale:",self.scale)

        self.diff_scale = kwargs.get('diff_scale')
        self.deconv_scale = kwargs.get('deconv_scale')
        print("diff scale: ",self.diff_scale)
        print("deconv scale: ",self.deconv_scale)
        
        self.device = device
        self.kernel_size = kwargs.get('kernel_size')
        self.exp_task = kwargs.get('exp_task')
        self.patch_size = kwargs.get('patch_size')
        self.threshold_init = kwargs.get('threshold_init')

        self.loss_mode = kwargs.get('loss_mode')
        self.start_point_1 = kwargs.get('start_point_1')
        self.start_point_2 = kwargs.get('start_point_2')

        self.skip_point = kwargs.get('skip_point')
        self.vignette_scale = kwargs.get('vignette_scale')


        self.results_dir = './results_sub_test_0510_final/0510_ROI_deconv_patchwise_kl_stochastic_%s_%f_%f/' % (self.exp_task,self.diff_scale,self.deconv_scale)

  
        if not os.path.exists(self.results_dir):
            os.makedirs(self.results_dir, exist_ok=True)
        



    def conditioning(self, x_prev, x_t, x_0_hat, measurement,deconv_guide,s_theta, **kwargs):
        s_theta = s_theta

        norm_grad, norm_total, conv_mask, deconv_mask, conv_fft_mask, deconv_fft_mask, x_0_hat_256, KL_deconv_mask,arg_norm_true, x_0_hat_highpass_mask, diff_highpass_mask,admm_highpass_mask,mean_kl_matrix_dec_list, arg_trans_S = self.grad_and_value_ROI_pat_KL_stoch(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement,deconv_guide=deconv_guide,diff_scale=self.diff_scale,deconv_scale=self.deconv_scale,patch_size=self.patch_size, threshold_init=self.threshold_init,loss_mode = self.loss_mode,start_point_1 = self.start_point_1,start_point_2 = self.start_point_2,skip_point = self.skip_point,vignette_scale = self.vignette_scale,**kwargs)
        mean_kl_matrix_dec_list = torch.tensor(mean_kl_matrix_dec_list)
        if arg_norm_true:
            x_t -= norm_grad * self.scale
            print("conditional diffusion!")
        else:
            # if (self.num_call_ind <= self.start_point_1) or (self.num_call_ind > self.start_point_2):
            if self.num_call_ind <= 200:
                x_t -= norm_grad * self.scale
            else:
                if arg_trans_S:
                    x_t = x_t
                else:
                    # x_t -= admm_highpass_mask * norm_grad * self.scale
                    x_t -= norm_grad * self.scale


        
        return x_t, norm_total ,s_theta, norm_grad


    def grad_and_value_ROI_pat_KL_stoch(self, x_prev, x_0_hat, measurement,deconv_guide,diff_scale,deconv_scale,patch_size,threshold_init,loss_mode,skip_point,vignette_scale, **kwargs):
        if self.noiser.__name__ == 'gaussian':
            
            high_band = 10 
            low_band = 100
            
            arg_trans_S = False


            x_0_hat = vignetting(x_0_hat,vignette_scale) # vignette!!
            
            a = measurement[0] # [1,3,512,512]
            b = self.operator.forward(x_0_hat, **kwargs)[0] # [1,3,512,512]
            c = deconv_guide[0] # [1,3,256,256]
            d = x_0_hat[0] # [1,3,256,256]

            
            
            a = a[:,128:-128,128:-128]
            b = b[:,128:-128,128:-128]

      
            
            difference_conv = normalize(a)-normalize(b)
            difference_deconv = abs(c-d) #normalize(c)-normalize(d)

            difference_deconv_plus = abs(c-d) + 0.05 * torch.randn_like(abs(c))

            difference_conv_nonorm = abs(a-b)
            difference_deconv_nonorm = abs(c-d)

            deconv_guide_fft = abs(high_pass(c.unsqueeze(0),high_band) - high_pass(d.unsqueeze(0),high_band))
            deconv_norm_fft = torch.linalg.norm(deconv_guide_fft)


            conv_guide_fft = abs(low_pass(c.unsqueeze(0),low_band) - low_pass(d.unsqueeze(0),low_band))
            conv_norm_fft = torch.linalg.norm(deconv_guide_fft)

            difference_conv_nonorm = torch.sum(difference_conv_nonorm,dim=0)
            difference_deconv_nonorm = torch.sum(difference_deconv_nonorm,dim=0)


            ADMM_guide_inp = (c+1)/2 #c
            x_0_hat_inp = (d+1)/2 #d
            ADMM_guide_pixel = torch.sum(ADMM_guide_inp,dim=0)
            X_0_hat_pixel = torch.sum(x_0_hat_inp,dim=0)

            
            admm_highpass_fft_img = abs(high_pass_filter(ADMM_guide_inp.unsqueeze(0),high_band)) #(1,3,256,256)
            admm_highpass_fft_img = admm_highpass_fft_img.squeeze()
            admm_highpass_fft_img = torch.sum(admm_highpass_fft_img,dim=0)
            x_0_hat_highpass_fft_img = abs(high_pass_filter(x_0_hat_inp.unsqueeze(0),high_band)) #(1,3,256,256)
            x_0_hat_highpass_fft_img = x_0_hat_highpass_fft_img.squeeze()
            x_0_hat_highpass_fft_img = torch.sum(x_0_hat_highpass_fft_img,dim=0)
            
            thr_admm_highpass_fft_img = torch.quantile(admm_highpass_fft_img, 0.40)
            admm_highpass_mask = (admm_highpass_fft_img>thr_admm_highpass_fft_img).int()
            admm_highpass_mask = admm_highpass_mask.unsqueeze(0).repeat(3, 1, 1) #(3,256,256)
            
            thr_x_0_hat_highpass_fft_img = torch.quantile(x_0_hat_highpass_fft_img, 0.40)
            x_0_hat_highpass_mask = (x_0_hat_highpass_fft_img>thr_x_0_hat_highpass_fft_img).int()
            x_0_hat_highpass_mask = x_0_hat_highpass_mask.unsqueeze(0).repeat(3, 1, 1) #(3,256,256)

            thr_diff_fft_img = torch.quantile((admm_highpass_fft_img - x_0_hat_highpass_fft_img), 0.40)
            diff_highpass_mask = ((admm_highpass_fft_img - x_0_hat_highpass_fft_img)>thr_diff_fft_img).int()
            diff_highpass_mask = diff_highpass_mask.unsqueeze(0).repeat(3, 1, 1) #(3,256,256)

         
            
            ADMM_guide_pixel_probs = ADMM_guide_pixel / ADMM_guide_pixel.sum()
            X_0_hat_pixel_probs = X_0_hat_pixel / X_0_hat_pixel.sum() + 1e-10
            log_ADMM_guide_pixel_probs = torch.log(ADMM_guide_pixel_probs + 1e-10)  


            kl_div_deconv_pixel = F.kl_div(log_ADMM_guide_pixel_probs, X_0_hat_pixel_probs, reduction='none')
            
           

            kl_div_deconv_patch = kl_div_deconv_pixel.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)
            kl_div_deconv_patch = kl_div_deconv_patch.sum(dim=-1).sum(dim=-1)

           
            quantile_th_kl_deconv = torch.quantile(kl_div_deconv_patch, 0.65)
            patch_mask_kl_deconv = (kl_div_deconv_patch > quantile_th_kl_deconv).int()

            KL_deconv_mask = patch_mask_kl_deconv.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1)
            # print(KL_deconv_mask)
            # print(KL_deconv_mask.float().mean())
            KL_deconv_mask = KL_deconv_mask.unsqueeze(0).repeat(3, 1, 1)

           
            kl_matrix_dec = np.zeros((patch_size, patch_size))
            
            
   
            for i in range(patch_size):
                for j in range(patch_size):
                  
                    ADMM_guide_patch = ADMM_guide_pixel[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size].flatten()
                    X_0_hat_patch = X_0_hat_pixel[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size].flatten()


                    p = (ADMM_guide_patch + 1e-10) / (ADMM_guide_patch + 1e-10).sum()
                    q = (X_0_hat_patch + 1e-10) / (X_0_hat_patch + 1e-10).sum()
                    p = p.detach().cpu().numpy() 
                    q = q.detach().cpu().numpy() 
                   
                    kl_divergence = (entropy(p, q) + entropy(q, p))/2
                    kl_matrix_dec[i, j] = kl_divergence


            kl_matrix_dec = torch.tensor(kl_matrix_dec, dtype=torch.float32)
            kl_matrix_dec = kl_matrix_dec.to(c.device)
            mean_kl_matrix_dec = kl_matrix_dec.mean()
            self.mean_kl_matrix_dec_list.append(mean_kl_matrix_dec)
            
            th_kl_matrix_dec = torch.quantile(kl_matrix_dec, threshold_init)
            kl_dec_mask = (kl_matrix_dec > th_kl_matrix_dec).int()

            kl_dec_mask = kl_dec_mask.repeat_interleave(int(256/patch_size), dim=0).repeat_interleave(int(256/patch_size), dim=1)
            kl_dec_mask = kl_dec_mask.unsqueeze(0).repeat(3, 1, 1)
            kl_dec_mask = kl_dec_mask.to(c.device)
            difference_kl_deconv_mask = difference_deconv * kl_dec_mask
            norm_kl_dec_mask = torch.linalg.norm(difference_kl_deconv_mask)

         
            percentage = 32/256 
            height, width = kl_dec_mask[0].shape
            edge_size_h = int(height * percentage * 0.5)
            edge_size_w = int(width * percentage * 0.5)
            
            kl_dec_mask_mod = kl_dec_mask
            kl_dec_mask_mod[:,:edge_size_h, :] = 0.4 * kl_dec_mask_mod[:,:edge_size_h, :]
            kl_dec_mask_mod[:,-edge_size_h:, :] = 0.4 * kl_dec_mask_mod[:,-edge_size_h:, :]
            kl_dec_mask_mod[:,:, :edge_size_w] = 0.4 * kl_dec_mask_mod[:,:, :edge_size_w]
            kl_dec_mask_mod[:,:, -edge_size_w:] = 0.4 * kl_dec_mask_mod[:,:, -edge_size_w:]

            

            
            
        
            shift_size = self.num_call_ind % 16
            difference_deconv_nonorm_cyc = cyclic_shift_torch(difference_deconv_nonorm,shift_size,shift_size)
            unfolded_cyc = difference_deconv_nonorm_cyc.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)
            patch_sums_cyc = unfolded_cyc.sum(dim=-1).sum(dim=-1)
            quantile_th_cyc = torch.quantile(patch_sums_cyc, threshold_init)
            patch_mask_cyc = (patch_sums_cyc > quantile_th_cyc).int()
            deconv_mask_cyc = patch_mask_cyc.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1)
            deconv_mask_cyc = deconv_mask_cyc.unsqueeze(0).repeat(3, 1, 1)

            
            
            
            unfolded_conv = difference_conv_nonorm.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)
            unfolded = difference_deconv_nonorm.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)

            patch_sums_conv = unfolded_conv.sum(dim=-1).sum(dim=-1)
            patch_sums = unfolded.sum(dim=-1).sum(dim=-1)
            
           
            quantile_th_conv = torch.quantile(patch_sums_conv, threshold_init)
            quantile_th = torch.quantile(patch_sums, threshold_init)
           
            patch_mask_conv = (patch_sums_conv > quantile_th_conv).int()
            patch_mask = (patch_sums > quantile_th).int()
        
            
            conv_mask = patch_mask_conv.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1)
            deconv_mask = patch_mask.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1)
           
            conv_mask = conv_mask.unsqueeze(0).repeat(3, 1, 1)
            deconv_mask = deconv_mask.unsqueeze(0).repeat(3, 1, 1)
         
          
        

            difference_mask = difference_conv * conv_mask
            difference_deconv_mask = difference_deconv * deconv_mask
            difference_deconv_mask_plus = difference_deconv_plus * deconv_mask
            difference_fft_mask = difference_deconv * admm_highpass_mask
            difference_fft_mask_2 = difference_deconv * x_0_hat_highpass_mask
            difference_fft_mask_3 = difference_deconv * diff_highpass_mask
            
            difference_mask_2 = difference_conv * (1-deconv_mask)

            norm = torch.linalg.norm(difference_conv)
            norm_dec = torch.linalg.norm(difference_deconv)
            norm_mask = torch.linalg.norm(difference_mask)
            norm_dec_mask = torch.linalg.norm(difference_deconv_mask)
            norm_fft_mask = torch.linalg.norm(difference_fft_mask)
            norm_fft_mask_2 = torch.linalg.norm(difference_fft_mask_2)
            norm_fft_mask_3 = torch.linalg.norm(difference_fft_mask_3)
            
            norm_mask_2 = torch.linalg.norm(difference_mask_2)

            norm_dec_plus = torch.linalg.norm(difference_deconv_plus)
            norm_dec_mask_plus = torch.linalg.norm(difference_deconv_mask_plus)

            conv_norm_total = norm + norm_mask
            deconv_norm_total = norm_dec + norm_dec_mask

            deconv_norm_total_plus = norm_dec_plus + norm_dec_mask_plus
            deconv_norm_mask_plus = norm_dec_mask_plus
            
            conv_norm_mask = norm_mask
            deconv_norm_mask = norm_dec_mask

            admm_highpass_norm_mask = norm_fft_mask
            x_0_hat_highpass_norm_mask = norm_fft_mask_2
            diff_highpass_norm_mask = norm_fft_mask_3

            diff_highpass_norm = (admm_highpass_fft_img - x_0_hat_highpass_fft_img).unsqueeze(0).repeat(3, 1, 1)
            diff_highpass_norm = torch.linalg.norm(diff_highpass_norm)
            
            conv_norm_only = norm
            deconv_norm_only = norm_dec

            conv_norm_mask_2 = norm_mask_2

       
            deconv_norm_total_kl = norm_dec + norm_kl_dec_mask

            

            self.num_call_ind = self.num_call % 1000
            self.num_call += 1

            
            
        

            norm_tv = tv1_loss(d)
            tv1_weight = 0.1
            norm_deconv_tv = deconv_norm_only + tv1_weight * norm_tv
            if loss_mode == 10:
                norm_conv = torch.sum(conv_norm_only)
                arg_norm_true =True
                print("10: DPS_loss")
                norm_grad = diff_scale * torch.autograd.grad(outputs=norm_conv, inputs=x_prev)[0]
            elif loss_mode == 59:
                norm_conv = torch.sum(conv_norm_only)
                norm_deconv = torch.sum(deconv_norm_only)
                arg_norm_true = True
                print("59 - 8: Deconv_loss * Deconv_mask_difference-Cyclic Shift + skip guidance")
                print("skip_point: ", skip_point)
                if self.num_call_ind <= skip_point:
                    if self.num_call_ind % 2 ==0:
                        norm_grad = deconv_scale * deconv_mask_cyc * torch.autograd.grad(outputs=norm_deconv, inputs=x_prev)[0]
                    else:
                        norm_grad = 0
                else:
                    print("on 22222")
                    norm_grad = deconv_scale * deconv_mask_cyc * torch.autograd.grad(outputs=norm_deconv, inputs=x_prev)[0]
            else :
                print ("error")

            
            print("conv_scale: ",diff_scale)
            print("deconv_scale: ",deconv_scale)
        
        
        elif self.noiser.__name__ == 'poisson':
            Ax = self.operator.forward(x_0_hat, **kwargs)
            difference = measurement-Ax
            norm = torch.linalg.norm(difference) / measurement.abs()
            norm = norm.mean()
            norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]

        else:
            raise NotImplementedError

        return norm_grad, norm, conv_mask, kl_dec_mask, conv_guide_fft, deconv_guide_fft, (1+d)/2, KL_deconv_mask, arg_norm_true, x_0_hat_highpass_mask, diff_highpass_mask, admm_highpass_mask, self.mean_kl_matrix_dec_list,arg_trans_S  
    
@register_conditioning_method(name='vanilla')
class Identity(ConditioningMethod):
    # just pass the input without conditioning
    def conditioning(self, x_t):
        return x_t
    
@register_conditioning_method(name='projection')
class Projection(ConditioningMethod):
    def conditioning(self, x_t, noisy_measurement, **kwargs):
        x_t = self.project(data=x_t, noisy_measurement=noisy_measurement)
        return x_t
