
import pickle

import matplotlib.pyplot as plt
import torch

from angle_distribution import  cal_class_distri, angle_distribution_calibration, cal_class_center, cal_class_angles
from imbalanced_datasets import get_dataset, get_transform, PC, DatasetWrapper
from feature_extractor import load_feature_extractor, FeatureExtractor
import tqdm
from torch.utils.data import DataLoader
import os
from augmentation import AugmentType, CVAugment, convert_tensor_to_PILimages
import numpy as np
from globa_utils import  setup_seed
from val_sampler import DistriInfo, DistributionSampler, remove_duplication
import os
from val_sampler import split_train_val
import math
import seaborn as sns
from scipy.stats import norm, beta, gamma, wishart
from scipy.stats import multivariate_normal


os.environ["CUDA_VISIBLE_DEVICES"]='3'


def cal_entropy(logits, eps=1e-10):
    probs = torch.nn.functional.softmax(logits, dim=1)
    entropys = -1*torch.sum(probs*torch.log2(probs+eps), dim=1)
    return entropys


def cal_diversity_to_set_by_class(target_set, src_set, num_class):
    diversity_list = []
    for i in tqdm.tqdm(range(num_class)):
        min_angle_vec = cal_diversity_to_set(target_set[i], src_set[i])
        diversity_list.append(min_angle_vec)
    return diversity_list


