import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import numpy as np


class MaskedMNIST(Dataset):
    def __init__(self, data_dir='mnist-data', image_size=28, random_seed=0, train=True):
        self.rnd = np.random.RandomState()#random_seed)
        self.image_size = image_size
        self.train=train
        if image_size == 28:
            self.data = datasets.MNIST(
                data_dir, train=self.train, download=True,
                transform=transforms.ToTensor())
        else:
            self.data = datasets.MNIST(
                data_dir, train=self.train, download=True,
                transform=transforms.Compose([
                    transforms.Resize(image_size), transforms.ToTensor()]))
        self.generate_masks()

    def __getitem__(self, index):
        image, label = self.data[index]
        return image, self.mask[index], label, index

    def __len__(self):
        return len(self.data)

    def generate_masks(self):
        raise NotImplementedError


class BlockMaskedMNIST(MaskedMNIST):
    def __init__(self, block_len=None, *args, **kwargs):
        self.block_len = block_len
        super().__init__(*args, **kwargs)

    def generate_masks(self):
        d0_len = d1_len = self.image_size
        d0_min_len = 7
        d0_max_len = d0_len - d0_min_len
        d1_min_len = 7
        d1_max_len = d1_len - d1_min_len

        n_masks = len(self)
        self.mask = [None] * n_masks
        self.mask_info = [None] * n_masks
        for i in range(n_masks):
            if self.block_len is None:
                d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len)
                d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len)
            else:
                d0_mask_len = d1_mask_len = self.block_len

            d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1)
            d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1)

            mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8)
            mask[d0_start:(d0_start + d0_mask_len),
                 d1_start:(d1_start + d1_mask_len)] = 1
            self.mask[i] = mask
            self.mask_info[i] = d0_start, d1_start, d0_mask_len, d1_mask_len


class IndepMaskedMNIST(MaskedMNIST):
    def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs):
        self.prob = obs_prob
        self.prob_high = obs_prob_high
        super().__init__(*args, **kwargs)

    def generate_masks(self):
        imsize = self.image_size
        prob = self.prob
        prob_high = self.prob_high
        n_masks = len(self)
        self.mask = [None] * n_masks
        for i in range(n_masks):
            if prob_high is None:
                p = prob
            else:
                p = self.rnd.uniform(prob, prob_high)
            self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p)

class ShadowMaskedMNIST(MaskedMNIST):
    def __init__(self, depth=0.89, *args, **kwargs):
        self.depth = depth
        super().__init__(*args, **kwargs)

    def generate_masks(self):
        imsize = self.image_size
        depth = self.depth
        n_masks = len(self)
        self.mask = [None] * n_masks
        for i in range(n_masks):
            image, label = self.data[i]
            image = image.view(imsize*imsize)
            starting_direction = self.rnd.randint(1, 5)
            starting_point = (imsize*imsize - 1) * int(starting_direction >2)
            depth_incr = (1 - 2*int(starting_direction >2))*int(imsize**(starting_direction %2))
            view_incr = (1 - 2*int(starting_direction >2))*int(imsize**((starting_direction + 1) %2))
            mask = torch.ByteTensor(imsize*imsize)
            for view_index in range(0, imsize):
                blocked = 1
                for depth_index in range(0, imsize):
                    mask[starting_point + view_incr*view_index + depth_incr*depth_index] = blocked
                    if(image[starting_point + view_incr*view_index + depth_incr*depth_index] > 0.89):
                        blocked = 0
            self.mask[i] = mask.view(imsize, imsize)

class PatchMaskedMNIST(MaskedMNIST):
    def __init__(self,num_patches=27, *args, **kwargs):
        self.num_patches = num_patches
        super().__init__(*args, **kwargs)

    def generate_masks(self):
        imsize = self.image_size
        n_masks = len(self)
        self.mask = [None] * n_masks
        masks = torch.ByteTensor(imsize*imsize*9, imsize*imsize).bernoulli(1)
        for startingpoint in range(0, 28*28):
            for width in range(1, 10):
                mask = torch.ByteTensor(28*28).bernoulli(1)
                length = 25/width
                layer = 1
                while layer<= length:
                    mask[startingpoint + 28*(layer-1): startingpoint + width + 28*(layer-1)] = 0
                    layer += 1
                masks[9*startingpoint + width - 1] = mask
        for i in range(n_masks):
            num_boxes = self.num_patches
            mask_indices = self.rnd.randint(0, 28*28*9, num_boxes)
            mask = torch.ByteTensor(imsize*imsize).bernoulli_(1)
            for entry in mask_indices:
                mask *= masks[entry]
            self.mask[i] = mask.view(imsize, imsize)


