import numpy as np
import os
import pickle
import requests
import torch
import pandas as pd
from torch.utils.data import Dataset
from module import check_exists, makedir_exist_ok, save, load
from sklearn.preprocessing import StandardScaler


# target: recidivism (0, 1)
# sensitive: race (African-American, Caucasian, Hispanic)
class Compas(Dataset):
    data_name = 'Compas'

    def __init__(self, root, split, seed):
        self.root = os.path.expanduser(root)
        self.split = split
        self.seed = seed
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))

    def __getitem__(self, index):
        id, data, target, sensitive = torch.tensor(self.id[index]), torch.tensor(self.data[index]), torch.tensor(
            self.target[index]), torch.tensor(self.sensitive[index])
        input = {'id': id, 'data': data, 'target': target, 'sensitive': sensitive}
        other = {k: torch.tensor(self.other[k][index]) for k in self.other}
        input = {**input, **other}
        return input


    def __len__(self):
        return len(self.data)

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'seed_{self.seed}')

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')

    def process(self):
        if not check_exists(self.raw_folder):
            self.download()
        train_set, test_set, meta = self.make_data()
        save(train_set, os.path.join(self.processed_folder, 'train'))
        save(test_set, os.path.join(self.processed_folder, 'test'))
        save(meta, os.path.join(self.processed_folder, 'meta'))
        return



    def download(self):
        makedir_exist_ok(self.raw_folder)
        file_url = 'https://raw.githubusercontent.com/HsiangHsu/Fair-Projection/main/data/COMPAS/compas-scores-two-years.csv'
        file_response = requests.get(file_url)
        filename = os.path.basename(file_url)
        full_path = os.path.join(self.raw_folder, filename)
        if file_response.status_code == 200:
            # Open a local file in binary write mode and write the content from the response
            with open(full_path, "wb") as file:
                file.write(file_response.content)
            print("File downloaded successfully.")
    

    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNClass: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root,
                                                                     self.seed,
                                                                     self.split,
                                                                     self.metadata['n_classes'],
                                                                     self.metadata['n_groups'])
        return fmt_str

    def make_data(self):
        # read from csv
        df = pd.read_csv(os.path.join(self.raw_folder, 'compas-scores-two-years.csv'), index_col=0)

        df = df[['age', 'c_charge_degree', 'race', 'sex', 'priors_count',
                'days_b_screening_arrest', 'is_recid', 'c_jail_in', 'c_jail_out', 'c_days_from_compas',
                'v_decile_score', 'is_violent_recid']]
        
        # drop missing/bad features (following ProPublica's analysis)
        # ix is the index of variables we want to keep.
        # Remove entries with inconsistent arrest information.
        ix = df['days_b_screening_arrest'] <= 30
        ix = (df['days_b_screening_arrest'] >= -30) & ix
        # remove entries where compas case could not be found.
        ix = (df['is_recid'] != -1) & ix
        # remove traffic offenses.
        ix = (df['c_charge_degree'] != "O") & ix

        # trim dataset
        df = df.loc[ix, :]

        # create new attribute "length of stay" with total jail time.
        df['length_of_stay'] = (pd.to_datetime(df['c_jail_out']) - pd.to_datetime(df['c_jail_in'])).apply(
            lambda x: x.days)
        dropCol = ['c_jail_in', 'c_jail_out', 'days_b_screening_arrest']
        df.drop(dropCol, inplace=True, axis=1)

        # African-American: 0, Caucasian: 1
        # keep only African-American and Caucasian
        df = df.loc[df['race'].isin(['African-American', 'Caucasian']), :]
        df.loc[:,'race'] = df['race'].apply(lambda x: 1 if x=='Caucasian' else 0)
        # binarize gender
        # Female: 1, Male: 0
        df.loc[:, 'sex'] = df['sex'].apply(lambda x: 1 if x == 'Male' else 0)
        # rename columns 'sex' to 'gender'
        df.rename(index=str, columns={"sex": "gender"}, inplace=True)
        # binarize degree charged
        # Misd. = -1, Felony = 1
        df.loc[:, 'c_charge_degree'] = df['c_charge_degree'].apply(lambda x: 1 if x == 'F' else -1)
        # reset index
        df.reset_index(inplace=True, drop=True)

        
        # permutation 
        perm = np.random.default_rng(seed=self.seed).permutation(len(df))
        df = df.iloc[perm]
        
        # get sensitive feature
        sensitive = df['race'].to_numpy()
        target = df["is_recid"].to_numpy()
        df.drop(["is_recid"], axis=1, inplace=True)
        data = df.to_numpy()

        split_idx = int(data.shape[0] * 0.8)
        train_data, test_data = data[:split_idx].astype(np.float32), data[split_idx:].astype(np.float32)

        # get sensitive feature
        train_sensitive = sensitive[:split_idx].astype(np.int64)
        test_sensitive = sensitive[split_idx:].astype(np.int64)

        train_target, test_target = target[:split_idx].astype(np.int64), target[split_idx:].astype(np.int64)
        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)

        classes = list(map(str, list(range(max(train_target) + 1))))
        num_classes = len(classes)
        num_groups = len(np.unique(sensitive))

        self.metadata = {'n_classes': num_classes, 'n_groups': num_groups}
        # save as pandas DataFrame
        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
                self.metadata)