
import os

import datatable as dt
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.datasets import load_diabetes, load_iris, load_breast_cancer, load_wine, load_boston, fetch_california_housing, fetch_kddcup99, fetch_openml
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from utils.utils import set_seed
from utils.invase_data_generation import generate_dataset
from utils.l2x import generate_data_l2x





class Loader(object):
    """ Data loader """

    def __init__(self, config, dataset_name, drop_last=True, kwargs={}):
        """Pytorch data loader

        Args:
            config (dict): Dictionary containing options and arguments.
            dataset_name (str): Name of the dataset to load
            drop_last (bool): True in training mode, False in evaluation.
            kwargs (dict): Dictionary for additional parameters if needed

        """
        # Get batch size
        batch_size = config["batch_size"]
        # Get config
        self.config = config
        # Set the paths
        paths = config["paths"]
        # data > dataset_name
        file_path = os.path.join(paths["data"], dataset_name)
        # Get the datasets
        train_dataset, test_dataset, validation_dataset = self.get_dataset(dataset_name, file_path)
        # Set the loader for training set
        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last, **kwargs)
        # Set the loader for test set
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, **kwargs)
        # Set the loader for validation set
        self.validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, **kwargs)
        

    def get_dataset(self, dataset_name, file_path):
        """Returns training, validation, and test datasets"""
        # Create dictionary for loading functions of datasets.
        # If you add a new dataset, add its corresponding dataset class here in the form 'dataset_name': ClassName
        loader_map = {'default_loader': TabularDataset}
        # Get dataset. Check if the dataset has a custom class. 
        # If not, then assume a tabular data with labels in the first column
        dataset = loader_map[dataset_name] if dataset_name in loader_map.keys() else loader_map['default_loader']
        # Training and Validation datasets
        train_dataset = dataset(self.config, datadir=file_path, dataset_name=dataset_name, mode='train')
        # Test dataset
        test_dataset = dataset(self.config, datadir=file_path, dataset_name=dataset_name, mode='test')
        # validation dataset
        validation_dataset = dataset(self.config, datadir=file_path, dataset_name=dataset_name, mode="validation")
        # Return
        return train_dataset, test_dataset, validation_dataset


class ToTensorNormalize(object):
    """Convert ndarrays to Tensors."""
    def __call__(self, sample):
        # Assumes that min-max scaling is done when pre-processing the data
        return torch.from_numpy(sample).float()


