import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from PIL import Image
from tqdm import tqdm
from natsort import natsorted
import torchvision.transforms as transforms

def cal_sigma(sig_x, sig_y, radians):
    sig_x = sig_x.view(-1, 1, 1)
    sig_y = sig_y.view(-1, 1, 1)
    radians = radians.view(-1, 1, 1)

    D = torch.cat([F.pad(sig_x ** 2, [0, 1, 0, 0]), F.pad(sig_y ** 2, [1, 0, 0, 0])], 1)
    U = torch.cat([torch.cat([radians.cos(), -radians.sin()], 2),
                   torch.cat([radians.sin(), radians.cos()], 2)], 1)
    sigma = torch.bmm(U, torch.bmm(D, U.transpose(1, 2)))

    return sigma


def anisotropic_gaussian_kernel(batch, kernel_size, covar):
    ax = torch.arange(kernel_size).float().cuda() - kernel_size // 2

    xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
    yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
    xy = torch.stack([xx, yy], -1).view(batch, -1, 2)

    inverse_sigma = torch.inverse(covar)
    kernel = torch.exp(- 0.5 * (torch.bmm(xy, inverse_sigma) * xy).sum(2)).view(batch, kernel_size, kernel_size)

    return kernel / kernel.sum([1, 2], keepdim=True)


def isotropic_gaussian_kernel(batch, kernel_size, sigma):
    ax = torch.arange(kernel_size).float().cuda() - kernel_size//2
    xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
    yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
    kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sigma.view(-1, 1, 1) ** 2))

    return kernel / kernel.sum([1,2], keepdim=True)


def random_anisotropic_gaussian_kernel(batch=1, kernel_size=21, lambda_min=0.2, lambda_max=4.0):
    theta = torch.rand(batch).cuda() / 180 * math.pi
    lambda_1 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min
    lambda_2 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min

    covar = cal_sigma(lambda_1, lambda_2, theta)
    kernel = anisotropic_gaussian_kernel(batch, kernel_size, covar)
    return kernel


def stable_anisotropic_gaussian_kernel(kernel_size=21, theta=0, lambda_1=0.2, lambda_2=4.0):
    theta = torch.ones(1).cuda() * theta * math.pi
    lambda_1 = torch.ones(1).cuda() * lambda_1
    lambda_2 = torch.ones(1).cuda() * lambda_2

    covar = cal_sigma(lambda_1, lambda_2, theta)
    kernel = anisotropic_gaussian_kernel(1, kernel_size, covar)
    return kernel


def random_isotropic_gaussian_kernel(batch=1, kernel_size=21, sig_min=0.2, sig_max=4.0):
    x = torch.rand(batch).cuda() * (sig_max - sig_min) + sig_min
    k = isotropic_gaussian_kernel(batch, kernel_size, x)
    return k


def stable_isotropic_gaussian_kernel(kernel_size=21, sig=4.0):
    x = torch.ones(1).cuda() * sig
    k = isotropic_gaussian_kernel(1, kernel_size, x)
    return k


def random_gaussian_kernel(batch, kernel_size=21, blur_type='iso_gaussian', sig_min=0.2, sig_max=4.0, lambda_min=0.2, lambda_max=4.0):
    if blur_type == 'iso_gaussian':
        return random_isotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, sig_min=sig_min, sig_max=sig_max)
    elif blur_type == 'aniso_gaussian':
        return random_anisotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, lambda_min=lambda_min, lambda_max=lambda_max)


def stable_gaussian_kernel(kernel_size=21, blur_type='iso_gaussian', sig=2.6, lambda_1=0.2, lambda_2=4.0, theta=0):
    if blur_type == 'iso_gaussian':
        return stable_isotropic_gaussian_kernel(kernel_size=kernel_size, sig=sig)
    elif blur_type == 'aniso_gaussian':
        return stable_anisotropic_gaussian_kernel(kernel_size=kernel_size, lambda_1=lambda_1, lambda_2=lambda_2, theta=theta)


