import pickle

import torch

from angle_distribution import  cal_class_distri, angle_distribution_calibration, cal_class_center
from imbalanced_datasets import get_dataset, get_transform, PC, DatasetWrapper
from feature_extractor import load_feature_extractor, FeatureExtractor
import tqdm
from torch.utils.data import DataLoader
import os
from augmentation import AugmentType, CVAugment, convert_tensor_to_PILimages
import numpy as np
from globa_utils import  setup_seed
from val_sampler import DistriInfo, DistributionSampler, remove_duplication
import os
from val_sampler import split_train_val
os.environ["CUDA_VISIBLE_DEVICES"]='6'


# setup_seed(42)
def save_train_distribution_kfold_main(num_class):
    NUM_K = 5
    random_seed = PC.get_global_random_seed('im_cifar'+str(num_class))
    device = "cuda:0"

    whole_train_dst = get_dataset('im_cifar' + str(num_class), split='train', rand_number=random_seed, is_wrapper=True)
    train_val_index_list = split_train_val(whole_train_dst.indexset,
                                           whole_train_dst.get_label_list(), seed=PC.get_kfoldCV_seed(),
                                           k=NUM_K, val_ratio=0.2)

    class_split_features = []
    for k in range(NUM_K):
        train_index, val_index = train_val_index_list[k]
        val_dst = whole_train_dst.get_dataset_by_indexes(val_index)
        val_dst.transform = get_transform('im_cifar', t_type='test')
        feature_extractor = FeatureExtractor('fine-tune', load_feature_extractor(
            '/data/omf/model/DataValidation/CV/feature_extractor/cifar{}/kfoldCV/{}_model.th'.format(num_class, k)
            , device, num_classes=num_class), device)
        fold_features = feature_extractor.extractor_features_from_dst(val_dst, num_classes=num_class)
        class_split_features.append(fold_features)

    class_split_features = {c:torch.cat([class_split_features[i][c] for i in range(NUM_K)], dim=0)
                            for c in range(num_class)}

    mean_var_distri_list, sample_num_list, class_center_list = cal_class_distri(dst=None, feature_extractor=None,
                                                                                num_class=num_class, is_return_center=True,
                                                                                class_split_features=class_split_features)
    calibrated_distri = angle_distribution_calibration(mean_var_distri_list, sample_num_list)
    class_center_list = [cc.cpu().numpy() for cc in class_center_list]

    class_distri_info = {'original_distribution': mean_var_distri_list,
                         'calibrated_distribution': calibrated_distri,
                         'class_centers': class_center_list, 'sample_num_list': sample_num_list}

    save_path = PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path()
    with open(save_path + 'class_distri_info.pkl', 'wb') as fo:
        pickle.dump(class_distri_info, fo)


