import numpy as np
import os
import torch
import pandas as pd
from torch.utils.data import Dataset
from module import check_exists, save, load
from sklearn.preprocessing import MinMaxScaler


class Hsls2C(Dataset):
    data_name = 'Hsls2C'
    
    def __init__(self, root, split, seed):
        self.root = os.path.expanduser(root)
        self.split = split
        self.seed = seed
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))
    
    def __getitem__(self, index):
        id, data, target, sensitive = torch.tensor(self.id[index]), torch.tensor(self.data[index]), torch.tensor(
            self.target[index]), torch.tensor(self.sensitive[index])
        input = {'id': id, 'data': data, 'target': target, 'sensitive': sensitive}
        other = {k: torch.tensor(self.other[k][index]) for k in self.other}
        input = {**input, **other}
        return input
    
    def __len__(self):
        return len(self.data)
    

    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNClass: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                    self.root,
                                                                    self.seed,
                                                                    self.split,
                                                                    self.metadata['n_classes'],
                                                                    self.metadata['n_groups'])
        return fmt_str


    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'seed_{self.seed}')

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')
    
    def process(self):
        train_set, test_set, meta = self.make_data()
        save(train_set, os.path.join(self.processed_folder, 'train'))
        save(test_set, os.path.join(self.processed_folder, 'test'))
        save(meta, os.path.join(self.processed_folder, 'meta'))
        return
    
    def make_data(self):
        df = pd.read_pickle(os.path.join(self.raw_folder, 'hsls_df_knn_impute_past_v2.pkl'))
        ## Setting NaNs to out-of-range entries
        ## entries with values smaller than -7 are set as NaNs
        df[df <= -7] = np.nan

        ## Dropping all rows or columns with missing values
        ## this step significantly reduces the number of samples
        df = df.dropna()

        ## Creating racebin & gradebin & sexbin variables
        ## X1SEX: 1 -- Male, 2 -- Female, -9 -- NaN -> Preprocess it to: 0 -- Female, 1 -- Male, drop NaN
        ## X1RACE: 0 -- BHN, 1 -- WA
        df['gradebin'] = df['grade9thbin']
        df['racebin'] = np.logical_or(((df['studentrace']*7).astype(int)==7).values, ((df['studentrace']*7).astype(int)==1).values).astype(int)
        df['sexbin'] = df['studentgender'].astype(int)


        ## Dropping race and 12th grade data just to focus on the 9th grade prediction ##
        df = df.drop(columns=['studentgender', 'grade9thbin', 'grade12thbin', 'studentrace'])

        ## Balancing data to have roughly equal race=0 and race =1 ##
        df = df.sample(frac=1, random_state=self.seed)
        df.reset_index(inplace=True, drop=True)

        sensitive = df['racebin'].to_numpy()
        target = df["gradebin"].to_numpy()
        df.drop(["gradebin"], axis=1, inplace=True)
        df = df.to_numpy()

        split_idx = int(0.8 * len(df))

        train_data, test_data = df[:split_idx].astype(np.float32), df[split_idx:].astype(np.float32)

        # get sensitive feature
        train_sensitive = sensitive[:split_idx].astype(np.int64)
        test_sensitive = sensitive[split_idx:].astype(np.int64)
        train_target, test_target = target[:split_idx].astype(np.int64), target[split_idx:].astype(np.int64)
        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        classes = list(map(str, list(range(max(train_target) + 1))))
        num_classes = len(classes)
        num_groups = len(np.unique(sensitive))
        self.metadata = {'n_classes': num_classes, 'n_groups': num_groups}
        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
                self.metadata)
    


class Hsls4C(Hsls2C):
    data_name = 'Hsls4C'
    def make_data(self):
        df = pd.read_pickle(os.path.join(self.raw_folder, 'hsls_df_knn_impute_past_v2.pkl'))
        ## Setting NaNs to out-of-range entries
        ## entries with values smaller than -7 are set as NaNs
        df[df <= -7] = np.nan

        ## Dropping all rows or columns with missing values
        ## this step significantly reduces the number of samples
        df = df.dropna()

        ## Creating racebin & gradebin & sexbin variables
        ## X1SEX: 1 -- Male, 2 -- Female, -9 -- NaN -> Preprocess it to: 0 -- Female, 1 -- Male, drop NaN
        ## X1RACE: 0 -- BHN, 1 -- WA
        df['gradebin'] =  df['S1M8GRADE'].dropna().round(1).map({0.0:0, 0.2:1, 0.4:3, 0.6:4, 0.8:4, 1.0:4})
        df['racebin'] = np.logical_or(((df['studentrace']*7).astype(int)==7).values, ((df['studentrace']*7).astype(int)==1).values).astype(int)
        df['sexbin'] = df['studentgender'].astype(int)


        df = df.drop(columns=['S1M8GRADE', 'studentgender', 'grade12thbin', 'studentrace'])

        df = df.sample(frac=1, random_state=self.seed)
        df.reset_index(inplace=True, drop=True)

        sensitive = df['racebin'].to_numpy()
        target = df["gradebin"].to_numpy()
        df.drop(["gradebin"], axis=1, inplace=True)
        df = df.to_numpy()

        split_idx = int(0.8 * len(df))

        train_data, test_data = df[:split_idx].astype(np.float32), df[split_idx:].astype(np.float32)

        # get sensitive feature
        train_sensitive = sensitive[:split_idx].astype(np.int64)
        test_sensitive = sensitive[split_idx:].astype(np.int64)
        train_target, test_target = target[:split_idx].astype(np.int64), target[split_idx:].astype(np.int64)
        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        print(train_target)
        classes = list(map(str, list(range(max(train_target) + 1))))
        num_classes = len(classes)
        num_groups = len(np.unique(sensitive))
        self.metadata = {'n_classes': num_classes, 'n_groups': num_groups}
        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
                self.metadata)