import numpy as np
import pandas as pd
import torch
import pickle
from scipy import sparse
from sklearn import preprocessing
from sklearn.preprocessing import OneHotEncoder, KBinsDiscretizer
from torch_geometric.data import Data
import tqdm
from tqdm import trange

def preprocess_adult(missing_rate=0.0, initial_filling=None):
    # Make dataframe
    df_train = pd.read_csv('./data/adult/raw/adult.data', delimiter = ', ', header=None)
    df_test = pd.read_csv('./data/adult/raw/adult.test', delimiter = ', ', header=None)
    df = pd.concat([df_train, df_test], ignore_index=True)

    # Preprocess y
    df = df.replace(to_replace='<=50K.', value='<=50K')
    df = df.replace(to_replace='>50K.', value='>50K')
    df_y = df.iloc[:,-1][:] # copy the labels column
    df = df.iloc[:,:-1]

    # Formatting missing values
    df = df.replace(to_replace='?', value=np.nan)

    # Drop row & labels with NaNs
    df, df_y = drop_row_col_labels(df, df_y)
    print('Initial Missing Rate (%):', (df.isnull().sum().sum()*100) / (df.shape[0] * df.shape[1]))

    # Add random mask
    missing_feature_mask = torch.ones(df.shape, dtype=bool)
    if missing_rate > 0:
        df, missing_feature_mask = add_random_mask(df, missing_rate)
        print('Final Missing Rate (%):', (df.isnull().sum().sum()*100) / (df.shape[0] * df.shape[1]))

    # Numerical -> scale || Categorical -> one-hot encode
    data = process_num_cat_cols(df, df_y, initial_filling)

    return data, missing_feature_mask

def preprocess_abide(missing_rate=0.0, initial_filling=None):
    # Make dataframe
    df = pd.read_csv('./data/abide/raw/abide.csv', header=0)

    # Preprocess y
    df_y = df['DX_GROUP']
    df.drop(['DX_GROUP'], axis='columns', inplace=True)

    # Formatting missing values
    # -> Already given as NaN!

    # Drop row & labels with NaNs
    df, df_y = drop_row_col_labels(df, df_y)
    print('Initial Missing Rate (%):', (df.isnull().sum().sum()*100) / (df.shape[0] * df.shape[1]))

    # Add random mask
    missing_feature_mask = torch.ones(df.shape, dtype=bool)
    if missing_rate > 0:
        df, missing_feature_mask = add_random_mask(df, missing_rate)
        print('Final Missing Rate (%):', (df.isnull().sum().sum()*100) / (df.shape[0] * df.shape[1]))

    # Numerical -> scale || Categorical -> one-hot encode
    data = process_num_cat_cols(df, df_y, initial_filling)

    return data, missing_feature_mask

def preprocess_adni(missing_rate=0.0, initial_filling=None):
    # Make dataframe
    df = pd.read_csv('./data/adni/raw/adni.csv', header=0)

    # Preprocess y
    df_y = df['DX_bl']
    df.drop(['DX_bl'], axis='columns', inplace=True)

    # Formatting missing values
    # -> Already given as NaN!

    # Drop row & labels with NaNs
    df, df_y = drop_row_col_labels(df, df_y)
    print('Initial Missing Rate (%):', (df.isnull().sum().sum()*100) / (df.shape[0] * df.shape[1]))

    # Add random mask
    missing_feature_mask = torch.ones(df.shape, dtype=bool)
    if missing_rate > 0:
        df, missing_feature_mask = add_random_mask(df, missing_rate)
        print('Final Missing Rate (%):', (df.isnull().sum().sum()*100) / (df.shape[0] * df.shape[1]))

    # Numerical -> scale || Categorical -> one-hot encode
    data = process_num_cat_cols(df, df_y, initial_filling)

    return data, missing_feature_mask


