import numpy as np
import os
import torch
import pandas as pd
import requests
from torch.utils.data import Dataset
import sklearn.preprocessing as preprocessing
from module import check_exists, makedir_exist_ok, save, load


# Probability for the label '>50K' : 23.93% / 24.78% (without unknowns)
# Probability for the label '<=50K' : 76.07% / 75.22% (without unknowns)
# 48842 instances, mix of continuous and discrete (train=32561, test=16281)
# 45222 if instances with unknown values are removed (train=30162, test=15060)


# target: income (>50K, <=50K)
# sensitive: gender
class Adult(Dataset):
    data_name = 'Adult'

    def __init__(self, root, split, seed):
        self.root = os.path.expanduser(root)
        self.split = split
        self.seed = seed
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))

    def __getitem__(self, index):
        id, data, target, sensitive = torch.tensor(self.id[index]), torch.tensor(self.data[index]), torch.tensor(
            self.target[index]), torch.tensor(self.sensitive[index])
        input = {'id': id, 'data': data, 'target': target, 'sensitive': sensitive}
        other = {k: torch.tensor(self.other[k][index]) for k in self.other}
        input = {**input, **other}
        return input

    def __len__(self):
        return len(self.data)

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'seed_{self.seed}')

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')

    def process(self):
        if not check_exists(self.raw_folder):
            self.download()
        train_set, test_set, meta = self.make_data()
        save(train_set, os.path.join(self.processed_folder, 'train'))
        save(test_set, os.path.join(self.processed_folder, 'test'))
        save(meta, os.path.join(self.processed_folder, 'meta'))
        return

    def download(self):
        makedir_exist_ok(self.raw_folder)
        train_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'
        test_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test'
        train_response = requests.get(train_url)
        test_response = requests.get(test_url)
        for res in [train_response, test_response]:
            filename = os.path.basename(res.url)
            full_path = os.path.join(self.raw_folder, filename)
            if res.status_code == 200:
                # Open a local file in binary write mode and write the content from the response
                with open(full_path, "wb") as file:
                    file.write(res.content)
                print("File downloaded successfully.")

        return

    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNClass: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root,
                                                                     self.seed,
                                                                     self.split,
                                                                     self.metadata['n_classes'],
                                                                     self.metadata['n_groups'])
        return fmt_str

    def preprocess_data(self, X):
        # drop fnlwgt which is not related to the target and duplicated education
        age_bins = np.array([18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
        X = X.reset_index(drop=True)
        X.loc[:,'age'] = age_bins.searchsorted(X.loc[:,'age'])
        X['age'] = X['age'].astype('object')
        X = X.drop(columns=['fnlwgt','education'], axis = 1)
        y = X['income']
        X = X.iloc[:, :-1]
        # for numeric features
        col_ids = X.select_dtypes(include=['int']).columns
        y = y.replace('<=50K', 0).replace('>50K', 1).replace('<=50K.', 0).replace('>50K.', 1)
        y = y.to_numpy()
        le = preprocessing.OneHotEncoder(sparse=False, drop='first')
        col_ids = pd.DataFrame.select_dtypes(X, include=['object']).columns
        onhot_cols = le.fit_transform(X[col_ids])
        encoded_df = pd.DataFrame(onhot_cols, columns=le.get_feature_names_out(col_ids))
        X = pd.concat([X.drop(col_ids, axis=1), encoded_df], axis=1)
        X = X.astype(np.float32)
        # dropping column 62 (native-country_Holand-Netherlands) for extreme imbalance
        # X.drop(X.columns[62], axis=1, inplace=True)

        return X, y

    def make_data(self):
        X = pd.read_csv(os.path.join(self.raw_folder, 'adult.data'), engine= 'python', header=None, sep=r'\s*,\s*', na_values='?')
        X_test = pd.read_csv(os.path.join(self.raw_folder, 'adult.test'), engine= 'python', header=None, sep=r'\s*,\s*', na_values='?', skiprows=1)


        columnNames = ["age", "workclass", "fnlwgt", "education", "education-num",
                       "marital-status", "occupation", "relationship", "race", "gender",
                       "capital-gain", "capital-loss", "hours-per-week", "native-country", "income"]

        X = X.dropna()
        X_test = X_test.dropna()
        train_index = len(X)
        X_all = pd.concat([X,X_test], axis = 0)
        X_all.columns = columnNames
        X_all = X_all.sample(frac=1, random_state=self.seed).reset_index(drop=True)
        # preprocess data
        X_all, y_all = self.preprocess_data(X_all)
        X_train, y_train = X_all[:train_index], y_all[:train_index]
        X_test, y_test = X_all[train_index:], y_all[train_index:]

        train_data, test_data = X_train, X_test
        train_target, test_target = y_train.astype(np.int64), y_test.astype(np.int64)
        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        classes = list(map(str, list(range(max(train_target) + 1))))
        num_classes = len(classes)
        # convert to numpy
        train_sensitive = train_data['gender_Male'].to_numpy()
        test_sensitive = test_data['gender_Male'].to_numpy()
        num_groups = len(np.unique(X_all['gender_Male']))

        train_data = train_data.to_numpy()
        test_data = test_data.to_numpy()
        self.metadata = {'n_classes': num_classes, 'n_groups': num_groups}

        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
                self.metadata)


class AdultM(Adult):
    data_name = 'AdultM'

    def __init__(self, root, split, seed, num_groups):
        self.root = os.path.expanduser(root)
        self.split = split
        self.seed = seed
        self.num_groups = num_groups
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))

    def preprocess_data(self, X):
        age_bins = np.array([18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
        X = X.reset_index(drop=True)
        age_num = X['age'].copy()
        X.loc[:,'age'] = age_bins.searchsorted(X.loc[:,'age'])
        X['age'] = X['age'].astype('object')
        X = X.drop(columns=['fnlwgt','education'], axis = 1)
        y = X['income']
        X = X.iloc[:, :-1]
        if self.num_groups == 2:
            sens = X['gender'].astype('category').cat.codes
            assert len(np.unique(sens)) == 2, f"num_groups: {self.num_groups}, n_age_groups: 2"
        elif self.num_groups == 10:
            sens = X['gender'].str.cat(X['race'], sep='_').astype('category').cat.codes
            assert len(np.unique(sens)) == 10, f"num_groups: {self.num_groups}, n_age_groups: 10"
        elif self.num_groups >= 10:
            n_age_groups = int(self.num_groups / 10)
            # use qcut to create equal sized bins
            age_bins = pd.qcut(age_num, q = n_age_groups, labels = [f'age{i}' for i in range(n_age_groups)])
            sens = X['gender'].str.cat(X['race'], sep='_').str.cat(age_bins, sep='_').astype('category').cat.codes
            assert len(np.unique(sens)) == n_age_groups * 10, f"num_groups: {self.num_groups}, n_age_groups: {n_age_groups}, sens: {len(np.unique(sens))}"
        else:
            raise ValueError(f'num_groups should be 2, 6 or >= 6, but got {self.num_groups}')
        # for numeric features
        # col_ids = X.select_dtypes(include=['int']).columns
        y = y.replace('<=50K', 0).replace('>50K', 1).replace('<=50K.', 0).replace('>50K.', 1)
        y = y.to_numpy()
        le = preprocessing.OneHotEncoder(sparse=False, drop='first')
        col_ids = pd.DataFrame.select_dtypes(X, include=['object']).columns
        onhot_cols = le.fit_transform(X[col_ids])
        encoded_df = pd.DataFrame(onhot_cols, columns=le.get_feature_names_out(col_ids))
        X = pd.concat([X.drop(col_ids, axis=1), encoded_df], axis=1)
        X = X.astype(np.float32)
        # dropping column 62 (native-country_Holand-Netherlands) for extreme imbalance
        # X.drop(X.columns[62], axis=1, inplace=True)

        return X, y, sens


    def make_data(self):
        X = pd.read_csv(os.path.join(self.raw_folder, 'adult.data'), engine= 'python', header=None, sep=r'\s*,\s*', na_values='?')
        X_test = pd.read_csv(os.path.join(self.raw_folder, 'adult.test'), engine= 'python', header=None, sep=r'\s*,\s*', na_values='?', skiprows=1)


        columnNames = ["age", "workclass", "fnlwgt", "education", "education-num",
                       "marital-status", "occupation", "relationship", "race", "gender",
                       "capital-gain", "capital-loss", "hours-per-week", "native-country", "income"]

        X = X.dropna()
        X_test = X_test.dropna()
        train_index = len(X)
        X_all = pd.concat([X,X_test], axis = 0)
        X_all.columns = columnNames
        X_all = X_all.sample(frac=1, random_state=self.seed).reset_index(drop=True)
        
        # preprocess data
        X_all, y_all, sens_all = self.preprocess_data(X_all)
        X_train, y_train, sens_train = X_all[:train_index], y_all[:train_index], sens_all[:train_index]
        X_test, y_test, sens_test = X_all[train_index:], y_all[train_index:], sens_all[train_index:]

        train_data, test_data = X_train, X_test
        train_target, test_target = y_train.astype(np.int64), y_test.astype(np.int64)
        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        classes = list(map(str, list(range(max(train_target) + 1))))
        num_classes = len(classes)
        num_groups = sens_all.nunique()

        train_data = train_data.to_numpy()
        test_data = test_data.to_numpy()
        train_sensitive = sens_train.to_numpy().astype(np.int64)
        test_sensitive = sens_test.to_numpy().astype(np.int64)
        
        self.metadata = {'n_classes': num_classes, 'n_groups': num_groups}

        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
                self.metadata)

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'groups_{self.num_groups}', f'seed_{self.seed}')

    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNClass: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root,
                                                                     self.seed,
                                                                     self.split,
                                                                     self.metadata['n_classes'],
                                                                     self.metadata['n_groups'])
        return fmt_str