# implementation of matlab bicubic interpolation in pytorch
class bicubic(nn.Module):
    def __init__(self):
        super(bicubic, self).__init__()

    def cubic(self, x):
        absx = torch.abs(x)
        absx2 = torch.abs(x) * torch.abs(x)
        absx3 = torch.abs(x) * torch.abs(x) * torch.abs(x)

        condition1 = (absx <= 1).to(torch.float32)
        condition2 = ((1 < absx) & (absx <= 2)).to(torch.float32)

        f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * condition2
        return f

    def contribute(self, in_size, out_size, scale):
        kernel_width = 4
        if scale < 1:
            kernel_width = 4 / scale
        x0 = torch.arange(start=1, end=out_size[0] + 1).to(torch.float32).cuda()
        x1 = torch.arange(start=1, end=out_size[1] + 1).to(torch.float32).cuda()

        u0 = x0 / scale + 0.5 * (1 - 1 / scale)
        u1 = x1 / scale + 0.5 * (1 - 1 / scale)

        left0 = torch.floor(u0 - kernel_width / 2)
        left1 = torch.floor(u1 - kernel_width / 2)

        P = np.ceil(kernel_width) + 2

        indice0 = left0.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()
        indice1 = left1.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()

        mid0 = u0.unsqueeze(1) - indice0.unsqueeze(0)
        mid1 = u1.unsqueeze(1) - indice1.unsqueeze(0)

        if scale < 1:
            weight0 = scale * self.cubic(mid0 * scale)
            weight1 = scale * self.cubic(mid1 * scale)
        else:
            weight0 = self.cubic(mid0)
            weight1 = self.cubic(mid1)

        weight0 = weight0 / (torch.sum(weight0, 2).unsqueeze(2))
        weight1 = weight1 / (torch.sum(weight1, 2).unsqueeze(2))

        indice0 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice0), torch.FloatTensor([in_size[0]]).cuda()).unsqueeze(0)
        indice1 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice1), torch.FloatTensor([in_size[1]]).cuda()).unsqueeze(0)

        kill0 = torch.eq(weight0, 0)[0][0]
        kill1 = torch.eq(weight1, 0)[0][0]

        weight0 = weight0[:, :, kill0 == 0]
        weight1 = weight1[:, :, kill1 == 0]

        indice0 = indice0[:, :, kill0 == 0]
        indice1 = indice1[:, :, kill1 == 0]

        return weight0, weight1, indice0, indice1

    def forward(self, input, scale=1/4):
        b, c, h, w = input.shape

        weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale)
        weight0 = weight0[0]
        weight1 = weight1[0]

        indice0 = indice0[0].long()
        indice1 = indice1[0].long()

        out = input[:, :, (indice0 - 1), :] * (weight0.unsqueeze(0).unsqueeze(1).unsqueeze(4))
        out = (torch.sum(out, dim=3))
        A = out.permute(0, 1, 3, 2)

        out = A[:, :, (indice1 - 1), :] * (weight1.unsqueeze(0).unsqueeze(1).unsqueeze(4))
        out = out.sum(3).permute(0, 1, 3, 2)

        return out


class Gaussin_Kernel(object):
    def __init__(self, kernel_size=21, blur_type='iso_gaussian',
                 sig=2.6, sig_min=0.2, sig_max=4.0,
                 lambda_1=0.2, lambda_2=4.0, theta=0, lambda_min=0.2, lambda_max=4.0):
        self.kernel_size = kernel_size
        self.blur_type = blur_type

        self.sig = sig
        self.sig_min = sig_min
        self.sig_max = sig_max

        self.lambda_1 = lambda_1
        self.lambda_2 = lambda_2
        self.theta = theta
        self.lambda_min = lambda_min
        self.lambda_max = lambda_max

    def __call__(self, batch, random):
        # random kernel
        if random == True:
            return random_gaussian_kernel(batch, kernel_size=self.kernel_size, blur_type=self.blur_type,
                                          sig_min=self.sig_min, sig_max=self.sig_max,
                                          lambda_min=self.lambda_min, lambda_max=self.lambda_max)

        # stable kernel
        else:
            return stable_gaussian_kernel(kernel_size=self.kernel_size, blur_type=self.blur_type,
                                          sig=self.sig,
                                          lambda_1=self.lambda_1, lambda_2=self.lambda_2, theta=self.theta)