def cal_diversity_to_set(tar_feats, src_feats, batch_size=None):
    if batch_size is None:
        batch_size = tar_feats.shape[0]
    total_len = tar_feats.shape[0]
    min_angle_vec = []
    for i in range(total_len // batch_size + 1):
        head = batch_size * i
        if head >= total_len:
            break
        tail = batch_size * i + batch_size if batch_size * i + batch_size < total_len else total_len
        angle_matrix = torch.arccos(torch.cosine_similarity(tar_feats[head:tail, :].unsqueeze(1), src_feats.unsqueeze(0), dim=-1))
        angle_matrix = angle_matrix * 180 / math.pi
        batch_min_angle_vec = torch.min(angle_matrix, dim=1)[0]
        min_angle_vec.append(batch_min_angle_vec)
    return torch.cat(min_angle_vec, dim=0)


# def cal_diversity_to_set(tar_feats, src_feats):
#     angle_matrix = torch.arccos(torch.cosine_similarity(tar_feats.unsqueeze(1), src_feats.unsqueeze(0), dim=-1))
#     angle_matrix = angle_matrix * 180 / math.pi
#     min_angle_vec = torch.min(angle_matrix, dim=1)[0]
#     return min_angle_vec


from copulas.univariate import GammaUnivariate, BetaUnivariate, GaussianUnivariate
from copulas.bivariate import Clayton, Frank, Gumbel
from copulas.multivariate import GaussianMultivariate
from copulas import bivariate
from copulas import visualization
import pandas as pd


def estimate_bivar_distribution(angle_diversity_data, bi_type='Gaussian', is_draw=False):
    pd_data = pd.DataFrame(angle_diversity_data.T, columns=['angle', 'diversity'])
    if bi_type == 'Gaussian':
        joint_distr = GaussianMultivariate(distribution={
            "angle":GammaUnivariate,
            "diversity": GammaUnivariate
        })
        joint_distr.fit(pd_data)
        if is_draw:
            sampled_data = joint_distr.sample(len(pd_data))
            # visualization.compare_2d(pd_data, sampled_data, columns=['angle', 'diversity'])
            all_data = pd.concat([pd.concat([pd_data, pd.DataFrame(["real"]*len(pd_data), columns=['Type'])], axis=1),
                       pd.concat([sampled_data, pd.DataFrame(["syn"] * len(sampled_data), columns=['Type'])], axis=1)])
            draw_real_syn_distribution(all_data, hue_key='Type')
        joint_distri_params = joint_distr.to_dict()
    elif bi_type == 'Gumbel':
        am_distri = GammaUnivariate()
        am_distri.fit(angle_diversity_data[0, :])
        dm_distri = GammaUnivariate()
        dm_distri.fit(angle_diversity_data[1, :])
        uniform_ad_data = np.stack((am_distri.cdf(angle_diversity_data[0, :]),
                        dm_distri.cdf(angle_diversity_data[1, :])), axis=0).T
        joint_distr = bivariate.Bivariate(copula_type=bivariate.CopulaTypes.GUMBEL)
        joint_distr.fit(uniform_ad_data)
        if is_draw:
            sampled_data = joint_distr.sample(len(pd_data))
            sampled_data = np.stack((am_distri.ppf(sampled_data[:, 0]), dm_distri.ppf(sampled_data[:, 1])), axis=1)
            # visualization.compare_2d(pd_data, sampled_data, columns=['angle', 'diversity'])
            all_data = pd.concat([pd.concat([pd_data, pd.DataFrame(["real"]*len(pd_data), columns=['Type'])], axis=1),
                       pd.concat([ pd.DataFrame(sampled_data, columns=['angle', 'diversity']),
                                   pd.DataFrame(["syn"] * len(sampled_data), columns=['Type'])], axis=1)])
            draw_real_syn_distribution(all_data, hue_key='Type')
        joint_distri_params = {'angle_md':am_distri.to_dict(), 'diversity_md':dm_distri.to_dict(), 'joint_d':joint_distr.to_dict()}
    return joint_distri_params


def draw_real_syn_distribution(data, hue_key):
    graph = sns.JointGrid(data=data, x="angle", y="diversity", hue=hue_key)
    graph.plot_joint(sns.kdeplot, color="orange", hue=hue_key)
    graph.plot_joint(sns.scatterplot, alpha=.8, s=2, hue=hue_key)
    graph.plot_marginals(sns.histplot, kde=True, bins=50, hue=hue_key, stat='probability') #  hist_kws={"norm_hist":True}
    plt.show()


def estimate_gaussian_distribution(angle_diversity_data):
    mean, cov =  np.mean(angle_diversity_data, axis=1), np.cov(angle_diversity_data)
    distr = multivariate_normal(cov=cov, mean=mean, seed=0)
    draw_gaussian_distribution(mean, cov, distr, angle_diversity_data)


def draw_gaussian_distribution(mean, cov, distr, data):
    mean_1, mean_2 = mean[0], mean[1]
    sigma_1, sigma_2 = cov[0, 0], cov[1, 1]
    x = np.linspace(mean_1- 1 * sigma_1, mean_1+1 * sigma_1, num=100)
    y = np.linspace(mean_2-1 * sigma_2, mean_2+1 * sigma_2, num=100)
    X, Y = np.meshgrid(x, y)

    # Generating the density function
    # for each point in the meshgrid
    # Plotting the density function values
    # key = 131 + idx
    # ax = fig.add_subplot(key, projection='3d')
    # ax.plot_surface(X, Y, pdf, cmap='viridis')
    # plt.xlabel("x1")
    # plt.ylabel("x2")
    # plt.title(f'Covariance between x1 and x2 = {val}')
    # pdf_list.append(pdf)
    # ax.axes.zaxis.set_ticks([])
    pdf = np.zeros(X.shape)
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            pdf[i, j] = distr.pdf([X[i, j], Y[i, j]])
    plt.contourf(X, Y, pdf, levels=10) #  cmap='viridis'
    plt.scatter(data[0,:], data[1,:], s=1, alpha=0.5, c='black')
    plt.xlabel("angle")
    plt.ylabel("diversity")
    plt.show()


def draw_margin_distritbuion(angle_vec, diversity_list):
    fig, axes = plt.subplots(2, 1)
    sns.distplot(angle_vec, kde=True, bins=50, ax=axes[0], kde_kws={"label": "KDE"})
    sns.distplot(angle_vec, kde=False, hist=False, ax=axes[0], fit=norm, fit_kws={"label":"norm", "color":'yellow'})
    sns.distplot(angle_vec, kde=False, hist=False, ax=axes[0], fit=beta, fit_kws={"label":"beta", "color":'red'})
    sns.distplot(angle_vec, kde=False, hist=False, ax=axes[0], fit=gamma, fit_kws={"label":"gamma", "color":'black'})
    sns.distplot(diversity_list, kde=True, bins=50, fit=beta, ax=axes[1], color="g", kde_kws={"label": "KDE"})
    sns.distplot(diversity_list, kde=False, hist=False, ax=axes[1], fit=norm,
                 fit_kws={"label": "norm", "color": 'yellow'})
    sns.distplot(diversity_list, kde=False, hist=False, ax=axes[1], fit=beta,
                 fit_kws={"label": "beta", "color": 'red'})
    sns.distplot(diversity_list, kde=False, hist=False, ax=axes[1], fit=gamma,
                 fit_kws={"label": "gamma", "color": 'black'})
    axes[0].legend()
    axes[1].legend()
    plt.show()


def draw_joint_kde(angle_vec, diversity_list):
    graph = sns.JointGrid(data={"angle":angle_vec, "diversity":diversity_list}, x="angle", y="diversity")
    graph.plot_joint(sns.kdeplot, color="orange")
    graph.plot_joint(sns.scatterplot, alpha=.8, s=2)
    graph.plot_marginals(sns.histplot, kde=True, bins=50)
    plt.show()


def save_train_distribution(num_class, fe_path, bi_type):
    NUM_K = 5
    random_seed = PC.get_global_random_seed('im_cifar'+str(num_class))
    device = "cuda:0"

    whole_train_dst = get_dataset('im_cifar' + str(num_class), split='train', rand_number=random_seed, is_wrapper=True)
    train_val_index_list = split_train_val(whole_train_dst.indexset,
                                           whole_train_dst.get_label_list(), seed=PC.get_kfoldCV_seed(),
                                           k=NUM_K, val_ratio=0.2)
    whole_angle_list = []
    whole_diversity_list = []
    avg_class_center_list = []
    total_sample_num_list = []
    for k in range(NUM_K):
        train_index, val_index = train_val_index_list[k]
        train_dst = whole_train_dst.get_dataset_by_indexes(train_index)
        train_dst.transform = get_transform('im_cifar' + str(num_class), t_type='test')
        val_dst = whole_train_dst.get_dataset_by_indexes(val_index)
        val_dst.transform = get_transform('im_cifar'+str(num_class), t_type='test')
        feature_extractor = FeatureExtractor('fine-tune', load_feature_extractor(
            os.path.join(fe_path, '{}_model.th'.format(k)), 'resnet20'
            , device, num_classes=num_class), device)
        val_fold_features = feature_extractor.extractor_features_from_dst(val_dst, num_classes=num_class)
        train_fold_features = feature_extractor.extractor_features_from_dst(train_dst, num_classes=num_class)
        class_split_angle_list, sample_num_list, class_center_list = cal_class_angles(num_class, val_fold_features)
        diversity_list = cal_diversity_to_set_by_class(val_fold_features, train_fold_features, num_class)

        if k == 0:
            whole_angle_list = class_split_angle_list
            whole_diversity_list = diversity_list
            avg_class_center_list = class_center_list
            total_sample_num_list = sample_num_list
        else:
            whole_angle_list = list(map(lambda a, b: torch.cat([a,b], dim=0), whole_angle_list, class_split_angle_list))
            whole_diversity_list = list(map(lambda a, b: torch.cat((a,b), dim=0), whole_diversity_list, diversity_list))
            avg_class_center_list = list(map(lambda a, b: a + b, avg_class_center_list, class_center_list))
            total_sample_num_list = list(map(lambda a, b: a + b, total_sample_num_list, sample_num_list))

    angle_mean_var_distri_list = []
    diversity_mean_var_distri_list = []
    angle_diversity_data_list = []
    joint_distri_params_list = []
    for i in range(num_class):
        angle_vec = whole_angle_list[i]
        angular_mean = torch.mean(angle_vec)
        angular_var = torch.var(angle_vec, unbiased=False)
        angle_mean_var_distri_list.append((angular_mean.item(), angular_var.item()))
        diversity_mean_var_distri_list.append((torch.mean(whole_diversity_list[i]).item(),
                                               torch.var(whole_diversity_list[i], unbiased=False).item()))
        angle_vec = angle_vec.cpu().numpy()
        whole_diversity_list[i] = whole_diversity_list[i].cpu().numpy()
        angle_diversity_data = np.vstack((angle_vec, whole_diversity_list[i]))
        angle_diversity_data_list.append(angle_diversity_data)
        joint_distri_params = estimate_bivar_distribution(angle_diversity_data, bi_type=bi_type)
        joint_distri_params_list.append(joint_distri_params)
        # draw_margin_distritbuion(angle_vec, whole_diversity_list[i])

    calibrated_distri = angle_distribution_calibration(angle_mean_var_distri_list, total_sample_num_list)
    avg_class_center_list = [cc.cpu().numpy() / NUM_K for cc in avg_class_center_list]

    class_distri_info = {'original_distribution': angle_mean_var_distri_list,
                         'calibrated_distribution': calibrated_distri,
                         'diversity_distribution':diversity_mean_var_distri_list,
                         'angle_distribution': angle_mean_var_distri_list,
                         'joint_distribution_params': joint_distri_params_list,
                         'angle_diversity_data': angle_diversity_data_list,
                         'class_centers': avg_class_center_list,
                         'sample_num_list': total_sample_num_list}
    print(diversity_mean_var_distri_list)
    print(angle_mean_var_distri_list)
    save_path = PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path()
    with open(save_path + 'class_distri_info_kfold3_resnet20_{}.pkl'.format(bi_type), 'wb') as fo:
        pickle.dump(class_distri_info, fo)


def update_class_center(num_class, file_name):
    random_seed = PC.get_global_random_seed('im_cifar' + str(num_class))
    device = "cuda:0"
    whole_train_dst = get_dataset('im_cifar' + str(num_class), split='train', rand_number=random_seed)
    whole_train_dst.transform = get_transform('im_cifar' + str(num_class), t_type='test')

    fe_model_path = PC.get_cifar10_fe_path() if num_class == 10 else PC.get_cifar100_fe_path()
    feature_extractor = FeatureExtractor('default',
                                         load_feature_extractor(fe_model_path, 'resnet20',device, num_classes=num_class,
                                                                fe_type='default'), device)

    mean_var_distri_list, sample_num_list, class_center_list = cal_class_distri(whole_train_dst, feature_extractor,
                                                                                num_class,
                                                                                is_return_center=True)
    class_center_list = [cc.cpu().numpy() for cc in class_center_list]
    print(mean_var_distri_list)

    save_path = PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path()

    with open(save_path + file_name, 'rb') as fi:  # 'class_distri_info.pkl'
        distri_info = pickle.load(fi)

    distri_info['class_centers'] = class_center_list

    with open(save_path + file_name, 'wb') as fo:
        pickle.dump(distri_info, fo)


def cal_data_pool_info(train_dst, data_pool, raw_info, feature_extractor:FeatureExtractor,
                        class_center_set, device="cuda:0"):
    from angle_distribution import cal_angle_to_center
    num_class = len(class_center_set)
    origin_class_features = feature_extractor.extractor_features_from_dst(train_dst, num_class)
    origin_class_logits = feature_extractor.extractor_features_from_dst(train_dst, num_class, is_logits=True)
    augmented_data_info = {}
    augmented_data_info['rank_info'] = {}
    for l in tqdm.tqdm(range(num_class)):
        dst = data_pool.get_dataset_by_class(l)
        loader = DataLoader(dst, batch_size=256, shuffle=False, num_workers=4)

        origin_center = torch.from_numpy(class_center_set[l]).to(device)
        origin_angles = cal_angle_to_center(origin_class_features[l], origin_center)
        origin_entropys = cal_entropy(origin_class_logits[l])
        # origin_entropys = origin_angles

        aug_class_features = []
        aug_class_logits = []
        for aug_images, _ in loader:
            aug_images = aug_images.to(device)
            aug_features = feature_extractor.get_features(aug_images).squeeze(dim=0)
            aug_logits = feature_extractor.get_features(aug_images, is_logits=True).squeeze(dim=0)
            aug_class_features.append(aug_features)
            aug_class_logits.append(aug_logits)

        aug_class_features = torch.cat(aug_class_features, dim=0)
        aug_class_logits = torch.cat(aug_class_logits, dim=0)

        aug_diversities = cal_diversity_to_set(aug_class_features, origin_class_features[l], batch_size=512)
        aug_diversities = aug_diversities.cpu().numpy()

        aug_angles = cal_angle_to_center(aug_class_features, origin_center)
        origin_angles = origin_angles[raw_info['index_map'][l]]

        angle_changes = torch.abs(aug_angles - origin_angles).cpu().numpy()
        aug_angles = aug_angles.cpu().numpy()

        origin_entropys = origin_entropys[raw_info['index_map'][l]].cpu().numpy()
        aug_entropy = cal_entropy(aug_class_logits).cpu().numpy()
        entropy_changes = aug_entropy-origin_entropys

        augmented_data_info['rank_info'][l] = [(aug_angles[j], angle_changes[j], aug_entropy[j], entropy_changes[j],
                                                aug_diversities[j],
                                                j, raw_info['method_map'][l][j]) for j in range(aug_angles.shape[0])]
        # cal_angle_p2p(origin_features, aug_features).cpu().numpy().tolist()

    augmented_data_info['index_map'] = raw_info['index_map']
    augmented_data_info['method_map'] = raw_info['method_map']
    return augmented_data_info


def cal_data_pool_info_main(num_class, fe_type='default'):
    device = "cuda:0"
    random_seed = PC.get_global_random_seed('im_cifar'+str(num_class))

    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'rb') as fi:
        raw_info = pickle.load(fi)

    from imbalanced_datasets import AugmentedDataset
    data_pool = DatasetWrapper(AugmentedDataset(PC.get_cifar10_data_pool_path() if num_class == 10 else PC.get_cifar100_data_pool_path(),
                                                transform=get_transform('cifar'+str(num_class), 'test')))
    train_dst = get_dataset('im_cifar' + str(num_class), split='train', rand_number=random_seed, is_wrapper=True)
    train_dst.dataset.transform = get_transform('im_cifar' + str(num_class), t_type='test')

    if fe_type == 'default':
        fe_model_path = PC.get_cifar10_fe_path() if num_class == 10 else PC.get_cifar100_fe_path()
    else:
        fe_model_path = '/data/omf/model/DataValidation/CV/simclr/encoder/cifar10/SimCLR_cifar10_resnet18_lr_0.5_decay_0.0001_bsz_1024_temp_0.5_trial_0_warm/last.pth'
    feature_extractor = FeatureExtractor(fe_type,
                                         load_feature_extractor(fe_model_path, 'resnet20', device, num_classes=num_class,
                                                                fe_type=fe_type),
                                         device)

    info = DistriInfo(PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path())

    augmented_data_info = cal_data_pool_info(train_dst, data_pool, raw_info,
                                              feature_extractor,
                                              class_center_set=info.get_class_centers())

    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'wb') as fo:
        pickle.dump(augmented_data_info, fo)


