import torch
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import Dataset


class AdultDataset(Dataset):
    """
    Adult dataset.
    http://archive.ics.uci.edu/ml/datasets/Adult
    """

    def __init__(self, file):
        """
        :param file: Path to the data file.
        """
        self.data = pd.read_csv(file, sep=',', header=None)

        self.workclass_dict = {'Private': 0, 'Self-emp-not-inc': 1, 'Self-emp-inc': 2, 'Federal-gov': 3, 'Local-gov': 4,
                               'State-gov': 5, 'Without-pay': 6, 'Never-worked': 7}
        self.education_dict = {'Bachelors': 0, 'Some-college': 1, '11th': 2, 'HS-grad': 3, 'Prof-school': 4,
                               'Assoc-acdm': 5, 'Assoc-voc': 6, '9th': 7, '7th-8th': 8, '12th': 9, 'Masters': 10,
                               '1st-4th': 11, '10th': 12, 'Doctorate': 13, '5th-6th': 14, 'Preschool': 15}
        self.marital_status_dict = {'Married-civ-spouse': 0, 'Divorced': 1, 'Never-married': 2, 'Separated': 3,
                                    'Widowed': 4, 'Married-spouse-absent': 5, 'Married-AF-spouse': 6}
        self.occupation_dict = {'Tech-support': 0, 'Craft-repair': 1, 'Other-service': 2, 'Sales': 3,
                                'Exec-managerial': 4, 'Prof-specialty': 5, 'Handlers-cleaners': 6,
                                'Machine-op-inspct': 7, 'Adm-clerical': 8, 'Farming-fishing': 9, 'Transport-moving': 10,
                                'Priv-house-serv': 11, 'Protective-serv': 12, 'Armed-Forces': 13}
        self.relationship_dict = {'Wife': 0, 'Own-child': 1, 'Husband': 2, 'Not-in-family': 3, 'Other-relative': 4,
                                  'Unmarried': 5}
        self.race_dict = {'White': 0, 'Asian-Pac-Islander': 1, 'Amer-Indian-Eskimo': 2, 'Other': 3, 'Black': 4}
        self.sex_dict = {'Female': 0, 'Male': 1}
        self.native_country_dict = {'United-States': 0, 'Cambodia': 1, 'England': 2, 'Puerto-Rico': 3, 'Canada': 4,
                                    'Germany': 5, 'Outlying-US(Guam-USVI-etc)': 6, 'India': 7, 'Japan': 8, 'Greece': 9,
                                    'South': 10, 'China': 11, 'Cuba': 12, 'Iran': 13, 'Honduras': 14, 'Philippines': 15,
                                    'Italy': 16, 'Poland': 17, 'Jamaica': 18, 'Vietnam': 19, 'Mexico': 20,
                                    'Portugal': 21,
                                    'Ireland': 22, 'France': 23, 'Dominican-Republic': 24, 'Laos': 25, 'Ecuador': 26,
                                    'Taiwan': 27, 'Haiti': 28, 'Columbia': 29, 'Hungary': 30, 'Guatemala': 31,
                                    'Nicaragua': 32, 'Scotland': 33, 'Thailand': 34, 'Yugoslavia': 35,
                                    'El-Salvador': 36,
                                    'Trinadad&Tobago': 37, 'Peru': 38, 'Hong': 39, 'Holand-Netherlands': 40}
        self.label_dict = {'>50K': 0, '<=50K': 1}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data = self.data.iloc[idx]
        age = torch.FloatTensor([[data[0] / 90.0]])
        workclass = F.one_hot(torch.LongTensor([self.workclass_dict[data[1]]]), num_classes=8)
        fnlwgt = torch.FloatTensor([[data[2] / 1184622.0]])
        education = F.one_hot(torch.LongTensor([self.education_dict[data[3]]]), num_classes=16)
        education_num = torch.FloatTensor([[data[4] / 16.0]])
        marital_status = F.one_hot(torch.LongTensor([self.marital_status_dict[data[5]]]), num_classes=7)
        occupation = F.one_hot(torch.LongTensor([self.occupation_dict[data[6]]]), num_classes=14)
        relationship = F.one_hot(torch.LongTensor([self.relationship_dict[data[7]]]), num_classes=6)
        race = F.one_hot(torch.LongTensor([self.race_dict[data[8]]]), num_classes=5)
        sex = self.sex_dict[data[9]]
        capital_gain = torch.FloatTensor([[data[10] / 99999.0]])
        capital_loss = torch.FloatTensor([[data[11] / 4356.0]])
        hours_per_week = torch.FloatTensor([[data[12] / 99.0]])
        native_country = F.one_hot(torch.LongTensor([self.native_country_dict[data[13]]]), num_classes=41)
        label = self.label_dict[data[14]]

        attribute = torch.cat(
            (age, workclass.float(), fnlwgt, education.float(), education_num, marital_status.float(),
             occupation.float(), relationship.float(), race.float(), capital_gain, capital_loss,
             hours_per_week, native_country.float()), dim=1)
        sample = {'attribute': attribute, 'label': label, 'group': sex}

        return sample