class BatchBlur(nn.Module):
    def __init__(self, kernel_size=21):
        super(BatchBlur, self).__init__()
        self.kernel_size = kernel_size
        if kernel_size % 2 == 1:
            self.pad = nn.ReflectionPad2d(kernel_size//2)
        else:
            self.pad = nn.ReflectionPad2d((kernel_size//2, kernel_size//2-1, kernel_size//2, kernel_size//2-1))

    def forward(self, input, kernel):
        B, C, H, W = input.size()
        input_pad = self.pad(input)
        H_p, W_p = input_pad.size()[-2:]

        if len(kernel.size()) == 2:
            input_CBHW = input_pad.view((C * B, 1, H_p, W_p))
            kernel = kernel.contiguous().view((1, 1, self.kernel_size, self.kernel_size))

            return F.conv2d(input_CBHW, kernel, padding=0).view((B, C, H, W))
        else:
            input_CBHW = input_pad.view((1, C * B, H_p, W_p))
            kernel = kernel.contiguous().view((B, 1, self.kernel_size, self.kernel_size))
            kernel = kernel.repeat(1, C, 1, 1).view((B * C, 1, self.kernel_size, self.kernel_size))

            return F.conv2d(input_CBHW, kernel, groups=B*C).view((B, C, H, W))
        
def add_stripe(im):
    # Control the intensity of Gaussian noise
    stdN_G = torch.FloatTensor(im.size(0)).uniform_(15, 16)
    # Gaussian noise with mean 0
    mean_G = torch.zeros(im.size())
    noise_G = torch.normal(mean_G, stdN_G)

    # Control the intensity of stripe noise
    beta = torch.FloatTensor(im.size(-1)).uniform_(25, 30)
    # Stripe noise with mean 0
    mean_col = torch.zeros((im.size(0), im.size(-1)))
    noise_col = torch.normal(mean_col, beta)
    noise_col = noise_col.unsqueeze(0).unsqueeze(1)
    S_noise = noise_col.expand(im.size(0), im.size(1), im.size(2), -1)
    
    # Move tensors to GPU
    noise_G = noise_G.cuda()
    S_noise = S_noise.cuda()
    
    # Add Gaussian and stripe noise to the input image
    return im + 1 * noise_G + 1 * S_noise


class SRMDPreprocessing(object):
    def __init__(self,
                 scale,
                 mode='bicubic',
                 kernel_size=21,
                 blur_type='iso_gaussian',
                 sig=2.6,
                 sig_min=0.2,
                 sig_max=4.0,
                 lambda_1=0.2,
                 lambda_2=4.0,
                 theta=0,
                 lambda_min=0.2,
                 lambda_max=4.0,
                 noise=0.0
                 ):
        '''
        # sig, sig_min and sig_max are used for isotropic Gaussian blurs
        During training phase (random=True):
            the width of the blur kernel is randomly selected from [sig_min, sig_max]
        During test phase (random=False):
            the width of the blur kernel is set to sig

        # lambda_1, lambda_2, theta, lambda_min and lambda_max are used for anisotropic Gaussian blurs
        During training phase (random=True):
            the eigenvalues of the covariance is randomly selected from [lambda_min, lambda_max]
            the angle value is randomly selected from [0, pi]
        During test phase (random=False):
            the eigenvalues of the covariance are set to lambda_1 and lambda_2
            the angle value is set to theta
        '''
        self.kernel_size = kernel_size
        self.scale = scale
        self.mode = mode
        self.noise = noise

        self.gen_kernel = Gaussin_Kernel(
            kernel_size=kernel_size, blur_type=blur_type,
            sig=sig, sig_min=sig_min, sig_max=sig_max,
            lambda_1=lambda_1, lambda_2=lambda_2, theta=theta, lambda_min=lambda_min, lambda_max=lambda_max
        )
        self.blur = BatchBlur(kernel_size=kernel_size)
        self.bicubic = bicubic()

    def __call__(self, hr_tensor, random=True):
        hr_tensor = hr_tensor.cuda()
        with torch.no_grad():
            # only downsampling
            if self.gen_kernel.blur_type == 'iso_gaussian' and self.gen_kernel.sig == 0:
                B, C, H, W = hr_tensor.size()
                hr_blured = hr_tensor.view(-1, C, H, W)
                b_kernels = None

            # gaussian blur + downsampling
            else:
                B, C, H, W = hr_tensor.size()
                b_kernels = self.gen_kernel(B, random)  # B degradations

                # blur
                hr_blured = self.blur(hr_tensor.view(B, -1, H, W), b_kernels)
                hr_blured = hr_blured.view(-1, C, H, W)  # BN, C, H, W

            # downsampling
            if self.mode == 'bicubic':
                lr_blured = self.bicubic(hr_blured, scale=1/self.scale)
            elif self.mode == 's-fold':
                #lr_blured = hr_blured.view(-1, C, H//self.scale, self.scale, W//self.scale, self.scale)[:, :, :, 0, :, 0]
                lr_blured = hr_blured[:, :, ::int(self.scale), ::int(self.scale)]

            # add noise
            if self.noise > 0:
                lr_blured = add_stripe(lr_blured)
                # _, C, H_lr, W_lr = lr_blured.size()                
                # noise_level = torch.rand(B, 1, 1, 1).to(lr_blured.device) * self.noise if random else self.noise
                # noise = torch.randn_like(lr_blured).view(-1, C, H_lr, W_lr).mul_(noise_level).view(-1, C, H_lr, W_lr)
                # lr_blured.add_(noise)

            lr_blured = torch.clamp(lr_blured.round(), 0, 255)
            return lr_blured.view(B, C, H//int(self.scale), W//int(self.scale)), b_kernels
        
if __name__ == '__main__':
    # Initialize the SRMD preprocessor
    preprocessor = SRMDPreprocessing(
        scale=4,  # Scaling factor
        mode='bicubic',  # Interpolation mode
        kernel_size=21,  # Blur kernel size
        blur_type='iso_gaussian',  # Blur type
        sig=2.6,  # Gaussian blur standard deviation
        sig_min=0.2,  # Minimum standard deviation
        sig_max=4.0,  # Maximum standard deviation
        lambda_1=0.2,  # Anisotropic Gaussian blur parameter
        lambda_2=4.0,  # Anisotropic Gaussian blur parameter
        theta=0,  # Anisotropic Gaussian blur parameter
        lambda_min=0.2,  # Anisotropic Gaussian blur parameter
        lambda_max=4.0,  # Anisotropic Gaussian blur parameter
        noise=0  # Noise level
    )
    
    # Input and output folder paths
    input_folder = "/data/timer/DRMF/data/LLVIP/val/ir"
    output_folder = "/data/timer/DRMF/data/LLVIP/val/ir_noise"    
    low_output_folder = "/data/timer/DRMF/data/LLVIP/val/ir_noise_low"
    
    # Create output folders if they don't exist
    os.makedirs(output_folder, exist_ok=True)
    os.makedirs(low_output_folder, exist_ok=True)
    
    # List all files in the input folder
    filelist = os.listdir(input_folder)
    import random
    random.shuffle(filelist)
    
    # Iterate over the files
    filebar = tqdm(filelist)
    for filename in filebar:
        if filename.endswith('.jpg') or filename.endswith('.png'):
            # Read the image
            image_path = os.path.join(input_folder, filename)
            image = Image.open(image_path).convert('L')

            # Convert to PyTorch tensor
            transform = transforms.ToTensor()
            image_tensor = transform(image).unsqueeze(0) * 255.

            # Process the image
            lr_blurred, _ = preprocessor(image_tensor, random=True)

            # Save the processed images
            output_path = os.path.join(output_folder, filename)
            low_output_path = os.path.join(low_output_folder, filename)
            output_image = transforms.ToPILImage()(lr_blurred[0].byte())            
            output_image.save(low_output_path)
            
            # Resize and save the image with original size
            target_size = image.size 
            output_image = output_image.resize(target_size)
            output_image.save(output_path)
            
            # Update the progress bar description
            filebar.set_description("{} |".format(filename))