import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import sys
import util

class DataLoader(nn.Module):
    """Data loader module for training MCFlow
    Args:
        mode (int): Determines if we are loading training or testing data
        seed (int): Used to determine the fold for cross validation experiments or reproducibility consideration if not
        data (str): Determines the dataset to load
        drp_percent (float): Determines the binomial success rate for observing a feature
    """
    MODE_TRAIN = 0
    MODE_TEST = 1

    def __init__(self, args, mode=MODE_TRAIN):

        self.data = args.data
        if self.data in ['mnist', 'fashion', 'cifar0', 'cifar1', 'cifar2']:
            self.train_complete, self.test_complete, img_shape = util.path_to_matrix(self.data)
            self.mask_train, self.mask_test = util.create_img_masks(args.drp_percent,
                                                                    img_shape, 
                                                                    self.train_complete.shape[0], 
                                                                    self.test_complete.shape[0], 
                                                                    args.seed)
            self.train, self.test = util.fill_img_missingness(self.train_complete, 
                                                              self.test_complete, 
                                                              self.mask_train, 
                                                              self.mask_test, img_shape)
        else:
            data_complete = util.path_to_matrix(self.data)
            self.data_complete = util.preprocess(data_complete)
            np.random.RandomState(args.seed).shuffle(self.data_complete)
            self.train_complete, self.test_complete = util.create_k_fold(self.data_complete, args.fold_id)

            self.mask = util.create_tabular_mask(self.data_complete, args.drp_percent, args.miss_pattern, args.seed)
            self.mask_train, self.mask_test = util.create_k_fold(self.mask, args.fold_id)

            self.data_incomplete = self.data_complete.copy()
            self.data_incomplete[np.where(self.mask)] = np.nan
            data_filled = util.fill_missingness(self.data_incomplete, self.mask)
            self.train, self.test = util.create_k_fold(data_filled, args.fold_id)
            
            self.train_complete = torch.from_numpy(self.train_complete)
            self.test_complete = torch.from_numpy(self.test_complete)
            self.mask_train = torch.from_numpy(self.mask_train)
            self.mask_test = torch.from_numpy(self.mask_test)
            self.train = torch.from_numpy(self.train)
            self.test = torch.from_numpy(self.test)
            self.data_incomplete = torch.from_numpy(self.data_incomplete)
            
        self.mode = mode

    def reset_imputed_values(self, batch_EM, nf_model, args):

        random_mat = torch.clip(util.infer_imputation(batch_EM, nf_model, self.train, self.mask_train, args), 0, 1)
        self.train = (1-self.mask_train) * self.train + self.mask_train * random_mat
        random_mat = torch.clip(util.infer_imputation(batch_EM, nf_model, self.test, self.mask_test, args), 0, 1)
        self.test = (1-self.mask_test) * self.test + self.mask_test * random_mat

    def __len__(self):
        if self.mode==0:
            return len(self.train)
        elif self.mode==1:
            return len(self.test)
        else:
            print("Data loader mode error -- acceptable modes are 0,1,2")
            sys.exit()

    def __getitem__(self, idx):
        if self.mode==0:
            return self.train[idx] , self.train_complete[idx], self.mask_train[idx]
        elif self.mode==1:
            return self.test[idx] , self.test_complete[idx], self.mask_test[idx]
        else:
            print("Data loader mode error -- acceptable modes are 0,1,2")
            sys.exit()
