import os
import torch
import numpy  as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder
from sklearn.model_selection import train_test_split

np.random.seed(0)

FOLDER_PATH = '.'

class TabularDataset():
    def __init__(self, dataset, cont_method='minmax', categ_method='label', y_method='label'):
        """
        Arguments
        - dataset: dataset name
        - cont_method: [ minmax, label, onehot, raw ]
        - categ_method: [ minmax, label, onehot, raw ]
        - y_method: [ minmax, label, onehot, raw ]
        """

        load_data = {
            "income": self.process_income,
        }
        assert dataset in load_data.keys()

        preprocess_method = {
            "minmax": MinMaxScaler(), 
            "label": LabelEncoder(),
            "onehot": OneHotEncoder(sparse_output=False),
            "raw": None
        }
        assert cont_method.lower() in preprocess_method.keys()
        assert categ_method.lower() in preprocess_method.keys()
        assert y_method.lower() in preprocess_method.keys()

        # get data and index of categorical/continuous
        self.train, self.val, self.test, self.categ_index, self.cont_index = load_data[dataset]()
        
        # index
        last_index = self.train.shape[-1]-1
        self.y_index = [ last_index ]
        if last_index in self.categ_index:
            self.categ_index.remove(last_index)
        if last_index in self.cont_index:
            self.cont_index.remove(last_index)

        # get encoder (or scaler)
        self.cont_encoder = preprocess_method[cont_method.lower()]
        self.categ_encoder = preprocess_method[categ_method.lower()]
        self.y_encoder = preprocess_method[y_method.lower()]
        self.categ_method = categ_method.lower()

        # raw column
        if cont_method.lower() == 'raw': self.cont_index = []
        if categ_method.lower() == 'raw': self.categ_index = []
        if y_method.lower() == 'raw': self.y_index = []
        
        # preprocessing
        self.train_x, self.train_y, self.val_x, self.val_y, \
            self.test_x, self.test_y, self.categ_dims = self.preprocessing()

    def get_datas(self, seperate_y=False):
        if seperate_y:
            return self.train_x, self.train_y, self.val_x, self.val_y, self.test_x, self.test_y
        else:
            return ( np.concatenate([self.train_x, self.train_y], axis=1), 
                     np.concatenate([self.val_x, self.val_y], axis=1),
                     np.concatenate([self.test_x, self.test_y], axis=1) )
    
    def get_index(self):
        return self.categ_index, self.cont_index
    
    def get_categ_dims(self):
        return self.categ_dims
    
    def preprocessing(self):
        train_encoded_data = []
        val_encoded_data = []
        test_encoded_data = []
        
        categ_dims = []
        for i in range(self.train.shape[-1]):
            train_curr = self.train[:, i].copy().reshape(-1, 1)
            val_curr = self.val[:, i].copy().reshape(-1, 1)
            test_curr = self.test[:, i].copy().reshape(-1, 1)

            if i in self.y_index:
                train_encoded = self.y_encoder.fit_transform(train_curr)
                val_encoded = self.y_encoder.transform(val_curr)
                test_encoded = self.y_encoder.transform(test_curr)
            
            elif i in self.categ_index:
                train_encoded = self.categ_encoder.fit_transform(train_curr)
                val_encoded = self.categ_encoder.transform(val_curr)
                test_encoded = self.categ_encoder.transform(test_curr)            
                categ_dims.append(len(self.categ_encoder.classes_))

            elif i in self.cont_index:
                train_encoded = self.cont_encoder.fit_transform(train_curr)
                val_encoded = self.cont_encoder.transform(val_curr)
                test_encoded = self.cont_encoder.transform(test_curr)  

            else: # raw
                train_encoded = train_curr
                val_encoded = val_curr
                test_encoded = test_curr

            train_encoded_data.append(train_encoded.reshape(len(self.train), -1))
            val_encoded_data.append(val_encoded.reshape(len(self.val), -1))
            test_encoded_data.append(test_encoded.reshape(len(self.test), -1))
        
        train = np.concatenate(train_encoded_data, axis=1)
        val = np.concatenate(val_encoded_data, axis=1)
        test = np.concatenate(test_encoded_data, axis=1)

        train_x, train_y = train[:, :-1], train[:, -1:]
        val_x, val_y = val[:, :-1], val[:, -1:]
        test_x, test_y = test[:, :-1], test[:, -1:]

        return train_x, train_y, val_x, val_y, test_x, test_y, categ_dims

    def process_income(self):
        RANDOMSEED = 1
        train = pd.read_csv(os.path.join(FOLDER_PATH,"data/income_train.csv"))
        test = pd.read_csv(os.path.join(FOLDER_PATH, "data/income_test.csv"))

        train = np.array(train.dropna(axis=0))
        test = np.array(test.dropna(axis=0))

        train, val = train[int(len(test)/2):], train[:int(len(test)/2)]

        categorical = [1, 3, 5, 6, 7, 8, 9, 13, 14]
        continuous = list(set(list(range(train.shape[-1]))[:-1]) - set(categorical))
        
        return train, val, test, categorical, continuous
    