import numpy as np
import pywt
import torch
import torch.distributed as dist
import torchvision.transforms as T
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data._utils.collate import default_collate
from torchvision.datasets import ImageFolder
from timm.data.transforms import _pil_interp

from MWM.utils.mask import create_mask

from .random_degradations import RandomBlur, RandomNoise





class WaveletMaskGenerator:
    def __init__(self,
                 level=4,
                 threshold=[.05, .01, .005, .003],
                 random_mask=False,
                 mask_patch_size=32,
                 _not=False):
        self.level = level
        self.threshold = threshold
        self.random_mask = random_mask
        self.mask_patch_size=mask_patch_size
        self._not = _not
        self.wavelet_type = 'db1'
        self.mask = np.ones((self.input_size, self.input_size), dtype=int)


    def __call__(self, img):
        coeffs = pywt.wavedec2(img, wavelet=self.wavelet_type, level=self.level)
        coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs)

        coeff_arr_sorted = np.sort(np.abs(coeff_arr.reshape(-1)))
        thresh = coeff_arr_sorted[int(np.floor((1 - self.threshold) * len(coeff_arr_sorted)))]

        if not self.random_mask:
            if not self._not:
                mask = np.abs(coeff_arr) > thresh
            else:
                mask = np.abs(coeff_arr) < thresh
        else:
            mask = create_mask(img.shape[0], self.mask_patch_size, mask_ratio=self.threshold).tolist()
            mask = np.array(mask)
        masked_coeff = coeff_arr * mask 
        coeffs_filt = pywt.array_to_coeffs(masked_coeff, coeff_slices, output_format='wavedec2')
        compressed_image = pywt.waverec2(coeffs_filt, wavelet=self.wavelet_type)

        return compressed_image.astype('uint8'), (mask * 255).astype('uint8')


def collate_fn(batch):
    if not isinstance(batch[0][0], tuple):
        return default_collate(batch)
    else:
        batch_num = len(batch)
        ret = []
        for item_idx in range(len(batch[0][0])):
            if batch[0][0][item_idx] is None:
                ret.append(None)
            else:
                ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
        ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
        return ret

def build_loader_mfm(config, logger):
    transform = WaveletMaskGenerator(config)
    logger.info(f'Pre-train data transform:\n{transform}')

    dataset = ImageFolder(config.DATA.DATA_PATH, transform)
    logger.info(f'Build dataset: train images = {len(dataset)}')
    
    sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(),
                                 rank=dist.get_rank(), shuffle=True)
    dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, 
                            num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, 
                            collate_fn=collate_fn)
    
    return dataloader