import functools
import pickle
import numpy as np
import random
from utils import CVConfig
import torch
from copulas.multivariate import GaussianMultivariate
from copulas.univariate import GammaUnivariate
from copulas.bivariate import Bivariate, CopulaTypes

PC = CVConfig()


def pair_dist(a: torch.Tensor, b: torch.Tensor, p: int = 2) -> torch.Tensor:
    return (a-b).abs().pow(p).sum(-1).pow(1/p)


def cmp_fc(a, b):
    if a[2] > b[2]:
        # second rank indicate 2: angle changes 3: entropy 4: entropy change
        return -1
    else:
        return 1


class DistriInfo(object):
    def __init__(self, save_path, is_train=True):
        if is_train:
            with open(save_path+'class_distri_info_kfold3_resnet20_Gaussian.pkl', 'rb') as fi: #  'class_distri_info.pkl'
            # with open(save_path, 'rb') as fi:
                self.distri_info = pickle.load(fi)
        else:
            with open(save_path, 'rb') as fi:
            # with open(save_path+'test_class_distri_info_default.pkl', 'rb') as fi: #  'test_class_distri_info.pkl''test_class_distri_info_kfold_resnet20.pkl'
                self.distri_info = pickle.load(fi)

    def get_origin_distribution(self):
        return self.distri_info['original_distribution']

    def get_calibrated_distribution(self):
        return self.distri_info['calibrated_distribution']

    def get_class_centers(self):
        return self.distri_info['class_centers']

    def get_joint_distribution(self, class_index=None):
        if class_index is None:
            return self.distri_info['joint_distribution_params']
        else:
            return self.distri_info['joint_distribution_params'][class_index]

    def get_angle_diversity_data(self, class_index):
        if class_index is None:
            return self.distri_info['angle_diversity_data']
        else:
            return self.distri_info['angle_diversity_data'][class_index]


