import pickle

import pandas as pd
from sklearn.model_selection import train_test_split, StratifiedKFold
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
import numpy as np
# from models.tab_transformer_pytorch import get_model_and_train_components, tab_transformer_train, tab_transformer_eval
from utils import split_cont_categ, print_result,ColInfo, f1_eval
from tabular_dataset import  get_dataset, get_num_class
from sklearn.metrics import f1_score
import warnings
# from SPlit import split_train_val_by_SPlit
warnings.filterwarnings("ignore")



def train(cfg, seed_set, train_set, test_set, cols_info_tuple:ColInfo, label_list):
    """ Multiple rounds of training and collect the results of each training round

    :param cfg: config object
    :param seed_set: list of random seeds
    :param train_set: whole train set
    :param feature_cols:
    :param label_list: list of category ([0,1,...]
    :return: results which contain f1, bias and variance
    """
    perf_dict = {
        'val_f1_score_mean': [],
        'test_f1_score_mean': [],
        'val_test_bias_list': [],
        'val_test_bias_mean': [],
        'f1_score_group_std': [],
    }

    if 'jkfold' in cfg['val_method']:
        J = cfg['repeat_j']
        for seed_index in range(len(seed_set)):
            print(seed_index)
            np.random.seed(seed_set[seed_index])
            kfold_seeds = np.random.randint(0, 10000, size=J)
            jkfold_perf_dict = {'val_f1_score': [],
                                'val_test_bias': [],
                                'test_f1_score': []}
            for j in range(J):
                performance_dic = one_round_training(cfg, kfold_seeds[j], train_set, test_set, cols_info_tuple,
                                                                  label_list)
                jkfold_perf_dict['val_f1_score'].append(np.asarray(performance_dic['val_f1_score']).mean())
                # jkfold_perf_dict['val_auc'].append(np.asarray(performance_dic['val_auc']).mean())
                jkfold_perf_dict['test_f1_score'].append(np.asarray(performance_dic['test_f1_score']).mean())
                # jkfold_perf_dict['test_auc'].append(np.asarray(performance_dic['test_auc']).mean())
                jkfold_perf_dict['val_test_bias'].append(np.asarray(performance_dic['val_test_bias']).mean())

            perf_dict['val_f1_score_mean'].append(np.asarray(jkfold_perf_dict['val_f1_score']).mean())
            # perf_dict['val_auc_mean'].append(np.asarray(jkfold_perf_dict['val_auc']).mean())
            perf_dict['test_f1_score_mean'].append(np.asarray(jkfold_perf_dict['test_f1_score']).mean())
            # perf_dict['test_auc_mean'].append(np.asarray(jkfold_perf_dict['test_auc']).mean())
            perf_dict['val_test_bias_mean'].append(np.asarray(jkfold_perf_dict['val_test_bias']).mean())

    else:
        for seed_index in range(len(seed_set)):
            print(seed_index)
            performance_dic = one_round_training(cfg, seed_set[seed_index], train_set, test_set, cols_info_tuple, label_list)
            perf_dict['val_f1_score_mean'].append(np.asarray(performance_dic['val_f1_score']).mean())
            # perf_dict['val_auc_mean'].append(np.asarray(performance_dic['val_auc']).mean())
            perf_dict['test_f1_score_mean'].append(np.asarray(performance_dic['test_f1_score']).mean())
            # perf_dict['test_auc_mean'].append(np.asarray(performance_dic['test_auc']).mean())
            perf_dict['val_test_bias_mean'].append(np.asarray(performance_dic['val_test_bias']).mean())
            # print(perf_dict)
            # perf_dict['val_test_bias_list'].extend(performance_dic['val_test_bias'])

    # print(perf_dict)
    final_result = print_result(perf_dict)
    with open(cfg['result_save_path'], 'wb') as f:
        pickle.dump(final_result, f)
    return perf_dict, final_result