import functools
def cmp_func(a, b):
    # int(round(a[0])) == int(round(b[0])) and int(a[0]) == int(b[0])
    if a[0] == b[0]:
        # second rank indicate 1: angle changes 2: entropy 3: entropy change 4: diversity
        if a[4] > b[4]:
            return -1
        else:
            return 1
    else:
        if a[0] > b[0]:
            return 1
        else:
            return -1


def rank_main(num_class):
    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'rb') as fi:
        info = pickle.load(fi)

    ranked_info = info['rank_info']
    for c in tqdm.tqdm(range(num_class)):
        ranked_info[c] = sorted(ranked_info[c], key=functools.cmp_to_key(cmp_func))

    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'wb') as fo:
        pickle.dump({'index_map':info['index_map'], 'method_map':info['method_map'], 'rank_info':ranked_info}, fo)


def show_aug_data_distribution(num_class):
    info = DistriInfo(
        PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path(), is_train=True)
    print(info.get_origin_distribution())
    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'rb') as fi:
        augmented_data_info = pickle.load(fi)
    sample_rank_list = remove_duplication(augmented_data_info['rank_info'], tolerance=0)
    for i in range(num_class):
        # if i % 10 != 0:
        #     continue
        aug_samples = sample_rank_list[i]
        aug_angle_vec = np.asarray([item[0] for item in aug_samples])
        aug_diversity_vec = np.asarray([item[4] for item in aug_samples])
        train_angle_diversity_vec = info.get_angle_diversity_data(i)
        train_aug_data = pd.concat([pd.DataFrame({'angle':aug_angle_vec, 'diversity':aug_diversity_vec, 'Type':["aug"] * len(aug_samples)}),
                                    pd.DataFrame({'angle': train_angle_diversity_vec[0,:], 'diversity': train_angle_diversity_vec[1,:],
                                                  'Type': ["train"] * train_angle_diversity_vec.shape[1]})])
        draw_real_syn_distribution(train_aug_data, hue_key='Type')


