import functools
import pickle
import numpy as np
import random
import os
from utils import PathConfig
from copulas.multivariate import GaussianMultivariate
from copulas.univariate import GammaUnivariate, GaussianUnivariate, Univariate
from copulas.bivariate import Bivariate, CopulaTypes
import torch
from sklearn.neighbors import KernelDensity


PC = PathConfig()


def pair_dist(a: torch.Tensor, b: torch.Tensor, p: int = 2, batch_size: int = 512) -> torch.Tensor:
    results = torch.empty(a.shape[0], b.shape[1], dtype=float)
    total_num = a.shape[0]
    weight_vec = torch.tensor([0.5, 0.5]).to(a.device)
    for i in range(total_num//batch_size+1):
        head = batch_size * i
        if head >= total_num:
            break
        tail = batch_size * i + batch_size if batch_size * i + batch_size < total_num else total_num
        results[head:tail,:] = ((a[head:tail,...]-b).abs()*weight_vec).pow(p).sum(-1).pow(1/p)

    return results


def cmp_fc(a, b):
    if a[3] > b[3]:
        # second rank indicate 2: angle changes
        return -1
    else:
        return 1


class DistriInfo(object):
    def __init__(self, save_path, is_train=True):
        if is_train:
            with open(os.path.join(save_path, 'kfold_class_distri_info_Gaussian.pkl'), 'rb') as fi: #  'class_distri_info.pkl'
            # with open(save_path, 'rb') as fi:
                self.distri_info = pickle.load(fi)
        else:
            # with open(save_path, 'rb') as fi:
            with open(save_path+'.pkl', 'rb') as fi: #  'test_class_distri_info.pkl'
                self.distri_info = pickle.load(fi)

    def get_origin_distribution(self):
        return self.distri_info['original_distribution']

    def get_calibrated_distribution(self):
        return self.distri_info['calibrated_distribution']

    def get_class_centers(self):
        return self.distri_info['class_centers']

    def get_joint_distribution(self, class_index=None):
        if class_index is None:
            return self.distri_info['joint_distribution_params']
        else:
            return self.distri_info['joint_distribution_params'][class_index]

    def get_univar_distribution(self, class_index=None):
        if class_index is None:
            return self.distri_info['distribution_params']
        else:
            return self.distri_info['distribution_params'][class_index]

    def get_angle_diversity_data(self, class_index):
        if class_index is None:
            return self.distri_info['angle_diversity_data']
        else:
            return self.distri_info['angle_diversity_data'][class_index]

    def get_angle_data(self, class_index):
        if class_index is None:
            return self.distri_info['angle_data']
        else:
            return self.distri_info['angle_data'][class_index]


class DistributionSampler(object):
    def __init__(self, distribution, sample_rank_list):
        self.distribution = distribution
        self.sample_rank_list = sample_rank_list
        self.index_pos = 5
        self.angle_pos = 0
        self.diversity_pos = 4
        self.in_out_sample_ratio = 1
        self.bi_type = None

    def set_bi_type(self, bi_type):
        self.bi_type = bi_type

    def set_joint_distribution(self, joint_distribution):
        self.joint_distribution = joint_distribution

    def set_univar_distribution(self, univar_distribution):
        self.univar_distribution = univar_distribution

    def set_in_out_sample_ratio(self, ratio):
        self.in_out_sample_ratio = ratio

    def _preprocess(self, class_index, sample_num):
        distri = Univariate.from_dict(self.univar_distribution [class_index])
        sampled_angle_list = distri.sample(sample_num)
        sampled_angle_list = np.sort(sampled_angle_list)
        sample_list = self.sample_rank_list[class_index]

        candiate_angle_range = (sample_list[0][0], sample_list[-1][0])
        print("augmented angle range: ", candiate_angle_range[0], '--', candiate_angle_range[1])
        print("sampled angle range: ", sampled_angle_list[0], '--', sampled_angle_list[-1])
        target_angle_range = (
        sampled_angle_list[0] if sampled_angle_list[0] > candiate_angle_range[0] else candiate_angle_range[0],
        candiate_angle_range[1] if candiate_angle_range[1] < sampled_angle_list[-1] else sampled_angle_list[-1])
        print('target angle range: ', target_angle_range[0], '--', target_angle_range[1])
        return sample_list, candiate_angle_range, sampled_angle_list, target_angle_range

    def sample_in_class_in_intersect_interval(self, class_index, sample_num, is_debug=False):
        print("=" * 20 + str(class_index) + "=" * 20)
        sample_list, candiate_angle_range, sampled_angle_list, target_angle_range = self._preprocess(class_index, sample_num)
        # target_angle_range = (sampled_angle_list[-1], candiate_angle_range[1])
        intersec_ratio = 0
        for angle in sampled_angle_list:
            intersec_ratio += 1 if target_angle_range[0] <= angle and angle <= target_angle_range[1] else 0
        ratio1 = intersec_ratio/len(sampled_angle_list)
        print('intersection part ratio of sampled angle list: ', ratio1)

        selected_index, data_info = [], []
        for index, info in enumerate(sample_list):
            if target_angle_range[0] <= info[1] and info[1] <= target_angle_range[1]:
                selected_index.append(info[self.index_pos])
                data_info.append((index, *info))

        ratio2 = len(selected_index) / len(sample_list)
        print('intersection part ratio of aug angle list: ', ratio2)

        if sample_num < len(selected_index):
            indexes = np.random.permutation(len(selected_index))[0:sample_num].tolist()
            selected_index = [selected_index[i] for i in indexes]
            data_info = [data_info[i] for i in indexes]

        if is_debug:
            return selected_index, data_info, ratio1, ratio2
        else:
            return selected_index, data_info

    def sample_in_class(self, class_index, sample_num):
        in_sample_num = int(sample_num*self.in_out_sample_ratio)
        out_sample_num = sample_num - in_sample_num
        print("="*20+str(class_index)+"="*20)
        sample_list, candiate_angle_range, sampled_angle_list, target_angle_range = self._preprocess(class_index, in_sample_num)

        start_index = {}
        for index, info in enumerate(sample_list):
            if info[0] not in start_index:
                start_index[info[0]] = index

        not_in_range_num = 0
        exceed_interval_num = 0
        selected_index, data_info = [], []
        for angle in sampled_angle_list:
            angle = 0 if angle < 0 else angle
            target_angle = int(angle)+0.5 if round(angle) >= angle else int(angle)
            if target_angle < candiate_angle_range[0] or target_angle > candiate_angle_range[1]:
                not_in_range_num += 1
                continue
            # index = binary_search(sample_list, target_angle)
            if target_angle in start_index:
                # selected_index.append(sample_list[start_index[target_angle]][3])
                # data_info.append((angle, *sample_list[start_index[target_angle]]))
                # start_index[target_angle] += 1
                if target_angle+0.5 in start_index:
                    if start_index[target_angle] < start_index[target_angle+0.5]:
                        selected_index.append(sample_list[start_index[target_angle]][self.index_pos])
                        data_info.append((angle, *sample_list[start_index[target_angle]]))
                    elif start_index[target_angle] == start_index[target_angle+0.5]:
                        exceed_interval_num +=1
                        print("reach the end: ", target_angle)
                    else:
                        exceed_interval_num += 1
                        print("exceed the end ", target_angle)
                else:
                    if start_index[target_angle] >= len(sample_list):
                        exceed_interval_num += 1
                        print("exceed the max angle in sample list ", target_angle)
                    else:
                        selected_index.append(sample_list[start_index[target_angle]][self.index_pos])
                        data_info.append((angle, *sample_list[start_index[target_angle]]))
                start_index[target_angle] += 1
            else:
                print('angle not found ', angle)
        print("not in range sample times: ", not_in_range_num)
        print("not in interval sample times: ", exceed_interval_num)

        if out_sample_num >= 1:
            print("out sample num ", out_sample_num)
            out_angle = int(target_angle_range[1])+0.5 if round(target_angle_range[1]) >= target_angle_range[1] else int(target_angle_range[1])
            if out_angle in start_index:
                out_angle_head = start_index[out_angle]
                out_sample_list = [sample_list[i] for i in np.arange(out_angle_head, len(sample_list))]
                out_sample_list = sorted(out_sample_list, key=functools.cmp_to_key(cmp_fc))
                out_selected_index = [out_sample_list[i][self.index_pos] for i in range(out_sample_num)]
                selected_index.extend(out_selected_index)
                data_info.extend([(i, *out_sample_list[i]) for i in range(out_sample_num)])
                # out_index = np.random.permutation(np.arange(out_angle_head, len(sample_list)))[0:out_sample_num]
                # out_selected_index = [sample_list[i][self.index_pos] for i in out_index]
                # selected_index.extend(out_selected_index)
                # data_info.extend([(i, *sample_list[i]) for i in out_index])
            else:
                print("no out samples")
        return selected_index, data_info

    def sample_in_class_default(self, class_index, sample_num):
        in_sample_num = int(sample_num * self.in_out_sample_ratio)
        print("=" * 20 + str(class_index) + "=" * 20)
        sample_list, candiate_angle_range, sampled_angle_list, target_angle_range = self._preprocess(class_index,
                                                                                                     in_sample_num)

        sampled_angle_list = [int(sampled_angle_list[j]) + 0.5 if round(sampled_angle_list[j]) >= sampled_angle_list[j]
                               else int(sampled_angle_list[j]) for j in range(len(sampled_angle_list))]
        start_index = {}
        for index, info in enumerate(sample_list):
            angle_index = info[0]
            if angle_index not in start_index:
                start_index[angle_index] = (index, 1)
            else:
                start_index[angle_index] = (start_index[angle_index][0], start_index[angle_index][1]+1)

        sampled_angle_count_dict = {}
        for angle in sampled_angle_list:
            if angle in sampled_angle_count_dict:
                sampled_angle_count_dict[angle] += 1
            else:
                sampled_angle_count_dict[angle] = 1

        not_in_range_num = 0
        exceed_interval_num = 0
        selected_index, data_info = [], []
        sample_list = np.array(sample_list)
        for angle, num in sampled_angle_count_dict.items():
            if angle not in start_index:
                not_in_range_num += 1
                continue

            num = num if num < start_index[angle][1] else start_index[angle][1]

            sampled_ids = np.random.permutation(start_index[angle][1])[0:num] + start_index[angle][0]
            selected_index.extend(sample_list[sampled_ids, self.index_pos].astype(int).tolist())
            data_info.extend((angle, *sample_list[sampled_ids, self.index_pos].tolist()))

        print("not in range sample times: ", not_in_range_num)
        print("not in interval sample times: ", exceed_interval_num)

        return selected_index, data_info

    def sample_in_class_v2(self, class_index, sample_num):
        aug_angle_diversity_vec = torch.tensor([[item[self.angle_pos], item[self.diversity_pos]]
                                                for item in self.sample_rank_list[class_index]], dtype=float)

        if self.bi_type == 'Gaussian':
            joint_distri = GaussianMultivariate.from_dict(self.joint_distribution[class_index])
            sampled_points = joint_distri.sample(sample_num)
            sampled_points = sampled_points[sampled_points.diversity>=1]
            sampled_points = torch.from_numpy(sampled_points.values)
        elif self.bi_type == 'Gumbel':
            distr_dict = self.joint_distribution[class_index]
            joint_distri = Bivariate(CopulaTypes.GUMBEL).from_dict(distr_dict['joint_d'])
            am_distri = GammaUnivariate.from_dict(distr_dict['angle_md'])
            dm_distri = GammaUnivariate.from_dict(distr_dict['diversity_md'])
            sampled_points = joint_distri.sample(sample_num)
            sampled_points = np.stack((am_distri.ppf(sampled_points[:, 0]), dm_distri.ppf(sampled_points[:, 1])),
                                      axis=1)
            sampled_points = torch.from_numpy(sampled_points)
        elif self.bi_type == 'KDE':
            distr_dict = self.joint_distribution[class_index]
            joint_distri = distr_dict
            sampled_points = joint_distri.sample(sample_num)
            sampled_points = sampled_points[sampled_points[:, 1] >= 1]
            sampled_points = torch.from_numpy(sampled_points)

        match_dist_matrix = pair_dist(sampled_points.unsqueeze(1), aug_angle_diversity_vec.unsqueeze(0))
        dists, indices = match_dist_matrix.topk(50, dim=1, largest=False)
        dists, indices = dists.numpy(), indices.numpy()
        selected_index_set = set()
        selected_sample_info_list = []
        total_not_in_num = 0
        for j in range(indices.shape[0]):
            inside_indices = False
            for k in range(indices.shape[1]):
                curr_select_aug_index = indices[j, k]
                if curr_select_aug_index in selected_index_set:
                    continue
                else:
                    selected_index_set.add(curr_select_aug_index)

                    selected_sample_info_list.append((curr_select_aug_index, dists[j, k],
                                                      abs(sampled_points[j, 0] - aug_angle_diversity_vec[
                                                          curr_select_aug_index, 0]).item(),
                                                      abs(sampled_points[j, 1] - aug_angle_diversity_vec[
                                                          curr_select_aug_index, 1]).item(),
                                                      *self.sample_rank_list[class_index][curr_select_aug_index]))
                    inside_indices = True
                    break
            if not inside_indices:
                total_not_in_num += 1
        print(total_not_in_num)
        return [item[self.index_pos + 4] for item in selected_sample_info_list], selected_sample_info_list

    def sample(self, sample_num_dict, s_type='default'):
        num_class = self.sample_rank_list.keys()
        results = {'index':{}, 'info':{}}
        r1_list, r2_list = [], []
        for c in num_class:
            if s_type == 'default':
                inclass_indexes, selected_sample_info = self.sample_in_class(c, sample_num_dict[c])
                print(c, len(inclass_indexes))
            elif s_type == 'joint':
                inclass_indexes, selected_sample_info = self.sample_in_class_v2(c, sample_num_dict[c])
            elif s_type == 'random':
                inclass_indexes, selected_sample_info, r1, r2\
                    = self.sample_in_class_in_intersect_interval(c, sample_num_dict[c], is_debug=True)
                r1_list.append(r1)
                r2_list.append(r2)
            results['index'][c] = inclass_indexes
            results['info'][c] = selected_sample_info
        if r2_list.__len__()>0:
            print("average intersection part ratio of sampled angle list", np.mean(r1_list), np.min(r1_list), np.max(r1_list))
            print('average intersection part ratio of aug angle list: ', np.mean(r2_list), np.min(r2_list), np.max(r2_list))

        return results


def remove_duplication(sample_rank_list, tolerance=0):
    shift = 0
    new_rank_list = {}
    for c, rank_list in sample_rank_list.items():
        index = 1
        pre_sample = rank_list[0]
        print(c, len(rank_list))
        curr_tolerance = tolerance
        while index < len(rank_list):
            if rank_list[index][1+shift] == pre_sample[1+shift] and (rank_list[index][2+shift] == pre_sample[2+shift] or rank_list[index][3+shift] == pre_sample[3+shift]):
                if curr_tolerance <= 0:
                    del rank_list[index]
                else:
                    curr_tolerance -= 1
                    index += 1
            else:
                pre_sample = rank_list[index]
                index += 1
                curr_tolerance = tolerance
        print(len(rank_list))
        new_rank_list[c] = rank_list
    return new_rank_list


def sample_val(dst_name, method, val_num_per_class, random_seed, class_ratio_list=None, is_return_info=False):
    np.random.seed(random_seed)
    random.seed(random_seed)
    with open(os.path.join(PC.get_data_pool_info_path(dst_name), 'data_pool_info.pkl'), 'rb') as fi:
        augmented_data_info = pickle.load(fi)
    num_class = len(augmented_data_info['index_map'].keys())
    sample_num_dict = {}
    for i in range(num_class):
        if class_ratio_list is None:
            sample_num_dict[i] = val_num_per_class
        else:
            sample_num_dict[i] = int(val_num_per_class * class_ratio_list[i])

    if method == 'RANDOM':
        val_info = {'index':{}, 'info':{}}
        pool_info = remove_duplication(augmented_data_info['rank_info'], tolerance=0)
        for c in range(num_class):
            total_num = len(augmented_data_info['rank_info'][c])
            val_info['index'][c] = np.random.permutation(total_num)[0:sample_num_dict[c]].tolist()
            val_info['info'][c] = [pool_info[c][i] for i in val_info['index'][c]]
        val_index = val_info['index']
    elif "DB" in method:
        distri_file_path = PC.get_distribution_path(dst_name)
        class_distri_info = DistriInfo(distri_file_path, is_train=True)
        sampler = DistributionSampler(class_distri_info.get_calibrated_distribution(),
                                      remove_duplication(augmented_data_info['rank_info'], tolerance=0))
        if method == 'DB':
            sampler.set_univar_distribution(class_distri_info.get_univar_distribution())
            sampler.set_in_out_sample_ratio(ratio=1.0)
            val_info = sampler.sample(sample_num_dict)
        elif method == 'DB_RANDOM':
            sampler.set_univar_distribution(class_distri_info.get_univar_distribution())
            val_info = sampler.sample(sample_num_dict, s_type='random')
        elif method == 'DB_ADJOINT':
            print(sample_num_dict)
            sampler.set_bi_type('Gaussian')
            sampler.set_joint_distribution(class_distri_info.get_joint_distribution())
            val_info = sampler.sample(sample_num_dict, s_type='joint')
        val_index = val_info['index']

    if is_return_info:
        return val_info
    else:
        return val_index


if __name__ == '__main__':
    import time
    start_t = time.time()
    sample_val(dst_name='reuters', method="DB_ADJOINT", val_num_per_class=450, random_seed=0)
    print("time ", time.time()-start_t)
