import torch
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import Dataset


class DutchDataset(Dataset):
    """
    Dutch dataset.
    https://sites.google.com/site/faisalkamiran/
    """

    def __init__(self, file):
        """
        :param file: Path to the data file.
        """
        self.data = pd.read_csv(file, sep=',')

        self.household_position_dict = [1110, 1121, 1122, 1131, 1132, 1140, 1210, 1220]
        self.cur_eco_activity_dict = [111, 122, 124, 131, 132, 133, 134, 135, 136, 137, 138, 139]
        self.occupation_dict = {'2_1': 0, '5_4_9': 1}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data = self.data.iloc[idx]
        sex = data[0] - 1
        age = F.one_hot(torch.LongTensor([data[1] - 4]), num_classes=12)
        household_position = F.one_hot(torch.LongTensor([self.household_position_dict.index(data[2])]), num_classes=8)
        household_size = F.one_hot(torch.LongTensor([data[3] % 10 - 1]), num_classes=6)
        prev_residence_place = F.one_hot(torch.LongTensor([data[4] - 1]), num_classes=2)
        citizenship = F.one_hot(torch.LongTensor([data[5] - 1]), num_classes=3)
        country_birth = F.one_hot(torch.LongTensor([data[6] - 1]), num_classes=3)
        edu_level = F.one_hot(torch.LongTensor([data[7]]), num_classes=6)
        economic_status = F.one_hot(torch.LongTensor([data[8] % 10]), num_classes=3)
        cur_eco_activity = F.one_hot(torch.LongTensor([self.cur_eco_activity_dict.index(data[9])]), num_classes=12)
        marital_status = F.one_hot(torch.LongTensor([data[10] - 1]), num_classes=4)
        occupation = self.occupation_dict[data[11]]

        attribute = torch.cat(
            (age.float(), household_position.float(), household_size.float(), prev_residence_place.float(),
             citizenship.float(), country_birth.float(), edu_level.float(), economic_status.float(),
             cur_eco_activity.float(), marital_status.float()), dim=1)
        sample = {'attribute': attribute, 'label': occupation, 'group': sex}

        return sample
