import numpy as np
import pandas as pd
import pickle
import torch
from sklearn.preprocessing import LabelEncoder, normalize
from torch.utils.data import Dataset


ADULT_FEATURES = ["Age", "Workclass", "fnlwgt", "Education", "Education-Num", "Marital Status", 
                  "Occupation", "Relationship", "Race", "Sex", "Capital Gain", "Capital Loss",
                  "Hours per week", "Country", "Target"] 
ADULT_TRAIN_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'
ADULT_TEST_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test'


class _IndexedAdult(Dataset):
    def __init__(self, train=True, indexed=True):
        with open('adult.pkl', 'rb') as f:
            train_data, train_targets, test_data, test_targets = pickle.load(f)
        
        f.close()

        if not train:
            self.data = test_data
            self.targets = test_targets
        else:
            self.data = train_data
            if not indexed:
                self.targets = train_targets
            else:
                self.targets = torch.cat((train_targets.reshape(-1, 1), torch.arange(len(train_targets)).reshape(-1, 1)), dim=1)
        

    def __len__(self):
        return len(self.targets)
    

    def __getitem__(self, index):
        datum, target = self.data[index], self.targets[index].int()
        return datum, target



"""
class _IndexedAdult(Dataset):
    def __init__(self, train=True, indexed=True):
        original_train = pd.read_csv(ADULT_TRAIN_URL, names=ADULT_FEATURES, sep=r'\s*,\s*', 
                                     engine='python', na_values="?")
        original_test = pd.read_csv(ADULT_TEST_URL, names=ADULT_FEATURES, sep=r'\s*,\s*', 
                                    engine='python', na_values="?", skiprows=1)
        original_test["Target"] = original_test["Target"].str.rstrip('.')
        original_train.drop(["fnlwgt",'Education-Num',"Country","Workclass"], axis = 1, inplace = True)
        original_test.drop(["fnlwgt",'Education-Num',"Country","Workclass"], axis = 1, inplace = True)
        original_train = original_train.dropna()
        original_test  = original_test.dropna()
        n_train = len(original_train)
        data = pd.concat([original_train, original_test], ignore_index=True)
        data['Education'].replace(['11th', '9th', '7th-8th', '5th-6th', '10th', '1st-4th', 'Preschool', '12th'],
                                'BelowHS', inplace = True)
        data['Race'].replace(['Black','Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other'],' Other', inplace = True)   
        encoder = LabelEncoder()
        data = data.apply(encoder.fit_transform)
        my = data.pop('Target')
        X_train, X_test, y_train, y_test = data[:n_train], data[n_train:], my[:n_train], my[n_train:]

        frac = np.unique(y_train.values, return_counts= True)[1]
        frac = frac.astype(np.float32)
        frac /= frac.sum()

        X_train = X_train.values
        X_test = X_test.values
        y_train = y_train.values
        y_test = y_test.values

        X_train = torch.from_numpy(normalize(X_train, axis=0, norm='max')).float()
        X_test = torch.from_numpy(normalize(X_test, axis=0, norm='max')).float()
        y_train = torch.from_numpy(y_train)
        y_test = torch.from_numpy(y_test)

        if not train:
            self.data = X_test
            self.targets = y_test
        else:
            self.data = X_train
            if not indexed:
                self.targets = y_train
            else:
                self.targets = torch.cat((y_train.reshape(-1, 1), torch.arange(len(y_train)).reshape(-1, 1)), dim=1)
        

    def __len__(self):
        return len(self.targets)
    

    def __getitem__(self, index):
        datum, target = self.data[index], self.targets[index].int()
        return datum, target
"""



def load_adult(*, test=False):
    """
    Loads the Adult Income dataset.

    Args:
        test (bool): Determines whether to return the (unindexed) test set. Defaults to `False`.

    Returns:
        The required version of Adult Income dataset.
    """
    return _IndexedAdult(train=not test, indexed=not test)