class DistributionSampler(object):
    def __init__(self, distribution, sample_rank_list):
        self.distribution = distribution
        self.sample_rank_list = sample_rank_list
        self.index_pos = 5
        self.in_out_sample_ratio = 1 #0.65 for cifar10 v4 0.55 for cifar100 v3
        self.joint_distribution = None
        self.bi_type = None

    def set_bi_type(self, bi_type):
        self.bi_type = bi_type

    def set_joint_distribution(self, joint_distribution):
        self.joint_distribution = joint_distribution

    def set_in_out_sample_ratio(self, ratio):
        self.in_out_sample_ratio = ratio

    def _preprocess(self, class_index, sample_num):
        sampled_angle_list = np.random.normal(self.distribution[class_index][0], np.sqrt(self.distribution[class_index][1]),
                                      sample_num)
        sampled_angle_list = np.sort(sampled_angle_list)
        sample_list = self.sample_rank_list[class_index]

        candiate_angle_range = (sample_list[0][0], sample_list[-1][0])
        print("augmented angle range: ", candiate_angle_range[0], '--', candiate_angle_range[1])
        print("sampled angle range: ", sampled_angle_list[0], '--', sampled_angle_list[-1])
        target_angle_range = (
        sampled_angle_list[0] if sampled_angle_list[0] > candiate_angle_range[0] else candiate_angle_range[0],
        candiate_angle_range[1] if candiate_angle_range[1] < sampled_angle_list[-1] else sampled_angle_list[-1])
        print('target angle range: ', target_angle_range[0], '--', target_angle_range[1])
        return sample_list, candiate_angle_range, sampled_angle_list, target_angle_range

    def sample_in_class_in_intersect_interval(self, class_index, sample_num, is_debug=False):
        print("=" * 20 + str(class_index) + "=" * 20)
        sample_list, candiate_angle_range, sampled_angle_list, target_angle_range = self._preprocess(class_index, sample_num)
        # target_angle_range = (sampled_angle_list[-1], candiate_angle_range[1])
        intersec_ratio = 0
        for angle in sampled_angle_list:
            intersec_ratio += 1 if target_angle_range[0] <= angle and angle <= target_angle_range[1] else 0
        ratio1 = intersec_ratio/len(sampled_angle_list)
        print('intersection part ratio of sampled angle list: ', ratio1)

        selected_index, data_info = [], []
        for index, info in enumerate(sample_list):
            if target_angle_range[0] <= info[1] and info[1] <= target_angle_range[1]:
                selected_index.append(info[self.index_pos])
                data_info.append((index, *info))

        ratio2 = len(selected_index) / len(sample_list)
        print('intersection part ratio of aug angle list: ', ratio2)

        if sample_num < len(selected_index):
            indexes = np.random.permutation(len(selected_index))[0:sample_num].tolist()
            selected_index = [selected_index[i] for i in indexes]
            data_info = [data_info[i] for i in indexes]

        if is_debug:
            return selected_index, data_info, ratio1, ratio2
        else:
            return selected_index, data_info

    def sample_in_class(self, class_index, sample_num):
        in_sample_num = int(sample_num*self.in_out_sample_ratio)
        out_sample_num = sample_num - in_sample_num
        print("="*20+str(class_index)+"="*20)
        sample_list, candiate_angle_range, sampled_angle_list, target_angle_range = self._preprocess(class_index, in_sample_num)

        start_index = {}
        for index, info in enumerate(sample_list):
            if info[0] not in start_index:
                start_index[info[0]] = index

        not_in_range_num = 0
        exceed_interval_num = 0
        selected_index, data_info = [], []
        for angle in sampled_angle_list:
            angle = 0 if angle < 0 else angle
            target_angle = int(angle)+0.5 if round(angle) >= angle else int(angle)
            if target_angle < candiate_angle_range[0] or target_angle > candiate_angle_range[1]:
                not_in_range_num += 1
                continue
            # index = binary_search(sample_list, target_angle)
            if target_angle in start_index:
                # selected_index.append(sample_list[start_index[target_angle]][3])
                # data_info.append((angle, *sample_list[start_index[target_angle]]))
                # start_index[target_angle] += 1
                if target_angle+0.5 in start_index:
                    if start_index[target_angle] < start_index[target_angle+0.5]:
                        selected_index.append(sample_list[start_index[target_angle]][self.index_pos])
                        data_info.append((angle, *sample_list[start_index[target_angle]]))
                    elif start_index[target_angle] == start_index[target_angle+0.5]:
                        exceed_interval_num +=1
                        print("reach the end: ", target_angle)
                    else:
                        exceed_interval_num += 1
                        print("exceed the end ", target_angle)
                else:
                    if start_index[target_angle] >= len(sample_list):
                        exceed_interval_num += 1
                        print("exceed the max angle in sample list ", target_angle)
                    else:
                        selected_index.append(sample_list[start_index[target_angle]][self.index_pos])
                        data_info.append((angle, *sample_list[start_index[target_angle]]))
                start_index[target_angle] += 1
            else:
                print('angle not found ', angle)
        print("not in range sample times: ", not_in_range_num)
        print("not in interval sample times: ", exceed_interval_num)

        if out_sample_num >= 1:
            print("out sample num ", out_sample_num)
            out_angle = int(target_angle_range[1])+0.5 if round(target_angle_range[1]) >= target_angle_range[1] else int(target_angle_range[1])
            if out_angle in start_index:
                out_angle_head = start_index[out_angle]
                out_sample_list = [sample_list[i] for i in np.arange(out_angle_head, len(sample_list))]
                out_sample_list = sorted(out_sample_list, key=functools.cmp_to_key(cmp_fc))
                out_selected_index = [out_sample_list[i][self.index_pos] for i in range(out_sample_num)]
                selected_index.extend(out_selected_index)
                data_info.extend([(i, *out_sample_list[i]) for i in range(out_sample_num)])
                # out_index = np.random.permutation(np.arange(out_angle_head, len(sample_list)))[0:out_sample_num]
                # out_selected_index = [sample_list[i][self.index_pos] for i in out_index]
                # selected_index.extend(out_selected_index)
                # data_info.extend([(i, *sample_list[i]) for i in out_index])
            else:
                print("no out samples")
        return selected_index, data_info

    def sample_in_class_v2(self, class_index, sample_num):
        aug_angle_diversity_vec = torch.tensor([[item[0], item[4]] for item in self.sample_rank_list[class_index]], dtype=float)
        if self.bi_type == 'Gaussian':
            joint_distri = GaussianMultivariate.from_dict(self.joint_distribution[class_index])
            sampled_points = joint_distri.sample(sample_num)
            sampled_points = torch.from_numpy(sampled_points.values)
        elif self.bi_type == 'Gumbel':
            distr_dict = self.joint_distribution[class_index]
            joint_distri = Bivariate(CopulaTypes.GUMBEL).from_dict(distr_dict['joint_d'])
            am_distri = GammaUnivariate.from_dict(distr_dict['angle_md'])
            dm_distri = GammaUnivariate.from_dict(distr_dict['diversity_md'])
            sampled_points = joint_distri.sample(sample_num)
            sampled_points = np.stack((am_distri.ppf(sampled_points[:, 0]), dm_distri.ppf(sampled_points[:, 1])),
                                      axis=1)
            sampled_points = torch.from_numpy(sampled_points)
        match_dist_matrix = pair_dist(sampled_points.unsqueeze(1), aug_angle_diversity_vec.unsqueeze(0))
        dists, indices = match_dist_matrix.topk(50, dim=1, largest=False)
        dists, indices = dists.numpy(), indices.numpy()
        selected_index_set = set()
        selected_sample_info_list = []
        total_not_in_num = 0
        for j in range(indices.shape[0]):
            inside_indices = False
            for k in range(indices.shape[1]):
                curr_select_aug_index = indices[j, k]
                if curr_select_aug_index in selected_index_set:
                    continue
                else:
                    selected_index_set.add(curr_select_aug_index)

                    selected_sample_info_list.append((curr_select_aug_index, dists[j, k],
                                                      abs(sampled_points[j, 0] - aug_angle_diversity_vec[
                                                          curr_select_aug_index, 0]).item(),
                                                      abs(sampled_points[j, 1] - aug_angle_diversity_vec[
                                                          curr_select_aug_index, 1]).item(),
                                                      *self.sample_rank_list[class_index][curr_select_aug_index]))
                    inside_indices = True
                    break
            if not inside_indices:
                total_not_in_num += 1
        print(total_not_in_num)
        return [item[self.index_pos + 4] for item in selected_sample_info_list], selected_sample_info_list

    def sample(self, sample_num_dict, s_type='default'):
        num_class = self.sample_rank_list.keys()
        results = {'index':{}, 'info':{}}
        r1_list, r2_list = [], []
        for c in num_class:
            if s_type == 'default':
                inclass_indexes, selected_sample_info = self.sample_in_class(c, sample_num_dict[c])
            elif s_type == 'joint':
                inclass_indexes, selected_sample_info = self.sample_in_class_v2(c, sample_num_dict[c])
            elif s_type == 'random':
                inclass_indexes, selected_sample_info, r1, r2\
                    = self.sample_in_class_in_intersect_interval(c, sample_num_dict[c], is_debug=True)
                r1_list.append(r1)
                r2_list.append(r2)
            results['index'][c] = inclass_indexes
            results['info'][c] = selected_sample_info
        if r2_list.__len__()>0:
            print("average intersection part ratio of sampled angle list", np.mean(r1_list), np.min(r1_list), np.max(r1_list))
            print('average intersection part ratio of aug angle list: ', np.mean(r2_list), np.min(r2_list), np.max(r2_list))

        return results