class TabularDataset(Dataset):
    def __init__(self, config, datadir, dataset_name, mode='train', transform=ToTensorNormalize()):
        """Dataset class for tabular data format.

        Args:
            config (dict): Dictionary containing options and arguments.
            datadir (str): The path to the data directory
            dataset_name (str): Name of the dataset to load
            mode (bool): Defines whether the data is for Train, Validation, or Test mode
            transform (func): Transformation function for data
            
        """

        self.config = config
        set_seed(config)
        self.mode = mode
        self.paths = config["paths"]
        self.dataset_name = dataset_name
        self.data_path = os.path.join(self.paths["data"], dataset_name)
        self.data, self.labels = self._load_data()
        self.transform = transform

    def __len__(self):
        """Returns number of samples in the data"""
        return len(self.data)

    def __getitem__(self, idx):
        """Returns batch"""
        sample = self.data[idx]
        cluster = int(self.labels[idx])
        return sample, cluster

    def _load_data(self):
        """Loads one of many available datasets, and returns features and labels"""

        if self.dataset_name.lower() in ["mnist"]:
            x_train, y_train, x_test, y_test = self._load_mnist()
        elif self.dataset_name.lower() in ["synrank"]:
            x_train, y_train, x_test, y_test = self._load_synrank()
        elif self.dataset_name.lower() in ["l2x_xor"]:
            x_train, y_train, x_test, y_test = self._load_l2x_XOR()
        elif self.dataset_name.lower() in ["l2x_orange"]:
            x_train, y_train, x_test, y_test = self._load_l2x_orange()
        elif self.dataset_name.lower() in ["l2x_switch"]:
            x_train, y_train, x_test, y_test = self._load_l2x_switch()  
        elif self.dataset_name.lower() in ["l2x_additive"]:
            x_train, y_train, x_test, y_test = self._load_l2x_additive() 
        else:
            print(f"Given dataset name {self.dataset_name.lower()} is not found. Check for typos, or missing condition "
                  f"in _load_data() of TabularDataset class in utils/load_data.py .")
            exit()

        # Define the ratio of training-validation split, e.g. 0.8
        training_data_ratio = self.config["training_data_ratio"]
        
        # If validation is on, and trainin_data_ratio==1, stop and warn
        if self.config["validate"] and training_data_ratio >= 1.0:
            print(f"training_data_ratio must be < 1.0 if you want to run validation during training.")
            exit()            

            
        # Update number of classes in the config file in case that it is not correct.
        n_classes = len(list(set(y_train.reshape(-1, ).tolist())))
        if self.config["n_classes"] != n_classes:
            self.config["n_classes"] = n_classes
            print(f"{50 * '>'} Number of classes changed "
                  f"from {self.config['n_classes']} to {n_classes} {50 * '<'}")

        # Check if the values of features are small enough to work well for neural network
        if np.max(np.abs(x_train)) > 100:
            print(f"Pre-processing of data does not seem to be correct. "
                  f"Max value found in features is {np.max(np.abs(x_train))}\n"
                  f"Please check the values of features...")
            exit()
          
        
        # Collect each fold into a list
        x_train_l, x_val_l = [], []
        y_train_l, y_val_l = [], []
        
        ####################### Defaul traning - validation set  ###############
        
        
        # Shuffle indexes of samples to randomize training-validation split
        idx = np.random.permutation(x_train.shape[0])

        # Divide training and validation data : 
        # validation data = training_data_ratio:(1-training_data_ratio)
        tr_idx = idx[:int(len(idx) * training_data_ratio)]
        val_idx = idx[int(len(idx) * training_data_ratio):]

        # Validation data
        x_val1 = x_train[val_idx, :]
        y_val1 = y_train[val_idx]
        
        # Training data
        x_train1 = x_train[tr_idx, :]
        y_train1 = y_train[tr_idx]
        
        x_train_l.append(x_train1)
        y_train_l.append(y_train1)

        x_val_l.append(x_val1)
        y_val_l.append(y_val1)
            
        #########################################################################
            
        # Instantiate a K-fold to spit training into training and validation sets
        kf = KFold(n_splits=self.config["kfold"], shuffle=True, random_state=self.config["seed"])
        kf.get_n_splits(x_train)
        

        for tr_index, val_index in kf.split(x_train):
            x_train_l.append(x_train[tr_index, :])
            y_train_l.append(y_train[tr_index])

            x_val_l.append(x_train[val_index, :])
            y_val_l.append(y_train[val_index])
        

        # Select features and labels, based on the mode
        if self.mode == "train":
            data = x_train_l[self.config["fold_num"]]
            labels = y_train_l[self.config["fold_num"]]
        elif self.mode == "validation":
            data = x_val_l[self.config["fold_num"]]
            labels = y_val_l[self.config["fold_num"]]
        elif self.mode == "test":
            data = x_test
            labels = y_test
        else:
            print(f"Something is wrong with the data mode. "
                  f"Use one of three options: train, validation, and test.")
            exit()
        
        # Return features, and labels
        return data, labels


    def _load_mnist(self):
        """Loads MNIST dataset"""
        
        self.data_path = os.path.join("./data/", "mnist")
        
        with open(self.data_path + '/train.npy', 'rb') as f:
            x_train = np.load(f)
            y_train = np.load(f)

        with open(self.data_path + '/test.npy', 'rb') as f:
            x_test = np.load(f)
            y_test = np.load(f)

        x_train = x_train.reshape(-1, 28 * 28) / 255.
        x_test = x_test.reshape(-1, 28 * 28) / 255.

        return x_train, y_train, x_test, y_test
    
    
    def _load_synrank(self):
        """
        Create train and validation datasets.
        """
        n = int(1e4)
        
        choices = ['orange_skin','XOR','nonlinear_additive','switch', 'synrank']
        
        datatype = choices[-1]
        
        x_train, y_train, _ = generate_data_l2x(n = n, 
            datatype = datatype, seed = 0)  

        x_test, y_test, datatypes_val = generate_data_l2x(n = 10**4, datatype = datatype, seed = 1)  

        y_train = np.argmax(y_train, axis=1)
        y_test = np.argmax(y_test, axis=1)
        
        features = list(range(1, x_train.shape[1]+1))
        self.config["l2x_features"] = features
        
        
        scaler = StandardScaler()
        x_train = scaler.fit_transform(x_train)
        x_test = scaler.transform(x_test)
        
        idx = np.random.permutation(x_train.shape[0])
        x_train = x_train[idx,:]
        y_train = y_train[idx]

        return x_train, y_train, x_test, y_test
    
    


    def _load_l2x_orange(self):
        """
        Create train and validation datasets.
        """
        n = int(1e4)
        
        choices = ['orange_skin','XOR','nonlinear_additive','switch', 'synrank']
        
        datatype = choices[0]
        
        x_train, y_train, _ = generate_data_l2x(n = n, 
            datatype = datatype, seed = 0)  

        x_test, y_test, datatypes_val = generate_data_l2x(n = 10**4, datatype = datatype, seed = 1)  

        y_train = np.argmax(y_train, axis=1)
        y_test = np.argmax(y_test, axis=1)
        
        features = list(range(1, x_train.shape[1]+1))
        self.config["l2x_features"] = features
        
        
        scaler = StandardScaler()
        x_train = scaler.fit_transform(x_train)
        x_test = scaler.transform(x_test)
        
        idx = np.random.permutation(x_train.shape[0])
        x_train = x_train[idx,:]
        y_train = y_train[idx]


        return x_train, y_train, x_test, y_test
    
        
    


    def _load_l2x_XOR(self):
        """
        Create train and validation datasets.
        """
        n = int(1e4)
        
        choices = ['orange_skin','XOR','nonlinear_additive','switch', 'synrank']
        
        datatype = choices[1]
        
        x_train, y_train, _ = generate_data_l2x(n = n, 
            datatype = datatype, seed = 0)  

        x_test, y_test, datatypes_val = generate_data_l2x(n = 10**4, datatype = datatype, seed = 1)  

        y_train = np.argmax(y_train, axis=1)
        y_test = np.argmax(y_test, axis=1)
        
        features = list(range(1, x_train.shape[1]+1))
        self.config["l2x_features"] = features
        
        
        scaler = StandardScaler()
        x_train = scaler.fit_transform(x_train)
        x_test = scaler.transform(x_test)

        idx = np.random.permutation(x_train.shape[0])
        x_train = x_train[idx,:]
        y_train = y_train[idx]
        
        return x_train, y_train, x_test, y_test
    
    


    def _load_l2x_additive(self):
        """
        Create train and validation datasets.
        """
        n = int(1e4)
        
        choices = ['orange_skin','XOR','nonlinear_additive','switch', 'synrank']
        
        datatype = choices[2]
        
        x_train, y_train, _ = generate_data_l2x(n = n, 
            datatype = datatype, seed = 0)  

        x_test, y_test, datatypes_val = generate_data_l2x(n = 10**4, datatype = datatype, seed = 1)  

        y_train = np.argmax(y_train, axis=1)
        y_test = np.argmax(y_test, axis=1)
        
        features = list(range(1, x_train.shape[1]+1))
        self.config["l2x_features"] = features
        
        
        scaler = StandardScaler()
        x_train = scaler.fit_transform(x_train)
        x_test = scaler.transform(x_test)

        idx = np.random.permutation(x_train.shape[0])
        x_train = x_train[idx,:]
        y_train = y_train[idx]
        
            
        return x_train, y_train, x_test, y_test
    
    


    def _load_l2x_switch(self):
        """
        Create train and validation datasets.
        """
        n = int(1e4)
        
        choices = ['orange_skin','XOR','nonlinear_additive','switch', 'synrank']
        
        datatype = choices[3]
        
        x_train, y_train, _ = generate_data_l2x(n = n, 
            datatype = datatype, seed = 0)  

        x_test, y_test, datatypes_val = generate_data_l2x(n = 10**4, datatype = datatype, seed = 1)  

        y_train = np.argmax(y_train, axis=1)
        y_test = np.argmax(y_test, axis=1)
        
        features = list(range(1, x_train.shape[1]+1))
        self.config["l2x_features"] = features
        
        
        scaler = StandardScaler()
        x_train = scaler.fit_transform(x_train)
        x_test = scaler.transform(x_test)
        

        idx = np.random.permutation(x_train.shape[0])
        x_train = x_train[idx,:]
        y_train = y_train[idx]

        return x_train, y_train, x_test, y_test