from tabular_dataset import  get_dataset, get_num_class
from utils import TabularConfig
from feature_extractor import FeatureExtractor
import os
from utils import read_config
from angle_distribution import cal_class_distri, angle_distribution_calibration, cal_class_angles
import pickle
from globa_utils import setup_seed
import torch
import numpy as np
from diversity_distritbution import cal_diversity_to_set_by_class
from copulas.univariate import GammaUnivariate, BetaUnivariate, GaussianUnivariate, Univariate, TruncatedGaussian, UniformUnivariate
from draw_utils import draw_real_syn_distribution
from copulas.multivariate import GaussianMultivariate
from copulas import bivariate
import pandas as pd
from args import get_dst_name_args


# os.environ["CUDA_VISIBLE_DEVICES"]='7'

TC = TabularConfig()


def estimate_bivar_distribution(angle_diversity_data, bi_type='Gaussian', is_draw=False):
    pd_data = pd.DataFrame(angle_diversity_data.T, columns=['angle', 'diversity'])
    if bi_type == 'Gaussian':
        candidates = [BetaUnivariate, GaussianUnivariate, GammaUnivariate, TruncatedGaussian, UniformUnivariate]
        angle_u = Univariate(candidates=candidates)
        angle_u.fit(pd_data['angle'])
        diversity_u = Univariate(candidates=candidates)
        diversity_u.fit(pd_data['diversity'])

        joint_distr = GaussianMultivariate(distribution={
            "angle":Univariate.from_dict(angle_u.to_dict()),
            "diversity": Univariate.from_dict(diversity_u.to_dict())
        }, random_state=TC.get_global_seed())
        joint_distr.fit(pd_data)
        if is_draw:
            sampled_data = joint_distr.sample(len(pd_data))
            # visualization.compare_2d(pd_data, sampled_data, columns=['angle', 'diversity'])
            all_data = pd.concat([pd.concat([pd_data, pd.DataFrame(["real"]*len(pd_data), columns=['Type'])], axis=1),
                       pd.concat([sampled_data, pd.DataFrame(["syn"] * len(sampled_data), columns=['Type'])], axis=1)])
            draw_real_syn_distribution(all_data, hue_key='Type')
        joint_distri_params = joint_distr.to_dict()
        print(joint_distri_params['univariates'])
    elif bi_type == 'Gumbel':
        am_distri = GammaUnivariate()
        am_distri.fit(angle_diversity_data[0, :])
        dm_distri = GammaUnivariate()
        dm_distri.fit(angle_diversity_data[1, :])
        uniform_ad_data = np.stack((am_distri.cdf(angle_diversity_data[0, :]),
                        dm_distri.cdf(angle_diversity_data[1, :])), axis=0).T
        joint_distr = bivariate.Bivariate(copula_type=bivariate.CopulaTypes.GUMBEL)
        joint_distr.fit(uniform_ad_data)
        if is_draw:
            sampled_data = joint_distr.sample(len(pd_data))
            sampled_data = np.stack((am_distri.ppf(sampled_data[:, 0]), dm_distri.ppf(sampled_data[:, 1])), axis=1)
            # visualization.compare_2d(pd_data, sampled_data, columns=['angle', 'diversity'])
            all_data = pd.concat([pd.concat([pd_data, pd.DataFrame(["real"]*len(pd_data), columns=['Type'])], axis=1),
                       pd.concat([ pd.DataFrame(sampled_data, columns=['angle', 'diversity']),
                                   pd.DataFrame(["syn"] * len(sampled_data), columns=['Type'])], axis=1)])
            draw_real_syn_distribution(all_data, hue_key='Type')
        joint_distri_params = {'angle_md':am_distri.to_dict(), 'diversity_md':dm_distri.to_dict(), 'joint_d':joint_distr.to_dict()}
    return joint_distri_params


def get_angles_and_class_centers(train_set, num_class, col_info, scaler):
    feature_extractor = FeatureExtractor(scaler)
    class_split_features = feature_extractor.extractor_features_from_dst(dst=train_set, num_classes=num_class,
                                                                         cols_info_tuple=col_info)
    class_split_angle_list, sample_num_list, class_center_list = cal_class_angles(num_class, class_split_features)

    return class_split_angle_list, sample_num_list, [item.cpu().numpy() for item in class_center_list]


