import torch
import torch.utils.data as data
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms


class MedianScaler(StandardScaler):
    def fit_transform(self, data):
        super(StandardScaler, self).fit_transform(data)
        self.mean_ = np.median(data, axis=0)


class DataLoader_Tabular(data.Dataset):
    def __init__(self, path, filename, label, scale='minmax', target_scale=None):
        
        """
        Load training dataset
        :param path: string with path to training set
        :param label: string, column name for label
        :param scale: string; either 'minmax' or 'standard' 
        :param target_scale: same as above or 'median' (will normalize std to 1, but subtract the median)
        :return: tensor with training data
        """
        
        # Load dataset
        self.dataset = pd.read_csv(path + filename)
        self.target = label
        # Infer data set name
        dataset_name = filename.split(".")[0]
        
        if dataset_name == "german-train" or dataset_name == "german-test":
            # For German, using a reduced feature space due to causal baseline's SCM
            num_feat = ["duration", "amount", "age"]
            cat_feat = ["personal_status_sex"]
            feature_names = ["duration", "amount", "age", "personal-status-sex"]

            # Save target and predictors
            self.X = self.dataset[feature_names]

            # Save feature names
            self.feature_names = self.X.columns.to_list()
            self.target_name = label
            self.num_features_list = num_feat
            self.cat_features_list = cat_feat
        
        else:
            # Save target and predictors
            self.X = self.dataset.drop(self.target, axis=1)
    
            # Save feature names
            self.feature_names = self.X.columns.to_list()
            self.target_name = label
        # Transform data
        if scale == 'minmax':
            self.scaler = MinMaxScaler(feature_range=(-1, 1))
        elif scale == 'standard':
            self.scaler = StandardScaler()
            
        self.scaler.fit_transform(self.X)
        self.data = self.scaler.transform(self.X)
        # pd version of data
        self.X = pd.DataFrame(self.X)
        self.X.columns = self.feature_names
        # Transform targets
        if target_scale is not None:
            if target_scale == 'minmax':
                self.target_scale = MinMaxScaler(feature_range=(-1, 1))
            elif target_scale == 'standard':
                self.target_scale = StandardScaler()
            elif target_scale == 'median':
                self.target_scale = MedianScaler()
            self.target_scale.fit_transform(self.dataset[self.target].to_numpy().reshape(-1, 1))
            self.targets = self.target_scale.transform(self.dataset[self.target].to_numpy().reshape(-1, 1)).flatten()
        else:
            self.targets = self.dataset[self.target]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # select correct row with idx
        if isinstance(idx, torch.Tensor):
            idx = idx.tolist()
        return (self.data[idx], self.targets[idx], idx)

    def get_number_of_features(self):
        return self.data.shape[1]
    
    def get_number_of_instances(self):
        return self.data.shape[0]
    

def return_loaders(data_name, is_tabular, batch_size=32, transform=None, scaler='minmax', target_scaler=None):
    
    if is_tabular:
        transform = None
    else:
        if transform is not None:
            transform = transform
        else:
            # Standard Transforms
            if data_name == 'mnist':
                transform = transforms.Compose([transforms.ToTensor()
                                                #transforms.Normalize((0.1307,), (0.3081,))
                                                ])
            # Not supported data sets
            else:
                raise ValueError
            
    # Dictionary with meta-data
    dict = {'admission': ('Admission', transform, is_tabular, 'zfya'),
            'heloc': ('Heloc', transform, is_tabular, 'ExternalRiskEstimate'),
            'adult': ('Adult', transform, is_tabular, 'income'),
            'compas': ('COMPAS', transform, is_tabular, 'risk'),
            'diabetes': ('Diabetes', transform, is_tabular, 'readmitted'),
            'german': ('German', transform, is_tabular, 'credit-risk'),
            'twomoons': ('TwoMoons', transform, is_tabular, 'label'),
            }
    
    file_train = data_name + '-train.csv'
    file_test = data_name + '-test.csv'

    dataset_train = DataLoader_Tabular(path='./Data_Sets/' + dict[data_name][0] + '/',
                                       filename=file_train, label=dict[data_name][3], scale=scaler, target_scale=target_scaler)

    dataset_test = DataLoader_Tabular(path='./Data_Sets/' + dict[data_name][0] + '/',
                                      filename=file_test, label=dict[data_name][3], scale=scaler, target_scale=target_scaler)

    trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
    testloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
    
    return trainloader, testloader

