import numpy as np
import os
import torch
import pandas as pd
from torch.utils.data import Dataset
from module import check_exists, makedir_exist_ok, save, load
from sklearn.preprocessing import MinMaxScaler
import pickle

class Enem(Dataset):
    """ This is a multi-class multi-group datast download from https://www.gov.br/inep/pt-br/acesso-a-informacao/dadosabertos/microdados/enem"""
    data_name = 'Enem'
    
    def __init__(self, root, split, seed):
        self.root = os.path.expanduser(root)
        self.split = split
        self.seed = seed
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target, self.sensitive = load(os.path.join(self.processed_folder, self.split))
        self.other = {}
        self.metadata = load(os.path.join(self.processed_folder, 'meta'))

    def __getitem__(self, index):
        id, data, target, sensitive = torch.tensor(self.id[index]), torch.tensor(self.data[index]), torch.tensor(
            self.target[index]), torch.tensor(self.sensitive[index])
        input = {'id': id, 'data': data, 'target': target, 'sensitive': sensitive}
        other = {k: torch.tensor(self.other[k][index]) for k in self.other}
        input = {**input, **other}
        return input

    def __len__(self):
        return len(self.data)



    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSeed: {}\nSplit: {}\nNClass: {}\nNGroup: {}'.format(self.__class__.__name__, self.__len__(),
                                                                     self.root,
                                                                     self.seed,
                                                                     self.split,
                                                                     self.metadata['n_classes'],
                                                                     self.metadata['n_groups'])
        return fmt_str
    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed', f'seed_{self.seed}')

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')
    

    def process(self):
        if not check_exists(os.path.join(self.raw_folder, 'enem.pkl')):
            self.process_raw(file_path=os.path.join(self.raw_folder, "MICRODADOS_ENEM_2020.csv"))
        train_set, test_set, meta = self.make_data()
        save(train_set, os.path.join(self.processed_folder, 'train'))
        save(test_set, os.path.join(self.processed_folder, 'test'))
        save(meta, os.path.join(self.processed_folder, 'meta'))
        return
    
    def process_raw(self, file_path, multigroup=True):
        ## load csv
        df = pd.read_csv(file_path, encoding='cp860', sep=';')
        # print('Original Dataset Shape:', df.shape)
        grade_attribute = ['NU_NOTA_CH'] ## Labels could be: NU_NOTA_CH=human science, NU_NOTA_LC=languages&codes, NU_NOTA_MT=math, NU_NOTA_CN=natural science
        group_attribute = ['TP_COR_RACA','TP_SEXO']
        question_vars = ['Q00'+str(x) if x<10 else 'Q0' + str(x) for x in range(1,25)] #changed for 2020
        domestic_vars = ['SG_UF_PROVA', 'TP_FAIXA_ETARIA'] 
        features = grade_attribute + group_attribute + question_vars + domestic_vars
        n_sample = 50000
        n_classes = 5
        ## Remove all entries that were absent or were eliminated in at least one exam
        ix = ~df[['TP_PRESENCA_CN', 'TP_PRESENCA_CH', 'TP_PRESENCA_LC', 'TP_PRESENCA_MT']].apply(
                    lambda row: any(x != 1.0 for x in row), axis=1
                    )
        df = df.loc[ix, :]

        ## Remove "treineiros" -- these are individuals that marked that they are taking the exam "only to test their knowledge". It is not uncommon for students to take the ENEM in the middle of high school as a dry run
        df = df.loc[df['IN_TREINEIRO'] == 0, :]

        ## drop eliminated features
        df.drop(['TP_PRESENCA_CN', 'TP_PRESENCA_CH', 'TP_PRESENCA_LC', 'TP_PRESENCA_MT', 'IN_TREINEIRO'], axis=1, inplace=True)

        ## subsitute race by names
        # race_names = ['N/A', 'Branca', 'Preta', 'Parda', 'Amarela', 'Indigena']
        race_names = [np.nan, 'Branca', 'Preta', 'Parda', 'Amarela', 'Indigena']
        # df['TP_COR_RACA'] = df.loc[:, ['TP_COR_RACA']].apply(lambda x: race_names[x], axis = 1).copy()
        df['TP_COR_RACA'] = df['TP_COR_RACA'].map(lambda x: race_names[x]).copy()


        ## remove repeated exam takers
        ## This pre-processing step significantly reduces the dataset.
        df = df.loc[df.TP_ST_CONCLUSAO.isin([1])]

        ## select features
        df = df[features]

        ## Dropping all rows or columns with missing values
        df = df.dropna()

        ## Creating racebin & gradebin & sexbin variable
        df['gradebin'] = construct_grade(df, grade_attribute, n_classes)
        
        if multigroup:
            df['racebin'] = construct_race(df, 'TP_COR_RACA')
        else:
            df['racebin'] =np.logical_or((df['TP_COR_RACA'] == 'Branca').values, (df['TP_COR_RACA'] == 'Amarela').values).astype(int)
        
        df['sexbin'] = (df['TP_SEXO'] == 'M').astype(int)

        df.drop([grade_attribute[0], 'TP_COR_RACA', 'TP_SEXO'], axis=1, inplace=True)

        ## encode answers to questionaires
        ## Q005 is 'Including yourself, how many people currently live in your household?'
        question_vars = ['Q00' + str(x) if x < 10 else 'Q0' + str(x) for x in range(1, 25)]
        for q in question_vars:
            if q != 'Q005':
                df_q = pd.get_dummies(df[q], prefix=q)
                df.drop([q], axis=1, inplace=True)
                df = pd.concat([df, df_q.iloc[:, :-1]], axis=1)
                
        ## check if age range ('TP_FAIXA_ETARIA') is within attributes
        if 'TP_FAIXA_ETARIA' in features:
            q = 'TP_FAIXA_ETARIA'
            df_q = pd.get_dummies(df[q], prefix=q)
            df.drop([q], axis=1, inplace=True)
            df = pd.concat([df, df_q.iloc[:, :-1]], axis=1)

        ## encode SG_UF_PROVA (state where exam was taken)
        df_res = pd.get_dummies(df['SG_UF_PROVA'], prefix='SG_UF_PROVA')
        df.drop(['SG_UF_PROVA'], axis=1, inplace=True)
        df = pd.concat([df, df_res], axis=1)

        df = df.dropna()
        ## Scaling ##
        scaler = MinMaxScaler()
        scale_columns = list(set(df.columns.values) - set(['gradebin', 'racebin']))
        df[scale_columns] = pd.DataFrame(scaler.fit_transform(df[scale_columns]), columns=scale_columns, index=df.index)
        # print('Preprocessed Dataset Shape:', df.shape)
        df = df.sample(n=min(n_sample, df.shape[0]), axis=0, replace=False)

        df.to_pickle(os.path.join(self.raw_folder, 'enem.pkl'))

        return 
    
    def make_data(self):
        df = pd.read_pickle(os.path.join(self.raw_folder, 'enem.pkl'))
        df = df.sample(n = int(5e4), random_state=self.seed)
        df.reset_index(inplace=True, drop=True)

        df['gradebin'] = df['gradebin'].astype(int)

        # get sensitive feature and convert to numpy
        # 'Branca': 0, 'Preta': 1, 'Parda': 2, 'Amarela': 3, 'Indigena':4
        sensitive = df['racebin'].to_numpy()
        # 5 bins for grade
        target = df["gradebin"].to_numpy()
        df.drop(["gradebin"], axis=1, inplace=True)
        df = df.to_numpy()
        
        split_idx = int(0.8 * len(df))

        train_data, test_data = df[:split_idx].astype(np.float32), df[split_idx:].astype(np.float32)

        # get sensitive feature
        train_sensitive = sensitive[:split_idx].astype(np.int64)
        test_sensitive = sensitive[split_idx:].astype(np.int64)
        train_target, test_target = target[:split_idx].astype(np.int64), target[split_idx:].astype(np.int64)
        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        classes = list(map(str, list(range(max(train_target) + 1))))
        num_classes = len(classes)
        num_groups = len(np.unique(sensitive))
        self.metadata = {'n_classes': num_classes, 'n_groups': num_groups}
        return ((train_id, train_data, train_target, train_sensitive),
                (test_id, test_data, test_target, test_sensitive),
                self.metadata)




def construct_grade(df, grade_attribute, n):
    v = df[grade_attribute[0]].values
    quantiles = np.nanquantile(v, np.linspace(0.0, 1.0, n+1))
    return pd.cut(v, quantiles, labels=np.arange(n))

def construct_race(df, protected_attribute):
    race_dict = {'Branca': 0, 'Preta': 1, 'Parda': 2, 'Amarela': 3, 'Indigena': 4} # changed to match ENEM 2020 numbering
    return df[protected_attribute].map(race_dict)