
import torch
import os
import pandas as pd
import numpy as np


class CIFAR10(torch.utils.data.dataset.Dataset):
    def __init__(self, root, train=True, subset=None, transform=None, target_transform=None):
        self.classes = ('Plane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog',
                        'Horse', 'Ship', 'Truck')
        self.train = train
        
        if self.train:
            self.data = torch.load(os.path.join(root, 'cifar10_trainX'))
            self.target = torch.load(os.path.join(root, 'cifar10_trainY'))
        else:
            self.data = torch.load(os.path.join(root, 'cifar10_testX'))
            self.target = torch.load(os.path.join(root, 'cifar10_testY'))
            
        if subset is not None:
            self.data = self.data[subset,]
            self.target = self.target[subset]

        self.transform = transform
        self.target_transform = target_transform


    def __getitem__(self, index):
        data, target = self.data[index,], self.target[index]

        if self.transform is not None:
            data = self.transform(data)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return data, target, index
    
    
    def __len__(self):
        return self.target.shape[0]



class FMNIST(torch.utils.data.dataset.Dataset):
    def __init__(self, root, train=True, subset=None, transform=None, target_transform=None):
        self.classes = ('T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
                        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')
        self.train = train
        
        if self.train:
            self.data = torch.load(os.path.join(root, 'fmnist_trainX'))
            self.target = torch.load(os.path.join(root, 'fmnist_trainY'))
        else:
            self.data = torch.load(os.path.join(root, 'fmnist_testX'))
            self.target = torch.load(os.path.join(root, 'fmnist_testY'))

        if subset is not None:
            self.data = self.data[subset,]
            self.target = self.target[subset]

        self.transform = transform
        self.target_transform = target_transform


    def __getitem__(self, index):
        data, target = self.data[index,], self.target[index]

        if self.transform is not None:
            data = self.transform(data)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return data, target, index
    
    
    def __len__(self):
        return self.target.shape[0]




class Adult(torch.utils.data.dataset.Dataset):
    # Categorical variables
    categorical = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'gender', 'native-country', 'income']
    
    def __init__(self, root, target_name, train=True, preprocess=None, transform=None, target_transform=None):
        self.classes = ('<= 50k', '> 50k')
        self.train = train
        
        # Read CSV file
        column_names = ['age', 'workclass', 'fnlwgt', 'education', 'educational-num', 
                        'marital-status', 'occupation', 'relationship', 'race', 'gender',
                        'capital-gain', 'capital-loss', 'hours-per-week', 'native-country','income']

        # Load data
        if self.train:
            self.data = pd.read_csv(os.path.join(root, 'adult.data'), sep = ",\s", header = None, names = column_names, engine = 'python')
        else:
            self.data = pd.read_csv(os.path.join(root, 'adult.test'), sep = ",\s", header = None, names = column_names, skiprows = 1, engine = 'python')
            self.data['income'].replace(regex = True, inplace = True, to_replace = r'\.', value = r'')

        # Declare categorical variables
        for var_name in Adult.categorical:
            self.data[var_name] = self.data[var_name].astype('category')

        if preprocess is not None:
            self.data = preprocess(self.data)

        # Recover response variable
        self.target = self.data.filter(regex=f'^{target_name}', axis = 1)
        self.data = self.data.drop(self.target.columns, axis = 1)

        self.transform = transform
        self.target_transform = target_transform



    def __getitem__(self, index):
        if type(index) is int:
            data, target = self.data.loc[[index]], self.target.loc[[index]]
        else:
            data, target = self.data.loc[index], self.target.loc[index]

        if self.transform is not None:
            data = self.transform(data)

        if self.target_transform is not None:
            target = self.target_transform(target)
            
        return data, target, index
    
    
    def __len__(self):
        return self.target.shape[0]
    
    
    
class COMPAS(torch.utils.data.dataset.Dataset):
    # List of clean variables
    variables = ['sex', 'age', 'age_cat', 'race', 'decile_score', 'score_text',
                 'v_decile_score', 'v_score_text', 'juv_misd_count', 'juv_other_count',
                 'priors_count', 'c_charge_degree', 'is_recid', 'is_violent_recid',
                 'two_year_recid']
    
    # Categorical variables
    categorical = ['sex', 'age_cat', 'race', 'score_text', 'v_score_text',
                   'c_charge_degree', 'is_recid', 'is_violent_recid', 'two_year_recid']
    
    def __init__(self, root, target_name, train=True, split=0.7, preprocess=None, transform=None, target_transform=None):
        self.train = train
        
        # Read CSV file
        self.data = pd.read_csv(os.path.join(root, 'compas-scores-two-years.csv'))
        
        # Drop repeated columns
        self.data = self.data.drop('decile_score.1', axis = 1)
        self.data = self.data.drop('priors_count.1', axis = 1)
        
        # Filter |days_b_screening_arrest| <= 30 (as in ProPublica analysis)
        self.data = self.data[(self.data['days_b_screening_arrest'] >= -30) &
                              (self.data['days_b_screening_arrest'] <= 30)]
        
        # Random split
        N = self.data.shape[0]
        idx_list = np.random.RandomState(seed=42).permutation(N)
        split_idx = int(np.ceil(N*split))
        train_idx = idx_list[:split_idx]
        test_idx = idx_list[split_idx:]
        
        # Normalize indices
        self.data.reset_index(drop=True, inplace=True)
        
        if self.train:
            self.data = self.data.loc[train_idx,]
        else:
            self.data = self.data.loc[test_idx,]
        
        # Renomarlize indices
        self.data.reset_index(drop=True, inplace=True)
        
        # Keep only columns of interest
        self.data = self.data[COMPAS.variables]
        
        # Declare categorical variables
        for var_name in COMPAS.categorical:
            self.data[var_name] = self.data[var_name].astype('category')

        if preprocess is not None:
            self.data = preprocess(self.data)

        # Recover response variable
        self.target = self.data.filter(regex=f'^{target_name}', axis = 1)
        self.data = self.data.drop(self.target.columns, axis = 1)

        self.transform = transform
        self.target_transform = target_transform


    def __getitem__(self, index):
        if type(index) is int:
            data, target = self.data.loc[[index]], self.target.loc[[index]]
        else:
            data, target = self.data.loc[index], self.target.loc[index]

        if self.transform is not None:
            data = self.transform(data)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return data, target, index
    
    
    def __len__(self):
        return self.target.shape[0]
