import sys
sys.path.append("..")
from globa_utils import PathConfig
import pandas as pd
from pandas import DataFrame
import numpy as np

def check_data(data, is_print=False):
    """  General inspection of the data.
    :param data:
    :return: None
    """
    if is_print:
        print(data.shape)
        print(data.head())
        print(data.isnull().sum())
        print(data.info())
        for col in data.columns.values:
            print(data[col].value_counts())
        print(data.describe())


def read_data(file_path, delimiter=None):
    """ Read dataset in csv form
    :param file_path: dataset file path
    :param delimiter:
    :return: Dataframe
    """
    data = pd.read_csv(file_path, delimiter=delimiter)
    return data


class TabularConfig(PathConfig):
    def __init__(self):
        super(TabularConfig, self).__init__(cfg_path="path_config.yaml")

    def get_dataset_path(self):
        return self.cfg['dataset_path']

    def get_data_pool_path(self):
        return self.cfg['data_pool_path']

    def get_data_pool_info_path(self):
        return self.cfg['augment_data_info_path']

    def get_fe_path(self):
        return self.cfg['feature_extractor_save_path']

    def get_distribution_path(self):
        return self.cfg['class_distri_save_path']

    def get_bank_dataset_path(self):
        return self.get_dataset_path() + 'bank-additional-full.csv'

    def get_car_dataset_path(self):
        return self.get_dataset_path() + 'car_evaluation.csv'

    def get_pageblocks_dataset_path(self):
        return self.get_dataset_path() + 'page_blocks.csv'

    def get_mushroom_dataset_path(self):
        return self.get_dataset_path() + 'mushrooms.csv'

    def get_heart_dataset_path(self):
        return self.get_dataset_path() + 'heart.csv'

    def get_diabetes_dataset_path(self):
        return self.get_dataset_path() + 'diabetes.csv'

    def get_job_dataset_path(self):
        return self.get_dataset_path() + 'HR-Analytics/'

    def get_dataset_config_path(self):
        return self.cfg['dataset_config_path']

    def get_experiment_save_path(self):
        return self.cfg['experiments_save_path']

    def get_global_seed(self):
        return self.cfg['global_random_seed']

    def get_scaler_save_path(self):
        return self.cfg['dataset_scaler_save_path']



def adjust_dataset_size(dst:DataFrame, action_type, y_name, sample_rate=None, unbalanced_ratio=None):
    """ adjust the number of samples in a dataset

    :param action_type: 1 sample instance in each class by 'sample_rate'
                        2 fix number of instance in negative class, and down sample positive class by unbalanced ratio
                        3 fix number of instance in positive class, and down sample negtive class by unbalanced ratio
                        4 fix number of instance in negative class, and up sample positive class by unbalanced ratio
    :return: Adjusted dataset
    """
    np.random.seed(seed=0)
    def sampling(group, class_dict):
        name = group.name
        n = class_dict[name]
        return group.sample(n=n)

    class_distri = dst[y_name].value_counts()
    class_dict = {}
    for i in range(len(class_distri)):
        name = class_distri.index[i]
        class_dict[name] = class_distri[name]
    keys = list(class_dict.keys())
    if action_type == 1:
        for key in keys:
            class_dict[key] = int(class_dict[key]*sample_rate)
        dst = dst.groupby(y_name).apply(sampling, class_dict)
    elif action_type ==2:
        class_dict[keys[1]] = int(class_dict[keys[0]]/unbalanced_ratio)
        dst = dst.groupby(y_name).apply(sampling, class_dict)
    elif action_type==3:
        class_dict[keys[0]] = int(class_dict[keys[1]]*unbalanced_ratio)
        # print(class_dict)
        dst = dst.groupby(y_name).apply(sampling, class_dict)
    elif action_type==4:
        class_dict[keys[1]] = int(class_dict[keys[0]] / unbalanced_ratio)
    print('='*80)
    dst.index = dst.index.droplevel()
    print(dst[y_name].value_counts())
    # dst = sklearn.utils.shuffle(dst)
    return dst


def split_labels(dst, y_name:str):
    """ Split dataset into data(x) and label(y)

    :param dst: dataset
    :param y_name: column name of label
    :return: datas of dst, labels of dst
    """
    check_data(dst)
    dst_x = dst.drop([y_name], axis=1)
    dst_y = dst[y_name]
    dst_x = dst_x.reset_index(drop=True)
    dst_y = dst_y.reset_index(drop=True)
    return dst_x, dst_y


