import numpy as np
import torch
import cv2
import random
import torch.distributed as dist
import torchvision.transforms as T
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data._utils.collate import default_collate
from torchvision.datasets import ImageFolder
from timm.data.transforms import _pil_interp
from PIL import Image
from .random_degradations import RandomBlur, RandomNoise


class FreqMaskGenerator:
    def __init__(self,
                 input_size=224,
                 mask_radius1=16,
                 mask_radius2=999,
                 sample_ratio=0.5):
        self.input_size = input_size
        self.mask_radius1 = mask_radius1
        self.mask_radius2 = mask_radius2
        self.sample_ratio = sample_ratio
        self.mask = np.ones((self.input_size, self.input_size), dtype=int)
        for y in range(self.input_size):
            for x in range(self.input_size):
                if ((x - self.input_size // 2) ** 2 + (y - self.input_size // 2) ** 2) >= self.mask_radius1 ** 2 \
                        and ((x - self.input_size // 2) ** 2 + (y - self.input_size // 2) ** 2) < self.mask_radius2 ** 2:
                    self.mask[y, x] = 0

    def __call__(self):
        rnd = torch.bernoulli(torch.tensor(self.sample_ratio, dtype=torch.float)).item()
        if rnd == 0:  # high-pass
            return 1 - self.mask
        elif rnd == 1:  # low-pass
            return self.mask
        else:
            raise ValueError
        

class FreqCompressionGenerator:
    def __init__(self, image, sample_ratio=.5, config=None):
        self.sample_ratio = sample_ratio
        rnd = random.randint(0, len(config.COMPRESSION_RATES) - 1)
        self.compression_rates = config.COMPRESSION_RATES
        if isinstance(image, Image.Image):
            image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY)
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        fourier = np.fft.fft2(image)
            
        fft_shift = np.fft.fftshift(fourier)
        fourier_sorted = np.sort(np.abs(fft_shift.reshape(-1)))

        mask = np.abs(fft_shift) > fourier_sorted[int(np.floor((1 - self.compression_rates[rnd]) * len(fourier_sorted)))]
        self.mask = mask * 1

    def __call__(self):
        rnd = torch.bernoulli(torch.tensor(self.sample_ratio, dtype=torch.float)).item()
        if rnd == 0: 
            return 1 - self.mask
        elif rnd == 1: 
            return self.mask
        else:
            raise ValueError
    
class GaborFilterGenerator:
    def __init__(self, input_size, sample_ratio=.5):
        self.radius_list = [30, 50, 70, 90]
        temp = random.randint(1, 4)
        self.theta = 1 * np.pi / (1.0 * temp)
        self.lambd = 1 * np.pi / 4.0
        self.gamma = 0.5
        self.phi = 0
        self.sample_ratio = sample_ratio
        self.sigma = random.randint(1, 4) * 3
        rnd = random.randint(0, 3)
        self.input_size = input_size
        self.mask = cv2.getGaborKernel((self.radius_list[rnd], self.radius_list[rnd]), self.sigma,
                                       self.theta, self.lambd, self.gamma, self.phi, ktype=cv2.CV_32F)
        self.mask = cv2.resize(self.mask, (self.input_size, self.input_size))
        self.mask = (self.mask - self.mask.min()) / (self.mask.max() - self.mask.min())
        self.mask = np.where(self.mask > 0.6, 1.0, 0.0).astype(np.float32)

        
    def __call__(self):
        rnd = torch.bernoulli(torch.tensor(self.sample_ratio, dtype=torch.float)).item()
        if rnd == 0: 
            return 1 - self.mask
        elif rnd == 1: 
            return self.mask
        else:
            raise ValueError
    
class MaskFilterGenerator:
    def __init__(self, input_size=224, sample_ratio=.5):
        mask_patch_sizes = [2, 4, 8, 16, 32]
        mask_ratio_list = [.3, .4, .5]
        self.sample_ratio = sample_ratio
        mask_ratio = mask_ratio_list[random.randint(0, 2)]
        rnd = random.randint(0, 4)
        rand_size = input_size // mask_patch_sizes[rnd]
        
        token_count = rand_size ** 2
        mask_count = int(np.ceil(token_count * mask_ratio))

        mask_idx = np.random.permutation(token_count)[:mask_count]
        self.mask = np.ones(token_count, dtype=int)
        self.mask[mask_idx] = 0
        
        self.mask = self.mask.reshape((rand_size, rand_size))
        self.mask = self.mask.repeat(mask_patch_sizes[rnd], axis=0).repeat(mask_patch_sizes[rnd], axis=1)
        

    def __call__(self):
        rnd = torch.bernoulli(torch.tensor(self.sample_ratio, dtype=torch.float)).item()
        if rnd == 0: 
            return 1 - self.mask
        elif rnd == 1: 
            return self.mask
        else:
            raise ValueError


class MFMTransform:
    def __init__(self, config):
        self.transform_img = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(config.DATA.MIN_CROP_SCALE, 1.), interpolation=_pil_interp(config.DATA.INTERPOLATION)),
            T.RandomHorizontalFlip(),
        ])

        self.filter_type = config.DATA.FILTER_TYPE
        self.sample_ratio=config.DATA.SAMPLE_RATIO
 
        if config.MODEL.TYPE == 'swin':
            model_patch_size = config.MODEL.SWIN.PATCH_SIZE
        elif config.MODEL.TYPE == 'vit':
            model_patch_size = config.MODEL.VIT.PATCH_SIZE
        elif config.MODEL.TYPE == 'resnet':
            model_patch_size = 1
        else:
            raise NotImplementedError
        self.config = config

        if config.DATA.FILTER_TYPE == 'deblur':
            self.degrade_transform = RandomBlur(
                params=dict(
                    kernel_size=config.DATA.BLUR.KERNEL_SIZE,
                    kernel_list=config.DATA.BLUR.KERNEL_LIST,
                    kernel_prob=config.DATA.BLUR.KERNEL_PROB,
                    sigma_x=config.DATA.BLUR.SIGMA_X,
                    sigma_y=config.DATA.BLUR.SIGMA_Y,
                    rotate_angle=config.DATA.BLUR.ROTATE_ANGLE,
                    beta_gaussian=config.DATA.BLUR.BETA_GAUSSIAN,
                    beta_plateau=config.DATA.BLUR.BETA_PLATEAU),
            )
        elif config.DATA.FILTER_TYPE == 'denoise':
            self.degrade_transform = RandomNoise(
                params=dict(
                    noise_type=config.DATA.NOISE.TYPE,
                    noise_prob=config.DATA.NOISE.PROB,
                    gaussian_sigma=config.DATA.NOISE.GAUSSIAN_SIGMA,
                    gaussian_gray_noise_prob=config.DATA.NOISE.GAUSSIAN_GRAY_NOISE_PROB,
                    poisson_scale=config.DATA.NOISE.POISSON_SCALE,
                    poisson_gray_noise_prob=config.DATA.NOISE.POISSON_GRAY_NOISE_PROB),
            )
        elif config.DATA.FILTER_TYPE == 'mfm':
            self.freq_mask_generator = FreqMaskGenerator(
                input_size=config.DATA.IMG_SIZE,
                mask_radius1=config.DATA.MASK_RADIUS1,
                mask_radius2=config.DATA.MASK_RADIUS2,
                sample_ratio=config.DATA.SAMPLE_RATIO
            )
        elif config.DATA.FILTER_TYPE == 'gabor':
            self.gabor_filter_generator = GaborFilterGenerator(
                input_size=config.DATA.IMG_SIZE,
                sample_ratio=config.DATA.SAMPLE_RATIO
            )
        elif config.DATA.FILTER_TYPE == 'r_mask':
            self.mask_generator = MaskFilterGenerator(
                input_size=config.DATA.IMG_SIZE,
                sample_ratio=config.DATA.SAMPLE_RATIO
            )

    def __call__(self, img):
        img = self.transform_img(img)  # PIL Image (HxWxC, 0-255), no normalization
        if self.filter_type in ['deblur', 'denoise']:
            img_lq = np.array(img).astype(np.float32) / 255.
            img_lq = self.degrade_transform(img_lq)
            img_lq = torch.from_numpy(img_lq.transpose(2, 0, 1))
        else:
            img_lq = None
        if self.filter_type == "fftComp":
            generator = FreqCompressionGenerator(img, sample_ratio=self.sample_ratio, config=self.config)
            mask = generator()
        elif self.filter_type == "gabor":
            mask = self.gabor_filter_generator()
        elif self.filter_type == "r_mask":
            mask = self.mask_generator()
        elif self.filter_type == 'mfm':
            mask = self.freq_mask_generator()
        else:
            mask = None
        img = T.ToTensor()(img)  # Tensor (CxHxW, 0-1)
        return img, img_lq, mask
    
