import random

import numpy as np
import torch


class FastTensorDataLoaderWithBinaryMask:
    """
    Inspired by:
    https://github.com/hcarlens/pytorch-tabular/blob/master/fast_tensor_data_loader.py

    Responsable for:
        - creating batch of data
        - for each batch, create the mask indicating the value with valid value, that is
                    mask == 1 --> Valid value
                    mask == 0 --> Missing value
    """
    def __init__(self,
                 tensor: torch.tensor,
                 batch_size: int = 32,
                 shuffle: bool = False,
                 drop_last: bool = False,
                 mask_percentage_row: float = 0.0,
                 mask_col_num: int = 0,
                 mask_col_num_constant: bool = True,
                 mask_on_first_feature: bool = True,
                 mask_on_last_feature: bool = True,
                 first_cols_to_exclude: int = 1,
                 last_cols_to_exclude: int = 1):
        """
        Initialize a FastTensorDataLoaderWithBinaryMask.

        :param tensor: torch.tensor containing the features
        :param batch_size: int, batch size to load.
        :param shuffle: boolean, if True, shuffle the data *in-place* whenever an
                       iterator is created out of this object.
        :param drop_last: boolean, if True when the last batch is dropped when it is incomplete (< batch_size)
        :param mask_percentage_row: float, percentage of rows to mask
        :param mask_col_num: int, maximum number of columns to mask for each row
        :param mask_col_num_constant: bool, if true for each row a constant number of row is masked
        :returns: A FastTensorDataLoader.
        """
        self.tensor = tensor

        self.dataset_len = self.tensor.shape[0]
        self.batch_size = batch_size
        self.feature_number = self.tensor.shape[1]
        self.shuffle = shuffle

        # Calculate # batches
        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
        if (remainder > 0) and (not drop_last):
            n_batches += 1
        self.n_batches = n_batches

        # Mask
        self.mask_percentage_row = mask_percentage_row
        self.mask_col_num = mask_col_num
        self.mask_col_num_constant = mask_col_num_constant
        self.mask_on_first_feature = mask_on_first_feature
        self.mask_on_last_feature = mask_on_last_feature
        self.first_cols_to_exclude = first_cols_to_exclude
        self.last_cols_to_exclude = last_cols_to_exclude

        self.batch_mask = self.reset_batch_mask()

    def reset_batch_mask(self) -> torch.tensor:
        """Reset the mask tensor

        :return: torch.tensor, the new instance of mask
        """
        if (self.mask_percentage_row == 0.0) or (self.mask_col_num == 0):
            return torch.ones(self.batch_size, self.feature_number)

        to_mask_num = int(self.batch_size * self.mask_percentage_row)
        tmp = []
        limit = 0
        if not self.mask_on_first_feature:
            limit = self.first_cols_to_exclude
        elif not self.mask_on_last_feature:
            limit = self.last_cols_to_exclude

        for _ in range(to_mask_num):
            if self.mask_col_num_constant:
                upper_limit = self.mask_col_num
            else:
                upper_limit = np.random.randint(1, self.mask_col_num + 1)
            r_tmp = [0. if i < upper_limit else 1. for i in range(self.feature_number - limit)]
            random.shuffle(r_tmp)
            tmp.append(r_tmp)
        to_mask = torch.tensor(tmp)
        current_perm = torch.randperm(self.feature_number)
        if (limit > 0) and (not self.mask_on_first_feature):
            to_mask = torch.cat([torch.ones(to_mask.shape[0], self.first_cols_to_exclude), to_mask], dim=1)
            current_perm = torch.tensor([kk for kk in range(self.first_cols_to_exclude)] +
                                        [kk+self.first_cols_to_exclude
                                         for kk in torch.randperm(self.feature_number - self.first_cols_to_exclude).tolist()],
                                        dtype=torch.long)
        if (limit > 0) and (not self.mask_on_last_feature):
            to_mask = torch.cat([to_mask, torch.ones(to_mask.shape[0], self.last_cols_to_exclude)], dim=1)
            current_perm = torch.tensor(torch.randperm(self.feature_number - self.last_cols_to_exclude).tolist() +
                                        [self.feature_number - self.last_cols_to_exclude + kk
                                         for kk in range(self.last_cols_to_exclude)],
                                        dtype=torch.long)
        do_not_mask = torch.ones(self.batch_size - to_mask_num, self.feature_number)
        return torch.cat([to_mask, do_not_mask], dim=0)[torch.randperm(self.batch_size)][:, current_perm]

    def __iter__(self):
        if self.shuffle:
            self.tensor = self.tensor[torch.randperm(self.dataset_len)]
        self.batch_mask = self.reset_batch_mask()
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.n_batches:
            raise StopIteration
        batch = self.tensor[self.i*self.batch_size:(self.i + 1) * self.batch_size]
        self.i += 1
        r = torch.randperm(self.feature_number)
        if not self.mask_on_first_feature:
            r = torch.tensor([kk for kk in range(self.first_cols_to_exclude)] +
                             [kk+self.first_cols_to_exclude for kk in torch.randperm(self.feature_number - self.first_cols_to_exclude).tolist()],
                             dtype=torch.long)
        if not self.mask_on_last_feature:
            r = torch.tensor(torch.randperm(self.feature_number - self.last_cols_to_exclude).tolist() +
                             [self.feature_number - self.last_cols_to_exclude + kk
                              for kk in range(self.last_cols_to_exclude)],
                             dtype=torch.long)
        return tuple([batch, self.batch_mask[:batch.shape[0], :][:, r]])

    def __len__(self):
        return self.n_batches
