import numpy as np
import os
import torch
import pandas as pd
import requests
from torch.utils.data import Dataset
from sklearn.preprocessing import OrdinalEncoder
from module import check_exists, makedir_exist_ok, save, load

# target: GPA (normalized to 0 and 1)
# sensitive: race (white, non-white)
class LawSchool(Dataset):
    data_name = 'LawSchool'
    def __init__(self, root, split, seed):
        self.root = os.path.expanduser(root)
        self.split = split
        self.seed = seed
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))
    def __getitem__(self, index):
        id, data, target, sensitive = torch.tensor(self.id[index]), torch.tensor(self.data[index]), torch.tensor(
            self.target[index]), torch.tensor(self.sensitive[index])
        input = {'id': id, 'data': data, 'target': target, 'sensitive': sensitive}
        other = {k: torch.tensor(self.other[k][index]) for k in self.other}
        input = {**input, **other}
        return input
    def __len__(self):
            return len(self.data)
    
    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root,
                                                                     self.seed,
                                                                     self.split,
                                                                     self.metadata['n_groups'])
        return fmt_str

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'seed_{self.seed}')
    
    @property
    def raw_folder(self):
        return os.path.join(self.root,'raw')
    
    def process(self):
        # data is downloaded in raw folder
        train_set, test_set, meta = self.make_data()
        save(train_set, os.path.join(self.processed_folder, 'train'))
        save(test_set, os.path.join(self.processed_folder, 'test'))
        save(meta, os.path.join(self.processed_folder, 'meta'))
        return


    def make_data(self):
        data = pd.read_csv(os.path.join(self.raw_folder, 'bar_pass_prediction.csv'))
        include_cols = ['decile3', 'decile1', 'lsat', 'gpa', 'grad', 'fulltime', 'fam_inc',
                        'gender', 'pass_bar', 'tier', 'indxgrp', 'indxgrp2',
                        'dnn_bar_pass_prediction', 'race1']
        
        # only include those columns based on EDA from
        # https://www.kaggle.com/code/eds8531/lsac-dataset-eda-and-predictions/notebook
        data = data[include_cols]
        data = data.dropna()
        

        # normalize gpa to [0,1]
        data['gpa'] = (data['gpa'] - data['gpa'].min()) / (data['gpa'].max() - data['gpa'].min())
        # rename race1 to race and make it binary
        data = data.rename(columns={'race1': 'race'})
        # data.loc[:, 'race'] = data['race'].apply(lambda x: 1 if x == 'white' else 0)
        race = ['asian', 'black', 'hisp', 'white']
        data = data[data['race'].isin(race)]

        # encode categorical columns
        col_ids = pd.DataFrame.select_dtypes(data, include=['object']).columns
        encoder = OrdinalEncoder()
        data[col_ids] = encoder.fit_transform(data[col_ids])
       
        # drop invariant columns
        data = data.loc[:,data.apply(pd.Series.nunique) != 1]
        data = data.sample(frac=1, random_state=self.seed)
        data.reset_index(inplace=True, drop=True)


        # get sensitive feature and convert to numpy
        sensitive = data['race'].to_numpy()
        target = data["gpa"].to_numpy().reshape(-1,1)
        data.drop(["gpa"], axis=1, inplace=True)
        data = data.to_numpy()
        split_idx = int(0.8 * len(data))

        train_data, test_data = data[:split_idx].astype(np.float32), data[split_idx:].astype(np.float32)

        # get sensitive feature
        train_sensitive = sensitive[:split_idx].astype(np.int64)
        test_sensitive = sensitive[split_idx:].astype(np.int64)

        train_target, test_target = target[:split_idx].astype(np.float32), target[split_idx:].astype(np.float32)
        
        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        n_groups = len(np.unique(sensitive))
        self.metadata = {'n_groups': n_groups, 'n_classes': 1}
        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
                self.metadata)



