import numpy as np
import os
import pickle
import torch
import pandas as pd
from torch.utils.data import Dataset
import folktables
from module import check_exists, makedir_exist_ok, save, load


# 5 class 5 races
# 1.6 million instances, 10 features
class ACSIncome(Dataset):
    data_name = 'ACSIncome'

    def __init__(self, root, split, seed):
        self.root = os.path.expanduser(root)
        self.split = split
        self.seed = seed
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))

    def __getitem__(self, index):
        id, data, target, sensitive = torch.tensor(self.id[index]), torch.tensor(self.data[index]), torch.tensor(
            self.target[index]), torch.tensor(self.sensitive[index])
        input = {'id': id, 'data': data, 'target': target, 'sensitive': sensitive}
        other = {k: torch.tensor(self.other[k][index]) for k in self.other}
        input = {**input, **other}
        return input

    def __len__(self):
        return len(self.data)

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'seed_{self.seed}')

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')

    def process(self):
        if not check_exists(self.raw_folder):
            self.load()
        train_set, test_set, meta = self.make_data()
        save(train_set, os.path.join(self.processed_folder, 'train'))
        save(test_set, os.path.join(self.processed_folder, 'test'))
        save(meta, os.path.join(self.processed_folder, 'meta'))
        return

    def load(self):
        get_data_fn = lambda: folktables.ACSDataSource(
            survey_year='2018',
            horizon='1-Year',
            survey='person',
            root_dir=self.raw_folder,
        ).get_data(download=True)
        df = get_data_fn()
        full_path = os.path.join(self.raw_folder, 'acs_income.pkl')
        with open(full_path, 'wb') as f:
            pickle.dump(df, f)
        return

        
    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNClass: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root,
                                                                     self.seed,
                                                                     self.split,
                                                                     self.metadata['n_classes'],
                                                                     self.metadata['n_groups'])
        return fmt_str


    def make_data(self):
        n_classes = 5
        target = 'PINCP'
        features = [
            'AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX',
            'RAC1P'
        ]
        categories = {
            "COW": {
                1.0: ("Employee of a private for-profit company or"
                        "business, or of an individual, for wages,"
                        "salary, or commissions"),
                2.0: ("Employee of a private not-for-profit, tax-exempt,"
                        "or charitable organization"),
                3.0:
                    "Local government employee (city, county, etc.)",
                4.0:
                    "State government employee",
                5.0:
                    "Federal government employee",
                6.0: ("Self-employed in own not incorporated business,"
                        "professional practice, or farm"),
                7.0: ("Self-employed in own incorporated business,"
                        "professional practice or farm"),
                8.0:
                    "Working without pay in family business or farm",
                9.0:
                    "Unemployed and last worked 5 years ago or earlier or never worked",
            },
            "SCHL": {
                1.0: "No schooling completed",
                2.0: "Nursery school, preschool",
                3.0: "Kindergarten",
                4.0: "Grade 1",
                5.0: "Grade 2",
                6.0: "Grade 3",
                7.0: "Grade 4",
                8.0: "Grade 5",
                9.0: "Grade 6",
                10.0: "Grade 7",
                11.0: "Grade 8",
                12.0: "Grade 9",
                13.0: "Grade 10",
                14.0: "Grade 11",
                15.0: "12th grade - no diploma",
                16.0: "Regular high school diploma",
                17.0: "GED or alternative credential",
                18.0: "Some college, but less than 1 year",
                19.0: "1 or more years of college credit, no degree",
                20.0: "Associate's degree",
                21.0: "Bachelor's degree",
                22.0: "Master's degree",
                23.0: "Professional degree beyond a bachelor's degree",
                24.0: "Doctorate degree",
            },
            "MAR": {
                1.0: "Married",
                2.0: "Widowed",
                3.0: "Divorced",
                4.0: "Separated",
                5.0: "Never married or under 15 years old",
            },
            "SEX": {
                1.0: "Male",
                2.0: "Female"
            },
            "RAC1P": {
                1.0: "White alone",
                2.0: "Black or African American alone",
                3.0: "American Indian alone",
                4.0: "Alaska Native alone",
                5.0: ("American Indian and Alaska Native tribes specified;"
                        "or American Indian or Alaska Native,"
                        "not specified and no other"),
                6.0: "Asian alone",
                7.0: "Native Hawaiian and Other Pacific Islander alone",
                8.0: "Some Other Race alone",
                9.0: "Two or More Races",
            },
        }
        
        df = pd.read_pickle(os.path.join(self.raw_folder, 'acs_income.pkl'))
        df = folktables.adult_filter(df)

        # Compute empirical CDF of PINCP
        x = np.sort(df[target])
        y = np.arange(len(x)) / float(len(x))

        # Partition into bins containing roughly the same number of samples
        partitions = np.array([
            x[np.argmax(y >= q)] for q in np.arange(1 / n_classes, 1, 1 / n_classes)
        ] + [np.inf])

        target_transform = lambda x: np.argmax(
            np.array(x)[:, None] < partitions[None, :], axis=1)

        # Combine RAC1P categories 3, 4, 5, and 6, 7, and 8, 9 into new categories
        # 10, 11, and 12 respectively, due to small sample size in some groups.
        # This is also consistent with the UCI Adult dataset.
        categories['RAC1P'][10.0] = "American Indian or Alaska Native alone"
        categories['RAC1P'][
            11.0] = "Asian, Native Hawaiian or Other Pacific Islander alone"
        categories['RAC1P'][12.0] = "Other"
        df['RAC1P'] = df['RAC1P'].replace([3.0, 4.0, 5.0], 10.0)
        df['RAC1P'] = df['RAC1P'].replace([6.0, 7.0], 11.0)
        df['RAC1P'] = df['RAC1P'].replace([8.0, 9.0], 12.0)

        data, labels, groups = folktables.BasicProblem(
            features=features,
            target=target,
            target_transform=target_transform,
            group='RAC1P',
            postprocess=lambda x: np.nan_to_num(x, -1),
        ).df_to_pandas(df, categories=categories, dummies=True)
        
       
        labels = labels.values.squeeze()
        groups = groups.values.squeeze()

        group_names, groups = np.unique(groups, return_inverse=True)
        group_names = [categories['RAC1P'][n] for n in group_names]
 
        # Split into train and test sets
        data = data.sample(frac=1, random_state=self.seed)
        labels = labels[data.index]
        groups = groups[data.index]

        split_idx = int(len(data) * 0.8)
        train_data, test_data = data.iloc[:split_idx].astype(np.float32), data.iloc[split_idx:].astype(np.float32)
        train_target, test_target = labels[:split_idx].astype(np.int64), labels[split_idx:].astype(np.int64)    
        train_sensitive, test_sensitive = groups[:split_idx].astype(np.int64), groups[split_idx:].astype(np.int64)

        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        classes = list(map(str, list(range(max(train_target) + 1))))
        num_classes = len(classes)
        num_groups = len(np.unique(groups))
        
        self.metadata = {'n_classes':num_classes, 'n_groups':num_groups}

        train_data = train_data.to_numpy()
        test_data = test_data.to_numpy()
        
        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
                self.metadata)