def pair_dist(a: torch.Tensor, b: torch.Tensor, p: int = 2) -> torch.Tensor:
    return (a-b).abs().pow(p).sum(-1).pow(1/p)


def sample_by_joint_distribution(num_class, sample_num, bi_type):
    distri_file_path = PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path()
    class_distri_info = DistriInfo(distri_file_path, is_train=True)
    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'rb') as fi:
        augmented_data_info = pickle.load(fi)
    # sample_rank_list = remove_duplication(augmented_data_info['rank_info'], tolerance=0)
    sample_rank_list = augmented_data_info['rank_info']
    results = {'index': {}, 'info': {}}

    for i in range(num_class):
        aug_angle_diversity_vec = torch.tensor([[item[0], item[4]] for item in sample_rank_list[i]], dtype=float)

        if bi_type == 'Gaussian':
            joint_distri = GaussianMultivariate.from_dict(class_distri_info.get_joint_distribution(i))
            sampled_points = joint_distri.sample(sample_num)
            sampled_points = torch.from_numpy(sampled_points.values)
        elif bi_type == 'Gumbel':
            distr_dict = class_distri_info.get_joint_distribution(i)
            joint_distri = bivariate.Bivariate(bivariate.CopulaTypes.GUMBEL).from_dict(distr_dict['joint_d'])
            am_distri = GammaUnivariate.from_dict(distr_dict['angle_md'])
            dm_distri = GammaUnivariate.from_dict(distr_dict['diversity_md'])
            sampled_points = joint_distri.sample(sample_num)
            sampled_points = np.stack((am_distri.ppf(sampled_points[:, 0]), dm_distri.ppf(sampled_points[:, 1])), axis=1)
            sampled_points = torch.from_numpy(sampled_points)

        match_dist_matrix = pair_dist(sampled_points.unsqueeze(1), aug_angle_diversity_vec.unsqueeze(0))
        dists, indices = match_dist_matrix.topk(20, dim=1, largest=False)
        dists, indices = dists.numpy(), indices.numpy()
        selected_index_set = set()
        selected_sample_info_list = []
        total_not_in_num = 0
        for j in range(indices.shape[0]):
            inside_indices = False
            for k in range(indices.shape[1]):
                curr_select_aug_index = indices[j, k]
                if curr_select_aug_index in selected_index_set:
                    continue
                else:
                    selected_index_set.add(curr_select_aug_index)

                    selected_sample_info_list.append((curr_select_aug_index, dists[j, k],
                                                      abs(sampled_points[j, 0] - aug_angle_diversity_vec[curr_select_aug_index, 0]).item(),
                                                      abs(sampled_points[j, 1] - aug_angle_diversity_vec[curr_select_aug_index, 1]).item(),
                                                        *sample_rank_list[i][curr_select_aug_index]))
                    inside_indices=True
                    break
            if not inside_indices:
                total_not_in_num +=1
        print(total_not_in_num)
        results['index'][i] = [item[5+4] for item in selected_sample_info_list]
        results['info'][i] = selected_sample_info_list

    return results


