'''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.'''

from abc import ABC, abstractmethod
from functools import partial
import yaml
from torch.nn import functional as F
from torchvision import torch
from motionblur.motionblur import Kernel

from util.resizer import Resizer
from util.img_utils import Blurkernel, fft2_m

import cv2
import numpy as np
from torch.fft import fft2, ifft2, fftshift, ifftshift
from guided_diffusion.custom_util import *

import argparse



def do_erase(data,ratio):
    
    origin_shape = data.shape
    
    # 배열을 1차원으로 변환
    flattened = data.flatten()

    # 전체 요소 수의 50%에 해당하는 인덱스 수를 계산
    num_zeroes = int(len(flattened) *ratio)

    # 무작위로 50%의 인덱스 선택
    indices = np.random.choice(len(flattened), num_zeroes, replace=False)

    # 선택된 인덱스의 요소를 0으로 설정
    flattened[indices] = 0
    
    data = flattened.reshape(origin_shape)
    
    return data


def sharpen_image(image, strength=1.0):
    """
    Apply sharpening to an image tensor using unsharp masking.
    :param image: Image tensor of shape [batch_size, channels, height, width].
    :param strength: Strength of the sharpening effect.
    """
    if image.dim() != 4:
        raise ValueError("Image tensor must be 4-dimensional.")

    # Create a blur kernel for unsharp masking
    blur_kernel = torch.tensor([[1, 2, 1],
                                [2, 4, 2],
                                [1, 2, 1]], dtype=torch.float32).to(image.device) / 16.0
    blur_kernel = blur_kernel.reshape(1, 1, 3, 3).repeat(image.size(1), 1, 1, 1)

    # Apply padding to maintain the image size
    padded_image = F.pad(image, (1, 1, 1, 1), mode='reflect')

    # Blur the image
    blurred_image = F.conv2d(padded_image, blur_kernel, groups=image.size(1))

    # Create the unsharp mask
    unsharp_mask = image - blurred_image

    # Create the sharpened image by adding the unsharp mask to the original image
    sharpened_image = image + strength * unsharp_mask

    return sharpened_image   

# =================
# Operation classes
# =================

__OPERATOR__ = {}

def register_operator(name: str):
    def wrapper(cls):
        if __OPERATOR__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __OPERATOR__[name] = cls
        return cls
    return wrapper


def get_operator(name: str, **kwargs):
    if __OPERATOR__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined.")
    return __OPERATOR__[name](**kwargs)


class LinearOperator(ABC):
    @abstractmethod
    def forward(self, data, **kwargs):
        # calculate A * X
        pass

    @abstractmethod
    def transpose(self, data, **kwargs):
        # calculate A^T * X
        pass
    
    def ortho_project(self, data, **kwargs):
        # calculate (I - A^T * A)X
        return data - self.transpose(self.forward(data, **kwargs), **kwargs)

    def project(self, data, measurement, **kwargs):
        # calculate (I - A^T * A)Y - AX
        return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs)


@register_operator(name='noise')
class DenoiseOperator(LinearOperator):
    def __init__(self, device):
        self.device = device
    
    def forward(self, data):
        return data

    def transpose(self, data):
        return data
    
    def ortho_project(self, data):
        return data

    def project(self, data):
        return data


@register_operator(name='super_resolution')
class SuperResolutionOperator(LinearOperator):
    def __init__(self, in_shape, scale_factor, device):
        self.device = device
        self.up_sample = partial(F.interpolate, scale_factor=scale_factor)
        self.down_sample = Resizer(in_shape, 1/scale_factor).to(device)

    def forward(self, data, **kwargs):
        return self.down_sample(data)

    def transpose(self, data, **kwargs):
        return self.up_sample(data)

    def project(self, data, measurement, **kwargs):
        return data - self.transpose(self.forward(data)) + self.transpose(measurement)
    

    
    