def save_train_distribution_kfold_main_v2(num_class, fe_path):
    NUM_K = 5
    random_seed = PC.get_global_random_seed('im_cifar'+str(num_class))
    device = "cuda:0"

    whole_train_dst = get_dataset('im_cifar' + str(num_class), split='train', rand_number=random_seed, is_wrapper=True)
    train_val_index_list = split_train_val(whole_train_dst.indexset,
                                           whole_train_dst.get_label_list(), seed=PC.get_kfoldCV_seed(),
                                           k=NUM_K, val_ratio=0.2)
    avg_mean_var_distri_list = []
    avg_class_center_list = []
    total_sample_num_list = []
    for k in range(NUM_K):
        train_index, val_index = train_val_index_list[k]
        val_dst = whole_train_dst.get_dataset_by_indexes(val_index)
        val_dst.transform = get_transform('im_cifar'+str(num_class), t_type='test')
        feature_extractor = FeatureExtractor('fine-tune', load_feature_extractor(
            os.path.join(fe_path, '{}_model.th'.format(k)), 'resnet20'
            , device, num_classes=num_class), device)
        fold_features = feature_extractor.extractor_features_from_dst(val_dst, num_classes=num_class)
        mean_var_distri_list, sample_num_list, class_center_list = cal_class_distri(dst=None, feature_extractor=None,
                                                                                    num_class=num_class,
                                                                                    is_return_center=True,
                                                                                    class_split_features=fold_features)
        if k == 0:
            avg_mean_var_distri_list = mean_var_distri_list
            avg_class_center_list = class_center_list
            total_sample_num_list = sample_num_list
        else:
            avg_mean_var_distri_list = list(map(lambda a, b: (a[0]+b[0], a[1]+b[1]), avg_mean_var_distri_list, mean_var_distri_list))
            avg_class_center_list = list(map(lambda a, b: a + b, avg_class_center_list, class_center_list))
            total_sample_num_list = list(map(lambda a, b: a + b, total_sample_num_list, sample_num_list))

    avg_mean_var_distri_list = list(map(lambda a: (a[0] / NUM_K, a[1] / NUM_K), avg_mean_var_distri_list))
    calibrated_distri = angle_distribution_calibration(avg_mean_var_distri_list, total_sample_num_list)
    avg_class_center_list = [cc.cpu().numpy() / NUM_K for cc in avg_class_center_list]

    class_distri_info = {'original_distribution': avg_mean_var_distri_list,
                         'calibrated_distribution': calibrated_distri,
                         'class_centers': avg_class_center_list, 'sample_num_list': total_sample_num_list}

    save_path = PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path()
    with open(save_path + 'class_distri_info_kfold2_resnet20.pkl', 'wb') as fo:
        pickle.dump(class_distri_info, fo)


def save_train_distribution_main(num_class, fe_type='default'):
    random_seed = PC.get_global_random_seed('im_cifar'+str(num_class))
    device = "cuda:0"
    whole_train_dst = get_dataset('im_cifar' + str(num_class), split='train', rand_number=random_seed)
    whole_train_dst.transform = get_transform('im_cifar' + str(num_class), t_type='test')

    if fe_type == 'default':
        fe_model_path = PC.get_cifar10_fe_path() if num_class==10 else PC.get_cifar100_fe_path()
    else:
        fe_model_path = '/data/omf/model/DataValidation/CV/simclr/encoder/cifar10/SimCLR_cifar10_resnet18_lr_0.5_decay_0.0001_bsz_1024_temp_0.5_trial_0_warm/last.pth'
    feature_extractor = FeatureExtractor(fe_type,
                                         load_feature_extractor(fe_model_path, 'resnet20', device, num_classes=num_class, fe_type=fe_type),
                                         device)

    def save_train_distribution(train_dst, num_class, save_path):
        """ Save class distribution(guassian) and class center to file path

        :return:
        """
        mean_var_distri_list, sample_num_list, class_center_list = cal_class_distri(train_dst, feature_extractor,
                                                                                    num_class,
                                                                                    is_return_center=True)
        calibrated_distri = angle_distribution_calibration(mean_var_distri_list, sample_num_list)
        class_center_list = [cc.cpu().numpy() for cc in class_center_list]
        # print(mean_var_distri_list)
        class_distri_info = {'original_distribution':mean_var_distri_list, 'calibrated_distribution':calibrated_distri,
                             'class_centers': class_center_list, 'sample_num_list':sample_num_list}

        with open(save_path+'class_distri_info_{}.pkl'.format(fe_type), 'wb') as fo:
            pickle.dump(class_distri_info, fo)

    save_train_distribution(whole_train_dst, num_class,
                            save_path= PC.get_cifar10_distribution_save_path() if num_class==10 else PC.get_cifar100_distribution_save_path())