def show_selected_aug_data_distribution(num_class, aug_angle_diversity_list):
    info = DistriInfo(
        PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path(),
        is_train=True)
    print(info.get_origin_distribution())
    for i in range(num_class):
        if i % 10 != 0:
            continue
        aug_angle_diversity_vec = aug_angle_diversity_list[i]
        train_angle_diversity_vec = info.get_angle_diversity_data(i)
        train_aug_data = pd.concat(
            [pd.DataFrame({'angle': aug_angle_diversity_vec[:, 0], 'diversity': aug_angle_diversity_vec[:, 1],
                           'Type': ["selected"] * aug_angle_diversity_vec.shape[0]}),
             pd.DataFrame({'angle': train_angle_diversity_vec[0, :], 'diversity': train_angle_diversity_vec[1, :],
                           'Type': ["train"] * train_angle_diversity_vec.shape[1]})])
        draw_real_syn_distribution(train_aug_data, hue_key='Type')


if __name__ == '__main__':
    setup_seed(42)
    num_class = 100
    save_train_distribution(num_class,
                            '/data/omf/model/DataValidation/CV/feature_extractor/cifar{}/kfoldCV_resnet20/'.format(num_class),
                               bi_type='Gaussian')
    update_class_center(num_class, file_name='class_distri_info_kfold3_resnet20_{}.pkl'.format('Gaussian'))
    cal_data_pool_info_main(num_class)
    rank_main(num_class)
    # show_aug_data_distribution(num_class)

    # #
    # results = sample_by_joint_distribution(num_class=num_class, sample_num=40, bi_type='Gumbel')
    #
    # aug_angle_diversity_list = []
    # for i in range(num_class):
    #     aug_angle_diversity_list.append(np.asarray([[item[4], item[8]] for item in results['info'][i]], dtype=float))
    # show_selected_aug_data_distribution(num_class, aug_angle_diversity_list)