def one_round_training(cfg, seed, train_set, test_set, cols_info_tuple:ColInfo, label_list=None):
    """ One round training
    """
    model = cfg['model']

    train_set = (train_set[0].reset_index(drop=True), train_set[1].reset_index(drop=True))

    if 'kfold' in cfg['val_method']:
        k = cfg['k']
        kfold = StratifiedKFold(n_splits=cfg['k'], shuffle=True, random_state=seed)
        kfold_set = []
        for train_ind, val_ind in kfold.split(train_set[0], train_set[1]):
            kfold_set.append((train_ind, val_ind))
    else:
        k = 1

    performance_dic = {
        'val_f1_score': [],
        'val_test_bias': [],
        'test_f1_score': []
    }

    for i in range(k):
        if k > 1:
            init_x_train = train_set[0].iloc[kfold_set[i][0].tolist()]
            init_y_train = train_set[1].iloc[kfold_set[i][0].tolist()]
            init_x_val = train_set[0].iloc[kfold_set[i][1].tolist()]
            init_y_val = train_set[1].iloc[kfold_set[i][1].tolist()]
        else:
            if 'split_free' in cfg['val_method']:
                if cfg['val_method'] == 'split_free_test':
                    val_set = test_set
                elif cfg['val_method'] == 'split_free_noval':
                    val_set = train_set
                elif cfg['val_method'] == 'split_free_holdout':
                    val_set, _ = get_dataset(cfg['dst_name'], split='val', val_method='DB_ADJOINT',
                                             rand_number=seed, val_num_per_class=cfg['val_num_per_class'])
                    x_train, _, y_train, _ = train_test_split(train_set[0], train_set[1],
                                                                                          test_size=cfg['val_ratio'],
                                                                                          random_state=seed,
                                                                                          stratify=train_set[1])
                    train_set = (x_train, y_train)
                else:
                    val_method_map = {'split_free':'DB', 'split_free_random':'RANDOM', 'split_free_joint':'DB_ADJOINT'}
                    val_set, _ = get_dataset(cfg['dst_name'], split='val', val_method=val_method_map[cfg['val_method']],
                                             rand_number=seed, val_num_per_class=cfg['val_num_per_class'])
                init_x_val, init_y_val = val_set
                init_x_train, init_y_train = train_set[0], train_set[1]
            elif cfg['val_method'] == 'LZO':
                init_x_train, init_y_train = train_set[0], train_set[1]
                val_set, _ = get_dataset(cfg['dst_name'], split='val', val_method='LZO', rand_number=seed)
                init_x_val, init_y_val = val_set
            elif cfg['val_method'] == 'SPlit':
                train_indexes, val_indexes = \
                    split_train_val_by_SPlit(pd.concat((train_set[0],
                                                        pd.DataFrame(train_set[1], index=None)), axis=1),
                                             val_ratio=cfg['val_ratio'],
                                             numeric_col=cols_info_tuple.cont_name)
                init_x_train = train_set[0].iloc[train_indexes]
                init_y_train = train_set[1].iloc[train_indexes]
                init_x_val = train_set[0].iloc[val_indexes]
                init_y_val = train_set[1].iloc[val_indexes]
            else:
                init_x_train, init_x_val, init_y_train, init_y_val = train_test_split(train_set[0], train_set[1],
                                                            test_size=cfg['val_ratio'],
                                                            random_state=seed,
                                                            stratify=train_set[1])

        init_train_set = (init_x_train.reset_index(drop=True), init_y_train.reset_index(drop=True))
        init_val_set = (init_x_val.reset_index(drop=True), init_y_val.reset_index(drop=True))

        (x_train, y_train), (x_val, y_val)= init_train_set, init_val_set

        print('train set shape', len(x_train))
        print('val set shape', len(x_val))

        if model != 'tabtransformer':
            x_train = x_train.values
            y_train = y_train.values
            x_val = x_val.values
            y_val = y_val.values

        print('training .....')
        if model == 'rfc':
            cls = RandomForestClassifier(random_state=0, n_jobs=cfg['model_params']['n_jobs'])
            cls.fit(x_train, y_train)
            # Predict the Test set results
        elif model == 'xgb':
            cls = XGBClassifier(**cfg['xgb_model_params'],
                                objective='multi:softprob', booster='gbtree', num_class=get_num_class(cfg['dst_name']))
            cls.fit(x_train, y_train, eval_metric=f1_eval, eval_set=[(x_val, y_val)], verbose=False, early_stopping_rounds=15)
        elif model == 'lr':
            cls = LogisticRegression(multi_class="multinomial", solver="newton-cg", max_iter=1000, n_jobs=32)
            cls.fit(x_train, y_train)
        elif model == 'tabtransformer':
            train_cont, train_categ = split_cont_categ(x_train, cols_info_tuple.cont_name, cols_info_tuple.cate_idxs)
            val_cont, val_categ = split_cont_categ(x_val, cols_info_tuple.cont_name, cols_info_tuple.cate_idxs)
            train_target = y_train.to_numpy()
            val_target = y_val.to_numpy()
            cls, criterion, optimizer, scheduler = get_model_and_train_components(cols_info_tuple.cate_dims,
                                                                                  len(cols_info_tuple.cont_name),
                                                                                  cfg['model_params'])
            best_model_dict = tab_transformer_train(
                cls,
                criterion,
                optimizer,
                scheduler,
                train_cont,
                train_categ,
                train_target,
                val_cont,
                val_categ,
                val_target,
                device="cuda:0",
                **cfg['train_params'],
                save_best_model_dict=False,
                save_metrics=False,
            )
            cls.load_state_dict(best_model_dict)

        if model == 'tabtransformer':
            f1_val = tab_transformer_eval(cls, val_cont, val_categ, val_target)
            test_cont, test_categ = split_cont_categ(test_set[0], cols_info_tuple.cont_name, cols_info_tuple.cate_idxs)
            test_target = test_set[1].to_numpy()
            f1_test = tab_transformer_eval(cls, test_cont, test_categ, test_target)
        else:
            y_test_pred = cls.predict(test_set[0].values)
            y_val_pred = cls.predict(x_val)

            f1_val = f1_score(y_val, y_val_pred, average='macro')
            f1_test = f1_score(test_set[1].values, y_test_pred, average='macro')

        performance_dic['val_f1_score'].append(f1_val)
        performance_dic['val_test_bias'].append(abs(f1_val - f1_test))
        performance_dic['test_f1_score'].append(f1_test)

    return performance_dic