def remove_duplication(sample_rank_list, tolerance=0):
    shift = 0
    new_rank_list = {}
    for c, rank_list in sample_rank_list.items():
        index = 1
        pre_sample = rank_list[0]
        print(c, len(rank_list))
        curr_tolerance = tolerance
        while index < len(rank_list):
            if rank_list[index][1+shift] == pre_sample[1+shift] \
                    and (rank_list[index][2+shift] == pre_sample[2+shift] or rank_list[index][3+shift] == pre_sample[3+shift]):
                if curr_tolerance <= 0:
                    del rank_list[index]
                else:
                    curr_tolerance -= 1
                    index += 1
            else:
                pre_sample = rank_list[index]
                index += 1
                curr_tolerance = tolerance
        print(len(rank_list))
        new_rank_list[c] = rank_list
    return new_rank_list


def sample_val(num_class, method, val_num_per_class, random_seed, is_return_info=False):
    np.random.seed(random_seed)
    random.seed(random_seed)
    with open(PC.get_cifar10_data_pool_info() if num_class == 10 else PC.get_cifar100_data_pool_info(), 'rb') as fi:
        augmented_data_info = pickle.load(fi)

    if type(val_num_per_class) is dict:
        sample_num_dict = val_num_per_class
    else:
        sample_num_dict = {}
        for i in range(num_class):
            sample_num_dict[i] = val_num_per_class

    if method == 'RANDOM':
        val_info = {'index':{}, 'info':{}}
        pool_info = remove_duplication(augmented_data_info['rank_info'], tolerance=0)
        for c in range(num_class):
            total_num = len(augmented_data_info['rank_info'][c])
            val_info['index'][c] = np.random.permutation(total_num)[0:sample_num_dict[c]].tolist()
            val_info['info'][c] = [pool_info[c][i] for i in val_info['index'][c]]
        val_index = val_info['index']
    elif "DB" in method:
        distri_file_path = PC.get_cifar10_distribution_save_path() if num_class == 10 else PC.get_cifar100_distribution_save_path()
        class_distri_info = DistriInfo(distri_file_path, is_train=True)
        sampler = DistributionSampler(class_distri_info.get_origin_distribution(),
                                      remove_duplication(augmented_data_info['rank_info'], tolerance=0))
        if method == 'DB':
            # 0.65 for cifar10 v4 0.3 for cifar100 v4
            # sampler.set_in_out_sample_ratio(ratio=0.3 if num_class == 100 else 0.8)
            sampler.set_in_out_sample_ratio(ratio=1.0)
            val_info = sampler.sample(sample_num_dict)
        elif method == 'DB_RANDOM':
            val_info = sampler.sample(sample_num_dict, s_type='random')
        elif method == 'DB_ADJOINT':
            sampler.set_bi_type('Gaussian')
            sampler.set_joint_distribution(class_distri_info.get_joint_distribution())
            val_info = sampler.sample(sample_num_dict, s_type='joint')

        val_index = val_info['index']

    if is_return_info:
        return val_info
    else:
        return val_index


from sklearn.model_selection import train_test_split, StratifiedKFold


def split_train_val(indexes, y, seed, k=1, val_ratio=0.2):
    labels = y
    train_val_set_list = []
    if k == 1:
        train_indexes, val_indexes, _, _ = train_test_split(indexes, labels, test_size=val_ratio,
                                                            stratify=labels, random_state=seed)
        train_val_set_list.append((train_indexes, val_indexes))
    else:
        kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
        for train_ind, val_ind in kfold.split(indexes, labels):
            train_val_set_list.append((train_ind, val_ind))
    return train_val_set_list


if __name__ == '__main__':
    import time
    start_t = time.time()
    sample_val(num_class=100, method='DB_ADJOINT', val_num_per_class=30, random_seed=42)
    print("time ", time.time()-start_t)