def cal_angle_diversity_distribution_kfold(dst_name, bi_type, num_k=5):
    from sklearn.model_selection import StratifiedKFold
    num_class = get_num_class(dst_name)
    cfg = read_config(cfg_path=TC.get_dataset_config_path() + dst_name + '.yaml')
    train_set, col_info = get_dataset(dst_name, split='train', rand_number=cfg['split_seed'],
                                      test_ratio=cfg['test_ratio'])

    # calculate diversity by kfold
    kfold = StratifiedKFold(n_splits=num_k, shuffle=True, random_state=cfg['split_seed'])
    k_index = 0
    whole_diversity_list = []
    scaler = pickle.load(open(os.path.join(TC.get_scaler_save_path(), dst_name+'_scaler.pkl'), 'rb'))
    for train_ind, val_ind in kfold.split(train_set[0], train_set[1]):
        feature_extractor = FeatureExtractor(scaler)
        train_fold_features, val_fold_features = feature_extractor.extractor_features_from_dst(dst=train_set,
                                                                    train_val_index=(train_ind, val_ind),
                                                                    num_classes=num_class, cols_info_tuple=col_info)
        diversity_list = cal_diversity_to_set_by_class(val_fold_features, train_fold_features, num_class)

        if k_index == 0:
            whole_diversity_list = diversity_list
        else:
            whole_diversity_list = list(map(lambda a, b: torch.cat((a,b), dim=0), whole_diversity_list, diversity_list))

        k_index += 1

    # calculate angle distribution and class centers
    class_split_angle_list, total_sample_num_list, class_center_list = \
        get_angles_and_class_centers(train_set, num_class, col_info, scaler)

    # calculate joint distribution
    diversity_mean_var_distri_list = []
    angle_diversity_data_list = []
    joint_distri_params_list = []
    angle_mean_var_distri_list = []
    for i in range(num_class):
        angle_mean_var_distri_list.append((torch.mean(class_split_angle_list[i]).item(),
                                          torch.var(class_split_angle_list[i]).item()))
        diversity_mean_var_distri_list.append((torch.mean(whole_diversity_list[i]).item(),
                                               torch.var(whole_diversity_list[i], unbiased=False).item()))
        whole_diversity_list[i] = whole_diversity_list[i].cpu().numpy()
        angle_diversity_data = np.vstack((class_split_angle_list[i].cpu().numpy(), whole_diversity_list[i]))
        angle_diversity_data_list.append(angle_diversity_data)
        joint_distri_params = estimate_bivar_distribution(angle_diversity_data, bi_type=bi_type, is_draw=True)
        joint_distri_params_list.append(joint_distri_params)

    calibrated_distri = angle_distribution_calibration(angle_mean_var_distri_list, total_sample_num_list)

    train_class_distri_info = {'original_distribution': angle_mean_var_distri_list,
                         'calibrated_distribution': calibrated_distri,
                         'diversity_distribution':diversity_mean_var_distri_list,
                         'angle_distribution': angle_mean_var_distri_list,
                         'joint_distribution_params': joint_distri_params_list,
                         'angle_diversity_data': angle_diversity_data_list,
                         'class_centers': class_center_list,
                         'sample_num_list': total_sample_num_list}

    print('angle_distribution', angle_mean_var_distri_list)
    print('diversity_distribution', diversity_mean_var_distri_list)

    if not os.path.exists(os.path.join(TC.get_distribution_path(), dst_name)):
        os.makedirs(os.path.join(TC.get_distribution_path(), dst_name))

    with open(os.path.join(os.path.join(TC.get_distribution_path(), dst_name), 'kfold_class_distri_info_{}.pkl'.format(bi_type)), 'wb') as fo:
        pickle.dump(train_class_distri_info, fo)


if __name__ == '__main__':
    args = get_dst_name_args()
    setup_seed(TC.get_global_seed())
    cal_angle_diversity_distribution_kfold(dst_name=args.dst_name, bi_type='Gaussian')