def save_test_distribution_main(num_class, fe_type='defualt'):
    random_seed = PC.get_global_random_seed('im_cifar'+str(num_class))
    device = "cuda:0"
    test_dst = get_dataset('im_cifar' + str(num_class), split='test', rand_number=random_seed)
    test_dst.transform = get_transform('im_cifar' + str(num_class), t_type='test')
    if fe_type == 'default':
        fe_model_path = PC.get_cifar10_fe_path() if num_class == 10 else PC.get_cifar100_fe_path()
    else:
        fe_model_path = '/data/omf/model/DataValidation/CV/simclr/encoder/cifar10/SimCLR_cifar10_resnet18_lr_0.5_decay_0.0001_bsz_1024_temp_0.5_trial_0_warm/last.pth'
    feature_extractor = FeatureExtractor(fe_type,
                                         load_feature_extractor(fe_model_path, 'resnet20', device, num_classes=num_class,
                                                                fe_type=fe_type),
                                         device)

    def save_distribution(dst, num_class, save_path):
        """ Save class distribution(guassian) and class center to file path

        :return:
        """
        mean_var_distri_list, sample_num_list, class_center_list = cal_class_distri(dst, feature_extractor,
                                                                                    num_class,
                                                                                    is_return_center=True)
        calibrated_distri = angle_distribution_calibration(mean_var_distri_list, sample_num_list)
        class_center_list = [cc.cpu().numpy() for cc in class_center_list]
        # print(mean_var_distri_list)
        class_distri_info = {'original_distribution':mean_var_distri_list, 'calibrated_distribution':calibrated_distri,
                             'class_centers': class_center_list, 'sample_num_list':sample_num_list}

        with open(save_path+'test_class_distri_info_{}.pkl'.format(fe_type), 'wb') as fo:
            pickle.dump(class_distri_info, fo)

    save_distribution(test_dst, num_class,
                      save_path= PC.get_cifar10_distribution_save_path() if num_class==10 else PC.get_cifar100_distribution_save_path())


def save_test_distribution_kfold_main(num_class, fe_path):
    NUM_K = 5
    random_seed = PC.get_global_random_seed('im_cifar'+str(num_class))
    device = "cuda:0"
    test_dst = get_dataset('im_cifar' + str(num_class), split='test', rand_number=random_seed)
    test_dst.transform = get_transform('im_cifar' + str(num_class), t_type='test')

    avg_mean_var_distri_list = []
    avg_class_center_list = []
    for k in range(NUM_K):
        feature_extractor = FeatureExtractor('fine-tune', load_feature_extractor(
            os.path.join(fe_path, '{}_model.th'.format(k)), 'resnet20',  device, num_classes=num_class), device)
        features = feature_extractor.extractor_features_from_dst(test_dst, num_classes=num_class)
        mean_var_distri_list, sample_num_list, class_center_list = cal_class_distri(dst=None, feature_extractor=None,
                                                                                    num_class=num_class,
                                                                                    is_return_center=True,
                                                                                    class_split_features=features)
        avg_mean_var_distri_list = mean_var_distri_list if len(avg_mean_var_distri_list) == 0 \
                                    else [(avg_mean_var_distri_list[i][0]+mean_var_distri_list[i][0], avg_mean_var_distri_list[i][1]+mean_var_distri_list[i][1]) for i in range(num_class)]
        if len(avg_mean_var_distri_list) == 0:
            avg_class_center_list = class_center_list
        else:
            avg_class_center_list = list(map(lambda a, b: a + b, avg_class_center_list, class_center_list))

    avg_mean_var_distri_list = list(map(lambda a: (a[0]/NUM_K, a[1]/NUM_K), avg_mean_var_distri_list))
    calibrated_distri = angle_distribution_calibration(avg_mean_var_distri_list, sample_num_list)
    avg_class_center_list = [cc.cpu().numpy()/NUM_K for cc in avg_class_center_list]
    #     if len(class_split_features) == 0:
    #         class_split_features = features
    #     else:
    #         for c in class_split_features.keys():
    #             class_split_features[c].add_(features[c])
    # for c in class_split_features.keys():
    #     class_split_features[c].div_(NUM_K)

    # mean_var_distri_list, sample_num_list, class_center_list = cal_class_distri(dst=None, feature_extractor=None,
    #                                                                             num_class=num_class, is_return_center=True,
    #                                                                             class_split_features=class_split_features)
    # calibrated_distri = angle_distribution_calibration(mean_var_distri_list, sample_num_list)
    # class_center_list = [cc.cpu().numpy() for cc in class_center_list]

    class_distri_info = {'original_distribution': avg_mean_var_distri_list,
                         'calibrated_distribution': calibrated_distri,
                         'class_centers': avg_class_center_list, 'sample_num_list': sample_num_list}

    save_path = PC.get_cifar10_distribution_save_path() if num_class==10 else PC.get_cifar100_distribution_save_path()
    with open(save_path+'test_class_distri_info_kfold_resnet20.pkl', 'wb') as fo:
        pickle.dump(class_distri_info, fo)