class MFMTransformV2(MFMTransform):
    def __call__(self, img):
        # img1, img_lq1, mask1 = super.__call__(img)
        # img2, img_lq2, mask2 = super.__call__(img)
        # return [img1, img2], [img_lq1, img_lq2], [mask1, mask2]
    
        img1 = self.transform_img(img)  # PIL Image (HxWxC, 0-255), no normalization
        img2 = self.transform_img(img)  # PIL Image (HxWxC, 0-255), no normalization
        if self.filter_type in ['deblur', 'denoise']:
            img_lq_1 = np.array(img1).astype(np.float32) / 255.
            img_lq_1 = self.degrade_transform(img_lq_1)
            img_lq_1 = torch.from_numpy(img_lq_1.transpose(2, 0, 1))

            img_lq_2 = np.array(img2).astype(np.float32) / 255.
            img_lq_2 = self.degrade_transform(img_lq_2)
            img_lq_2 = torch.from_numpy(img_lq_2.transpose(2, 0, 1))
        else:
            img_lq_2 = None
            img_lq_1 = None
        if self.filter_type == "fftComp":
            generator = FreqCompressionGenerator(img1, config=self.config)
            mask_1 = generator()

            generator = FreqCompressionGenerator(img2, config=self.config)
            mask_2 = generator()
        elif self.filter_type == "gabor":
            mask_1 = self.gabor_filter_generator()
            mask_2 = self.gabor_filter_generator()
        elif self.filter_type == "r_mask":
            mask_1 = self.mask_generator()
            mask_2 = self.mask_generator()
        elif self.filter_type == 'mfm':
            mask_1 = self.freq_mask_generator()
            mask_2 = self.freq_mask_generator()
        else:
            mask_1 = None
            mask_2 = None
        img1 = T.ToTensor()(img1).unsqueeze(0)  # Tensor (CxHxW, 0-1)
        img2 = T.ToTensor()(img2).unsqueeze(0)
        img = torch.cat((img1, img2), dim=0)

        img_lq = None
        if img_lq_1 is not None:
            img_lq_1 = img_lq_1.unsqueeze(0)
            img_lq_2 = img_lq_2.unsqueeze(0)
            img_lq = torch.cat((img_lq_1, img_lq_2), dim=0)

        mask = None
        if mask_1 is not None:
            mask_1 = np.expand_dims(mask_1, axis=0)
            mask_2 = np.expand_dims(mask_2, axis=0)
            mask = np.concatenate((mask_1, mask_2), axis=0)
        return img, img_lq, mask


def collate_fn(batch):
    if not isinstance(batch[0][0], tuple):
        return default_collate(batch)
    else:
        batch_num = len(batch)
        ret = []
        for item_idx in range(len(batch[0][0])):
            if batch[0][0][item_idx] is None:
                ret.append(None)
            else:
                ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
        ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
        return ret


def build_loader_mfm(config, logger, multi_view=False):
    if not multi_view:
        transform = MFMTransform(config)
    else:
        transform = MFMTransformV2(config)
    logger.info(f'Pre-train data transform:\n{transform}')

    dataset = ImageFolder(config.DATA.DATA_PATH, transform)
    logger.info(f'Build dataset: train images = {len(dataset)}')
    
    sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
    dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
    logger.info(f'Build dataset: train images dataloader len = {len(dataloader)}')
    return dataloader
