from utils import PathConfig
from nlp_dataset import get_num_class, get_dataset
from feature_extractor import FeatureExtractor
import os
from utils import read_config
from angle_distribution import cal_class_distri, angle_distribution_calibration, cal_class_angles
from diversity_distritbution import cal_diversity_to_set_by_class
import pickle
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from globa_utils import setup_seed
import torch
import numpy as np
from copulas.univariate import GammaUnivariate, BetaUnivariate, GaussianUnivariate, Univariate, TruncatedGaussian, UniformUnivariate, GaussianKDE
from copulas.multivariate import GaussianMultivariate
from copulas import bivariate
import pandas as pd
from draw_utils import draw_real_syn_distribution
from sklearn.neighbors import KernelDensity

os.environ["CUDA_VISIBLE_DEVICES"]='3, 4'

PC = PathConfig()


def estimate_bivar_distribution(angle_diversity_data, bi_type='Gaussian', is_draw=False):
    pd_data = pd.DataFrame(angle_diversity_data.T, columns=['angle', 'diversity'])
    candidates = [BetaUnivariate, GaussianUnivariate, GammaUnivariate, TruncatedGaussian, UniformUnivariate]
    if bi_type == 'Gaussian':
        angle_u = Univariate(candidates=candidates)
        angle_u.fit(pd_data['angle'])
        diversity_u = Univariate(candidates=candidates)
        diversity_u.fit(pd_data['diversity'])

        joint_distr = GaussianMultivariate(distribution={
            "angle":Univariate.from_dict(angle_u.to_dict()),
            "diversity": Univariate.from_dict(diversity_u.to_dict())
        }, random_state=PC.get_global_seed())
        joint_distr.fit(pd_data)
        if is_draw:
            sampled_data = joint_distr.sample(len(pd_data))
            # visualization.compare_2d(pd_data, sampled_data, columns=['angle', 'diversity'])
            all_data = pd.concat([pd.concat([pd_data, pd.DataFrame(["real"]*len(pd_data), columns=['Type'])], axis=1),
                       pd.concat([sampled_data, pd.DataFrame(["syn"] * len(sampled_data), columns=['Type'])], axis=1)])
            draw_real_syn_distribution(all_data, hue_key='Type')
        joint_distri_params = joint_distr.to_dict()
        print(joint_distri_params['univariates'])
    elif bi_type == 'Gumbel':
        am_distri = Univariate(candidates=candidates)
        am_distri.fit(angle_diversity_data[0, :])
        dm_distri = Univariate(candidates=candidates)
        dm_distri.fit(angle_diversity_data[1, :])
        uniform_ad_data = np.stack((am_distri.cdf(angle_diversity_data[0, :]),
                        dm_distri.cdf(angle_diversity_data[1, :])), axis=0).T
        joint_distr = bivariate.Bivariate(copula_type=bivariate.CopulaTypes.GUMBEL)
        joint_distr.fit(uniform_ad_data)
        if is_draw:
            sampled_data = joint_distr.sample(len(pd_data))
            sampled_data = np.stack((am_distri.ppf(sampled_data[:, 0]), dm_distri.ppf(sampled_data[:, 1])), axis=1)
            # visualization.compare_2d(pd_data, sampled_data, columns=['angle', 'diversity'])
            all_data = pd.concat([pd.concat([pd_data, pd.DataFrame(["real"]*len(pd_data), columns=['Type'])], axis=1),
                       pd.concat([ pd.DataFrame(sampled_data, columns=['angle', 'diversity']),
                                   pd.DataFrame(["syn"] * len(sampled_data), columns=['Type'])], axis=1)])
            draw_real_syn_distribution(all_data, hue_key='Type')
        joint_distri_params = {'angle_md':am_distri.to_dict(), 'diversity_md':dm_distri.to_dict(), 'joint_d':joint_distr.to_dict()}
    elif bi_type == 'KDE':
        kde_distri = KernelDensity(kernel='gaussian')
        kde_distri.fit(angle_diversity_data.T)
        joint_distri_params = kde_distri
        if is_draw:
            sampled_data = kde_distri.sample(len(pd_data))
            all_data = pd.concat([pd.concat([pd_data, pd.DataFrame(["real"] * len(pd_data), columns=['Type'])], axis=1),
                                  pd.concat([pd.DataFrame(sampled_data, columns=['angle', 'diversity']),
                                             pd.DataFrame(["syn"] * len(sampled_data), columns=['Type'])],
                                            axis=1)])
            draw_real_syn_distribution(all_data, hue_key='Type')
    return joint_distri_params