def update_class_center(num_class, file_name):
    random_seed = PC.get_global_random_seed('im_cifar' + str(num_class))
    device = "cuda:0"
    whole_train_dst = get_dataset('im_cifar' + str(num_class), split='train', rand_number=random_seed)
    whole_train_dst.transform = get_transform('im_cifar' + str(num_class), t_type='test')

    fe_model_path = PC.get_cifar10_fe_path() if num_class == 10 else PC.get_cifar100_fe_path()
    feature_extractor = FeatureExtractor('default',
                                         load_feature_extractor(fe_model_path, 'resnet20',device, num_classes=num_class,
                                                                fe_type='default'), device)

    mean_var_distri_list, sample_num_list, class_center_list = cal_class_distri(whole_train_dst, feature_extractor,
                                                                                num_class,
                                                                                is_return_center=True)
    class_center_list = [cc.cpu().numpy() for cc in class_center_list]
    print(mean_var_distri_list)

    save_path = PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path()

    with open(save_path + file_name, 'rb') as fi:  # 'class_distri_info.pkl'
        distri_info = pickle.load(fi)

    distri_info['class_centers'] = class_center_list

    with open(save_path + file_name, 'wb') as fo:
        pickle.dump(distri_info, fo)


def generate_augmented_data(augment_agent:CVAugment, origin_dst, aug_policy,
                            device="cuda:0", batch_size=64, save_path=None, info_save_path=None):
    """ Generate auxiliary dataset for source set by data augmentations

    :param origin_dst: whole train set(Source Set) in imbalanced-cifar-10
    :param batch_size: batch size in DataLoader
    :param save_path: the folder path to save generated(augmented) images(Auxiliary Dataset)
    :return:
    """
    bar = tqdm.tqdm(total=len(AugmentType) * len(origin_dst.class_split_indexes.keys()))
    index_map = {}
    method_map = {}
    for label in origin_dst.class_split_indexes.keys():
        index_map[label] = []
        method_map[label] = []
        class_x = origin_dst.get_dataset_by_class(label)
        class_loader = DataLoader(class_x, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        class_save_path = save_path + str(label) + '/'
        if not os.path.exists(class_save_path):
            os.makedirs(class_save_path)
        save_index = 0
        for at in AugmentType:
            # if at in [AugmentType.MixUp, AugmentType.CutMix, AugmentType.GridMask]:
            for t in range(aug_policy[at]):
                augmentor = augment_agent.get_aug(at, t)
                origin_index = 0
                for imgs, _ in class_loader:
                    imgs = imgs.to(device)
                    augmented_PIL_images = convert_tensor_to_PILimages(augmentor(imgs))
                    for i, pil_img in enumerate(augmented_PIL_images):
                        pil_img.save(class_save_path + str(save_index + i) + '.jpg')
                    save_index += len(augmented_PIL_images)

                    index_map[label].extend(np.arange(origin_index, origin_index+imgs.shape[0]).tolist())
                    method_map[label].extend([at]*imgs.shape[0])
                    origin_index += imgs.shape[0]

            bar.update(1)
    bar.close()

    with open(info_save_path, 'wb') as fo:
        pickle.dump({'index_map':index_map, 'method_map':method_map}, fo)


def generate_augmented_data_main(num_class):

    device = "cuda:0"
    random_seed = PC.get_global_random_seed('im_cifar'+str(num_class))

    aug_policy = {}
    # for at in AugmentType:
    #     if at in [AugmentType.MixUp, AugmentType.CutMix, AugmentType.GridMask]:
    #         aug_policy[at] = 3
    #     else:
    #         aug_policy[at] = 3
    for at in AugmentType:
        if at == AugmentType.GridMask:
            aug_policy[at] = 5
        else:
            aug_policy[at] = 4
    augment_agent = CVAugment(device=device)

    train_dst = get_dataset('im_cifar' + str(num_class), split='train', rand_number=random_seed, is_wrapper=True)
    train_dst.dataset.transform = get_transform('im_cifar' + str(num_class), t_type='to_tensor')

    generate_augmented_data(augment_agent, train_dst, aug_policy, batch_size=64, device=device,
                            save_path=PC.get_cifar10_data_pool_path() if num_class==10 else PC.get_cifar100_data_pool_path(),
                            info_save_path=PC.get_cifar10_data_pool_info() if num_class==10 else PC.get_cifar100_data_pool_info())


def cal_entropy(logits, eps=1e-10):
    probs = torch.nn.functional.softmax(logits, dim=1)
    entropys = -1*torch.sum(probs*torch.log2(probs+eps), dim=1)
    return entropys


def cal_data_pool_info(train_dst, data_pool, raw_info, feature_extractor:FeatureExtractor,
                        class_center_set, device="cuda:0"):
    from angle_distribution import cal_angle_to_center
    num_class = len(class_center_set)
    origin_class_features = feature_extractor.extractor_features_from_dst(train_dst, num_class)
    origin_class_logits = feature_extractor.extractor_features_from_dst(train_dst, num_class, is_logits=True)
    augmented_data_info = {}
    augmented_data_info['rank_info'] = {}
    for l in tqdm.tqdm(range(num_class)):
        dst = data_pool.get_dataset_by_class(l)
        loader = DataLoader(dst, batch_size=256, shuffle=False, num_workers=4)

        origin_center = torch.from_numpy(class_center_set[l]).to(device)
        origin_angles = cal_angle_to_center(origin_class_features[l], origin_center)
        origin_entropys = cal_entropy(origin_class_logits[l])
        # origin_entropys = origin_angles

        aug_class_features = []
        aug_class_logits = []
        for aug_images, _ in loader:
            aug_images = aug_images.to(device)
            aug_features = feature_extractor.get_features(aug_images).squeeze(dim=0)
            aug_logits = feature_extractor.get_features(aug_images, is_logits=True).squeeze(dim=0)
            aug_class_features.append(aug_features)
            aug_class_logits.append(aug_logits)

        aug_class_features = torch.cat(aug_class_features, dim=0)
        aug_class_logits = torch.cat(aug_class_logits, dim=0)

        aug_angles = cal_angle_to_center(aug_class_features, origin_center)
        origin_angles = origin_angles[raw_info['index_map'][l]]

        angle_changes = torch.abs(aug_angles - origin_angles).cpu().numpy()
        aug_angles = aug_angles.cpu().numpy()

        origin_entropys = origin_entropys[raw_info['index_map'][l]].cpu().numpy()
        aug_entropy = cal_entropy(aug_class_logits).cpu().numpy()
        # aug_entropy = angle_changes
        entropy_changes = aug_entropy-origin_entropys
        # entropy_changes = angle_changes

        augmented_data_info['rank_info'][l] = [(aug_angles[j], angle_changes[j], aug_entropy[j], entropy_changes[j],
                                                j, raw_info['method_map'][l][j]) for j in range(aug_angles.shape[0])]
        # cal_angle_p2p(origin_features, aug_features).cpu().numpy().tolist()

    augmented_data_info['index_map'] = raw_info['index_map']
    augmented_data_info['method_map'] = raw_info['method_map']
    return augmented_data_info


def cal_data_pool_info_main(num_class, fe_type='default'):
    device = "cuda:0"
    random_seed = PC.get_global_random_seed('im_cifar'+str(num_class))

    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'rb') as fi:
        raw_info = pickle.load(fi)

    from imbalanced_datasets import AugmentedDataset
    data_pool = DatasetWrapper(AugmentedDataset(PC.get_cifar10_data_pool_path() if num_class == 10 else PC.get_cifar100_data_pool_path(),
                                                transform=get_transform('cifar'+str(num_class), 'test')))
    train_dst = get_dataset('im_cifar' + str(num_class), split='train', rand_number=random_seed, is_wrapper=True)
    train_dst.dataset.transform = get_transform('im_cifar' + str(num_class), t_type='test')

    if fe_type == 'default':
        fe_model_path = PC.get_cifar10_fe_path() if num_class == 10 else PC.get_cifar100_fe_path()
    else:
        fe_model_path = '/data/omf/model/DataValidation/CV/simclr/encoder/cifar10/SimCLR_cifar10_resnet18_lr_0.5_decay_0.0001_bsz_1024_temp_0.5_trial_0_warm/last.pth'
    feature_extractor = FeatureExtractor(fe_type,
                                         load_feature_extractor(fe_model_path, 'resnet20', device, num_classes=num_class,
                                                                fe_type=fe_type),
                                         device)

    info = DistriInfo(PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path())

    augmented_data_info = cal_data_pool_info(train_dst, data_pool, raw_info,
                                              feature_extractor,
                                              class_center_set=info.get_class_centers())

    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'wb') as fo:
        pickle.dump(augmented_data_info, fo)