@register_operator(name='motion_blur')
class MotionBlurOperator(LinearOperator):
    def __init__(self,args, kernel_size, intensity, device):
        self.device = device
        self.kernel_size = kernel_size
        self.conv = Blurkernel(blur_type='motion',
                               kernel_size=kernel_size,
                               std=intensity,
                               device=device).to(device)  # should we keep this device term?

        self.kernel = Kernel(size=(kernel_size, kernel_size), intensity=intensity)        
        kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32)
        
        if kernel_size < 256:
            pad_size = int((256-kernel_size)/2)
            kernel = F.pad(kernel,(pad_size,pad_size,pad_size,pad_size))
    
        #self.conv.update_weights(kernel)

        self.new_kernel = kernel.unsqueeze(0).unsqueeze(0).to(device)
        #self.new_kernel = sharpen_image(self.new_kernel, 25.0)
        #self.new_kernel[self.new_kernel<0.0001] = 0
        
        print ("motion blur range :",self.new_kernel.max(),self.new_kernel.min())

    def forward(self, data, **kwargs):
        # A^T * A
        #return self.conv(data)
        pad_size = 128
        
        data_pad = F.pad(data,[pad_size,pad_size,pad_size,pad_size],mode='circular')#replicate,reflect
        kernel_pad = F.pad(self.new_kernel,[pad_size,pad_size,pad_size,pad_size])

        #return crop_and_noise_2(conv_psf_fft2(data_pad,kernel_pad),250,0)
        return conv_psf_fft2(data_pad,kernel_pad)

    def transpose(self, data, **kwargs):
        return data

    def get_kernel(self):
        kernel = torch.from_numpy(self.kernel.kernelMatrix).type(torch.float32)
        return kernel.view(1, 1, self.kernel_size, self.kernel_size)


@register_operator(name='gaussian_blur')
class GaussialBlurOperator(LinearOperator):
    
    def __init__(self,args, kernel_size, intensity, device):
        self.device = device
        self.kernel_size = kernel_size
        self.conv = Blurkernel(blur_type='gaussian',
                               kernel_size=kernel_size,
                               std=intensity,
                               device=device).to(device)
        self.kernel = self.conv.get_kernel()
        self.conv.update_weights(self.kernel.type(torch.float32))
        self.new_kernel = self.conv.get_kernel().unsqueeze(0).unsqueeze(0).to(device).type(torch.float32)

    def forward(self, data, **kwargs):
        pad_size = 128
        data_pad = F.pad(data,[pad_size,pad_size,pad_size,pad_size],mode='replicate')#replicate,reflect
        kernel_pad = F.pad(self.new_kernel,[pad_size,pad_size,pad_size,pad_size])

        #return crop_and_noise_2(conv_psf_fft2(data_pad,kernel_pad),250,0)
        return conv_psf_fft2(data_pad,kernel_pad)

    def transpose(self, data, **kwargs):
        return data

    def get_kernel(self):
        return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)

@register_operator(name='inpainting')
class InpaintingOperator(LinearOperator):
    '''This operator get pre-defined mask and return masked image.'''
    def __init__(self, device):
        self.device = device
    
    def forward(self, data, **kwargs):
        try:
            return data * kwargs.get('mask', None).to(self.device)
        except:
            raise ValueError("Require mask")
    
    def transpose(self, data, **kwargs):
        return data
    
    def ortho_project(self, data, **kwargs):
        return data - self.forward(data, **kwargs)


class NonLinearOperator(ABC):
    @abstractmethod
    def forward(self, data, **kwargs):
        pass

    def project(self, data, measurement, **kwargs):
        return data + measurement - self.forward(data) 

@register_operator(name='phase_retrieval')
class PhaseRetrievalOperator(NonLinearOperator):
    def __init__(self, oversample, device):
        self.pad = int((oversample / 8.0) * 256)
        self.device = device
        
    def forward(self, data, **kwargs):
        padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad))
        amplitude = fft2_m(padded).abs()
        return amplitude

@register_operator(name='nonlinear_blur')
class NonlinearBlurOperator(NonLinearOperator):
    def __init__(self, opt_yml_path, device):
        self.device = device
        self.blur_model = self.prepare_nonlinear_blur_model(opt_yml_path)     
         
    def prepare_nonlinear_blur_model(self, opt_yml_path):
        '''
        Nonlinear deblur requires external codes (bkse).
        '''
        from bkse.models.kernel_encoding.kernel_wizard import KernelWizard

        with open(opt_yml_path, "r") as f:
            opt = yaml.safe_load(f)["KernelWizard"]
            model_path = opt["pretrained"]
        blur_model = KernelWizard(opt)
        blur_model.eval()
        blur_model.load_state_dict(torch.load(model_path)) 
        blur_model = blur_model.to(self.device)
        return blur_model
    
    def forward(self, data, **kwargs):
        random_kernel = torch.randn(1, 512, 2, 2).to(self.device) * 1.2
        data = (data + 1.0) / 2.0  #[-1, 1] -> [0, 1]
        blurred = self.blur_model.adaptKernel(data, kernel=random_kernel)
        blurred = (blurred * 2.0 - 1.0).clamp(-1, 1) #[0, 1] -> [-1, 1]
        return blurred