# from pytorch_tabnet.metrics import Metric
from sklearn.metrics import f1_score

# class F1(Metric):
#
#     def __init__(self):
#         self._name = "F1" # write an understandable name here
#         self._maximize = True
#
#     def __call__(self, y_true, y_score):
#         """
#         Compute AUC of predictions.
#
#         Parameters
#         ----------
#         y_true: np.ndarray
#             Target matrix or vector
#         y_score: np.ndarray
#             Score matrix or vector
#
#         Returns
#         -------
#             float
#         """
#         y_pred= np.argmax(y_score, axis=1)
#         return f1_score(y_true, y_pred)


def f1_eval(y_pred, dtrain):
    y_true = dtrain.get_label()
    # pred_score = 1.0 / (1.0 + np.exp(-y_pred))
    # pred = [1 if p > 0.5 else 0 for p in pred_score]
    pred = np.argmax(np.array(y_pred), axis=1)
    err = 1 - f1_score(y_true, pred, average='macro')
    return 'f1-err', err


def split_cont_categ(dst:DataFrame, var_numerical, cat_idxs):
    dst_cont = dst[var_numerical].to_numpy()
    dst_categ = dst.to_numpy()[:,cat_idxs]
    return dst_cont, dst_categ


def print_result(perf_dict):
    result = {'val_f1_score_mean_std':None, 'val_auc_mean_std':None, 'val_test_bias_mean_mean':None,
              'test_f1_score_mean_mean':None, 'test_auc_mean_mean':None}
    print(perf_dict)

    print('metric-mean std')
    result['val_f1_score_mean_std'] = np.asarray(perf_dict['val_f1_score_mean']).std()
    # result['val_auc_mean_std'] = np.asarray(perf_dict['val_auc_mean']).std()
    print('val_f1_score std', result['val_f1_score_mean_std'])
    # print('val_auc std', result['val_auc_mean_std'])
    # print('val_test_bias_mean', np.asarray(perf_dict['val_test_bias_mean']).std())
    # print(np.asarray(perf_dict['positive_f1_mean']).std())
    print('metric-bias mean')
    result['val_test_bias_mean_mean'] = np.asarray(perf_dict['val_test_bias_mean']).mean()
    print('val_test_bias mean', result['val_test_bias_mean_mean'])

    print('perf-metric mean')
    result['test_f1_score_mean_mean'] = np.asarray(perf_dict['test_f1_score_mean']).mean()
    # result['test_auc_mean_mean'] = np.asarray(perf_dict['test_auc_mean']).mean()
    print('test_f1_score mean', result['test_f1_score_mean_mean'])
    # print('test_auc mean', result['test_auc_mean_mean'])

    # if len(perf_dict['f1_score_group_std']) > 0:
    #     print(np.asarray(perf_dict['f1_score_group_std']).mean())
    #     print(np.asarray(perf_dict['auc_group_std']).mean())
    return result


def generate_seed_set():
    """ Generate 100 random seeds

    :return: list contains 100 seeds
    """
    np.random.seed(0)
    seed_set = np.random.randint(0, 10000, size=100).tolist()
    return seed_set


import yaml
def read_config(cfg_path, args=None):
    """ Load config file

    :param cfg_path: file path
    :param args:
    :return: config
    """
    with open(cfg_path, 'r') as f:
        cfg = yaml.safe_load(f)
        if args is not None:
            cfg = complete_cfg_by_args(cfg, args)
        print(cfg)
    return cfg


from collections import namedtuple

ColInfo = namedtuple('ColInfo', ['cate_idxs', 'cate_dims', 'cont_name', 'cate_name'])


def contruct_col_info(categorical_columns, numerical_columns, target_col_name, dst):
    categorical_dims = {}
    for col in categorical_columns:
        categorical_dims[col] = dst[col].nunique()

    unused_feat = []
    features = [col for col in dst.columns if col not in unused_feat + [target_col_name]]
    cat_idxs = [i for i, f in enumerate(features) if f in categorical_columns]
    cat_dims = [categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]
    col_info = ColInfo(cate_idxs=cat_idxs, cate_dims=cat_dims, cont_name=numerical_columns, cate_name=categorical_columns)
    return col_info