import functools
def cmp_func(a, b):
    # int(round(a[0])) == int(round(b[0])) and int(a[0]) == int(b[0])
    if a[0] == b[0]:
        # second rank indicate 2: angle changes 3: entropy 4: entropy change
        if a[2] > b[2]:
            return -1
        else:
            return 1
    else:
        if a[0] > b[0]:
            return 1
        else:
            return -1


def rank_main(num_class):
    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'rb') as fi:
        info = pickle.load(fi)

    ranked_info = info['rank_info']
    for c in tqdm.tqdm(range(num_class)):
        if len(ranked_info[c][0]) == 6:
            ranked_info[c] = [(int(ranked_info[c][j][0])+0.5 if round(ranked_info[c][j][0]) >= ranked_info[c][j][0] else int(ranked_info[c][j][0]),
                               *ranked_info[c][j]) for j in range(len(ranked_info[c]))]
        ranked_info[c] = sorted(ranked_info[c], key=functools.cmp_to_key(cmp_func))

    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'wb') as fo:
        pickle.dump({'index_map':info['index_map'], 'method_map':info['method_map'], 'rank_info':ranked_info}, fo)


def binary_search(nums, target: float) -> int:
    low, high = 0, len(nums) - 1
    while low <= high:
        mid = (high - low) // 2 + low
        num = nums[mid]
        if num == target:
            return mid
        elif num > target:
            high = mid - 1
        else:
            low = mid + 1
    return -1


def sample_main():
    num_class = 10
    val_num_per_class = 200
    distri_file_path = PC.get_cifar10_distribution_save_path() if num_class==10 else PC.get_cifar100_distribution_save_path()
    class_distri_info = DistriInfo(distri_file_path)

    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'rb') as fi:
        augmented_data_info = pickle.load(fi)

    # remove_duplication(augmented_data_info['angle_info'], tolerance=1)

    sampler = DistributionSampler(class_distri_info.get_calibrated_distribution(),
                                  remove_duplication(augmented_data_info['rank_info'], tolerance=0))

    sample_num_dict = {}
    for i in range(num_class):
        sample_num_dict[i] = val_num_per_class
    val_results = sampler.sample(sample_num_dict)
    # print(val_results)