def process_num_cat_cols(_df, _df_y, initial_filling):
    # print('Process numerical & categorical columns...')
    df = _df.copy()
    df_y = _df_y.copy()

    n_col = df.shape[1]
    numerical_col_idx = np.arange(n_col)[(df.dtypes == int) | (df.dtypes == float)]
    categorical_col_idx = list(set(np.arange(n_col)) - set(numerical_col_idx))
    num_missing_cols = []
    cat_missing_cols = []

    ## For numerical columns
    if len(numerical_col_idx) > 0:
        print('Processing Numerical columns...')
        df_num = df[numerical_col_idx][:] 
        min_max_scaler = preprocessing.MinMaxScaler()
        x_scaled = min_max_scaler.fit_transform(df_num)
        df_num = pd.DataFrame(x_scaled)
        
        if df_num.isnull().sum().sum() > 0:
            num_missing_cols = np.arange(df_num.shape[1])[df_num.isnull().sum() != 0]
            if initial_filling == 'mode':
                df_num = df_num.fillna(df_num.mode().loc[0])
            elif initial_filling == 'median':
                df_num = df_num.fillna(df_num.median().loc[0])

    # kbins = KBinsDiscretizer(n_bins=4, encode='ordinal', strategy='quantile')
    # df_num_ = pd.DataFrame(kbins.fit_transform(df_num)).astype('category')
    # df_num = pd.get_dummies(df_num_)

    # if len(num_missing_cols) > 0:
    #     for i in num_missing_cols:
    #         idx = np.where(df_num.iloc[:,df_num.columns.str.startswith(f'{i}_')].sum(1) == 0)
    #         df_num.loc[idx[0], df_num.columns.str.startswith(f'{i}_')] = np.nan

    ## For categorical columns
    if len(categorical_col_idx) > 0:
        print('Processing Categorical columns...') 
        df_cat_ = df[categorical_col_idx][:]
        df_cat_ = df_cat_.loc[:,df_cat_.describe().loc['unique'] < 1000] # filter columns to prevent dimension exploding

        if df_cat_.isnull().sum().sum() > 0:
            cat_missing_cols = list(df_cat_.columns[df_cat_.isnull().sum() > 0])
            if initial_filling != 'None':
                df_cat_ = df_cat_.fillna(df_cat_.mode().loc[0])
            
        df_cat = pd.get_dummies(df_cat_)

        if len(cat_missing_cols) > 0:
            for j in cat_missing_cols:
                idx = np.where(df_cat.iloc[:,df_cat.columns.str.startswith(f'{j}_')].sum(1) == 0)
                df_cat.loc[idx[0], df_cat.columns.str.startswith(f'{j}_')] = np.nan

    if (len(numerical_col_idx) != 0) & (len(categorical_col_idx) != 0):
        df_final = pd.concat([df_num, df_cat], axis=1)
    elif len(numerical_col_idx) == 0:
        df_final = df_cat
    elif len(categorical_col_idx) == 0:
        df_final = df_num
    
    # Return data
    x = torch.FloatTensor(np.array(df_final))
    y = torch.LongTensor(pd.factorize(df_y)[0])
    bins_for_each_col = np.array(list(np.ones(len(numerical_col_idx))) + [df_cat.columns.str.startswith(f'{i}_').sum() for i in categorical_col_idx])
    num_cat_idx = [len(numerical_col_idx), df_final.shape[1] - len(numerical_col_idx)]
    
    data = Data()
    data.x = x
    data.y = y
    data.bins_for_each_col = bins_for_each_col
    data.num_cat_idx = num_cat_idx

    return data

def add_random_mask(_df, missing_rate):
    print('Applying Random Masking...')
    df_og = _df.copy()

    missing_mask = torch.bernoulli(torch.Tensor([missing_rate]).repeat(_df.shape)).bool().numpy()
    df = _df.mask(missing_mask, np.nan)
    idx_with_all_nans = pd.isnull(df).all(1)
    
    if idx_with_all_nans.sum() > 0: # sample one value of non-nans in original data and recover nans
        missing_mask_tmp = np.zeros_like(missing_mask)
        print(f'Found {idx_with_all_nans.sum()} rows with all NaNs! Recovering rows with all NaNs....')
        recovered_cols = df_og[idx_with_all_nans].apply(lambda row: np.random.choice(row.dropna().index), axis=1)
        missing_mask_tmp[idx_with_all_nans, recovered_cols] = True
        df = df.where(~missing_mask_tmp, df_og)
        missing_mask = pd.isnull(df)

        assert pd.isnull(df).all(1).sum() == 0

    return df, ~missing_mask

def drop_row_col_labels(_df, _df_y):
    print('Drop row, col, labels filled with all NaNs...')
    df = _df.copy()
    df_y = _df_y.copy()

    # Drop col
    df = df.dropna(axis=1, how='all')

    # Drop label & row
    idx_remain_y = ~pd.isnull(df_y) # drop where label is missing
    idx_remain_x = ~pd.isnull(df).all(1) # drop where entire row is missing
    idx_remain = idx_remain_x * idx_remain_y
    df = df[idx_remain]
    df_y = df_y[idx_remain]

    df['new_idx'] = np.arange(df.shape[0])
    df = df.set_index(keys='new_idx')
    df.columns = np.arange(df.shape[1])

    return df, df_y