# =============
# Noise classes
# =============


__NOISE__ = {}

def register_noise(name: str):
    def wrapper(cls):
        if __NOISE__.get(name, None):
            raise NameError(f"Name {name} is already defined!")
        __NOISE__[name] = cls
        return cls
    return wrapper

def get_noise(name: str, **kwargs):
    if __NOISE__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined.")
    noiser = __NOISE__[name](**kwargs)
    noiser.__name__ = name
    return noiser

class Noise(ABC):
    def __call__(self, data):
        return self.forward(data)
    
    @abstractmethod
    def forward(self, data):
        pass

@register_noise(name='clean')
class Clean(Noise):
    def forward(self, data):
        return data

@register_noise(name='gaussian')
class GaussianNoise(Noise):
    def __init__(self, sigma):
        self.sigma = sigma
    
    def forward(self, data):
        return data + torch.randn_like(data, device=data.device) * self.sigma
    
    
@register_noise(name='gaussian_mean')
class GaussianNoise(Noise):
    def __init__(self, sigma):
        """
        Initializes the GaussianNoise layer.
        
        Parameters:
        - sigma (float): Standard deviation of the Gaussian noise.
        - mu (float): Mean of the Gaussian noise. Default is 0.0.
        """
        self.sigma = sigma
    
    def forward(self, data):
        """
        Applies Gaussian noise to the input data.

        Parameters:
        - data (torch.Tensor): Input tensor to which noise will be added.

        Returns:
        - torch.Tensor: Noisy data.
        """
        noise = torch.randn_like(data, device=data.device) * self.sigma + torch.mean(data)
        return data + noise


@register_noise(name='poisson')
class PoissonNoise(Noise):
    def __init__(self, rate):
        self.rate = rate

    def forward(self, data):
        '''
        Follow skimage.util.random_noise.
        '''

        # TODO: set one version of poisson
       
        # version 3 (stack-overflow)
        import numpy as np
        data = (data + 1.0) / 2.0
        data = data.clamp(0, 1)
        device = data.device
        data = data.detach().cpu()
        data = torch.from_numpy(np.random.poisson(data * 255.0 * self.rate) / 255.0 / self.rate)
        data = data * 2.0 - 1.0
        data = data.clamp(-1, 1)
        return data.to(device)


@register_operator(name='lensless_real_voronoi')
class PsfOperator(LinearOperator):
    def __init__(self,args, device, **kwargs) -> None:
        """
        parser = argparse.ArgumentParser()
        parser.add_argument('--model_config', type=str)
        parser.add_argument('--diffusion_config', type=str)
        parser.add_argument('--task_config', type=str)
        parser.add_argument('--gpu', type=int, default=4)
        parser.add_argument('--save_dir', type=str, default='./results')
        args = parser.parse_args()
        """
        self.args = args
        GPU_NUM =self.args.gpu
        self.device = device

        self.psf_512 = load_psf_real(args = self.args, psf_file = "dataset/ys_flickr_100/psf/psf_camera1_original.tiff", resize_dim= 512).unsqueeze(0).cuda(GPU_NUM)
        self.psf_256 = load_psf(args = self.args, psf_file = "dataset/ys_flickr_100/psf/psf_camera1_original.tiff", resize_dim= 256).unsqueeze(0).cuda(GPU_NUM)

        
        self.new_kernel = self.psf_512 
        #self.new_kernel = self.new_kernel / self.new_kernel.sum() # 합이 1이 되도록
    
    def forward(self, data, **kwargs):
        pad_size = 128
        data_pad = F.pad(data,(pad_size,pad_size,pad_size,pad_size), mode='constant', value=-1) # # [1,3,256,256] =>[1,3,512,512]

        Ax = self.apply_kernel(data_pad,kernel_size=512)
        # crop 4%
        crop_Ax = crop_and_noise_2(Ax,self.args.crop,0)
        # vignetting
        crop_Ax = vignetting(crop_Ax,175)
        return crop_Ax
        
    def forward_nopad(self, data, **kwargs):
        return self.apply_kernel(data,kwargs.get('kernel_size'))

    def transpose(self, data, **kwargs):
        return data
    
    def apply_kernel(self, data, kernel_size):
        #TODO: faster way to apply conv?:W

        if kernel_size == 256:
            kernel = self.psf_256
        else:
            kernel = self.psf_512


        output = conv_psf_fft(data,kernel)

        return output