def cal_angle_diversity_distribution_kfold(dst_name, bi_type, num_k=5):
    from sklearn.model_selection import StratifiedKFold
    num_class = get_num_class(dst_name)
    cfg = read_config(cfg_path=PC.get_dataset_config_path() + dst_name + '.yaml')
    whole_train_set, test_set = get_dataset(dst_name,  model_name="bert-base-uncased")

    # calculate diversity by kfold
    kfold = StratifiedKFold(n_splits=num_k, shuffle=True, random_state=cfg['split_seed'])
    k_index = 0
    whole_angle_list = []
    whole_diversity_list = []
    for train_ind, val_ind in kfold.split([i for i in range(len(whole_train_set))], whole_train_set['labels']):
        train_dst = whole_train_set.select(train_ind)
        val_dst = whole_train_set.select(val_ind)

        feature_extractor = FeatureExtractor(num_class=num_class,
            weight_path=os.path.join(PathConfig().get_fe_path(dst_name), str(k_index)))
        train_fold_features = \
            feature_extractor.extractor_features_from_dst(train_dst.remove_columns(['labels']), train_dst['labels'])
        val_fold_features = \
            feature_extractor.extractor_features_from_dst(val_dst.remove_columns(['labels']), val_dst['labels'])

        class_split_angle_list, sample_num_list, class_center_list = cal_class_angles(num_class, val_fold_features)
        diversity_list = cal_diversity_to_set_by_class(val_fold_features, train_fold_features, num_class)

        if k_index == 0:
            whole_angle_list = class_split_angle_list
            whole_diversity_list = diversity_list
        else:
            whole_angle_list = list(
                map(lambda a, b: torch.cat([a, b], dim=0), whole_angle_list, class_split_angle_list))
            whole_diversity_list = list(
                map(lambda a, b: torch.cat((a, b), dim=0), whole_diversity_list, diversity_list))

        del feature_extractor

        k_index += 1

    # calculate angle distribution and class centers
    feature_extractor = FeatureExtractor(num_class=num_class, weight_path=os.path.join(PathConfig().get_fe_path(dst_name), 'default'))
    class_split_features = feature_extractor.extractor_features_from_dst(whole_train_set.remove_columns(['labels']),
                                                                         whole_train_set['labels'])
    _, total_sample_num_list, class_center_list = cal_class_angles(num_class, class_split_features)
    class_center_list = [item.cpu().numpy() for item in class_center_list]
    test_angle_distri, _ = cal_class_distri(None, None, num_class, is_return_center=False,
                     class_split_features=
                     feature_extractor.extractor_features_from_dst(test_set.remove_columns(['labels']), test_set['labels']))

    # calculate joint distribution
    diversity_mean_var_distri_list = []
    angle_diversity_data_list = []
    joint_distri_params_list = []
    angle_mean_var_distri_list = []
    for i in range(num_class):
        angle_mean_var_distri_list.append((torch.mean(whole_angle_list[i]).item(),
                                          torch.var(whole_angle_list[i]).item()))
        diversity_mean_var_distri_list.append((torch.mean(whole_diversity_list[i]).item(),
                                               torch.var(whole_diversity_list[i], unbiased=False).item()))
        whole_diversity_list[i] = whole_diversity_list[i].cpu().numpy()
        angle_diversity_data = np.vstack((whole_angle_list[i].cpu().numpy(), whole_diversity_list[i]))
        angle_diversity_data_list.append(angle_diversity_data)
        joint_distri_params = estimate_bivar_distribution(angle_diversity_data, bi_type=bi_type, is_draw=True)
        joint_distri_params_list.append(joint_distri_params)

    calibrated_distri = angle_distribution_calibration(angle_mean_var_distri_list, total_sample_num_list)

    # save ditribution
    train_class_distri_info = {'original_distribution': angle_mean_var_distri_list,
                         'calibrated_distribution': calibrated_distri,
                         'diversity_distribution':diversity_mean_var_distri_list,
                         'angle_distribution': angle_mean_var_distri_list,
                         'joint_distribution_params': joint_distri_params_list,
                         'angle_diversity_data': angle_diversity_data_list,
                         'class_centers': class_center_list,
                         'sample_num_list': total_sample_num_list}

    print('train angle_distribution', angle_mean_var_distri_list)
    print('test angle_distribution', test_angle_distri)
    print('diversity_distribution', diversity_mean_var_distri_list)

    distr_save_path = PC.get_distribution_path(dst_name)
    if not os.path.exists(distr_save_path):
        os.makedirs(distr_save_path)

    with open(os.path.join(distr_save_path, 'kfold_class_distri_info_{}.pkl'.format(bi_type)), 'wb') as fo:
        pickle.dump(train_class_distri_info, fo)


if __name__ == '__main__':
    setup_seed(PC.get_global_seed())
    cal_angle_diversity_distribution_kfold(dst_name='reuters',bi_type='Gaussian')