import torch
import math
from tqdm import tqdm
import numpy as np


# def large_scale_self_cosine_similarity(data_x):
#     total_num = data_x.shape[0]
#     data_x_part_one = data_x[0:total_num//2, :]
#     data_x_part_two = data_x[total_num//2:total_num, :]
#     upper_half_part = torch.cat([torch.cosine_similarity(data_x_part_one.unsqueeze(1), data_x_part_one.unsqueeze(0), dim=-1),
#                torch.cosine_similarity(data_x_part_one.unsqueeze(1), data_x_part_two.unsqueeze(0), dim=-1)], dim=1).cpu()
#     lower_halp_part = torch.cat([torch.cosine_similarity(data_x_part_two.unsqueeze(1), data_x_part_one.unsqueeze(0), dim=-1),
#                torch.cosine_similarity(data_x_part_two.unsqueeze(1), data_x_part_two.unsqueeze(0), dim=-1)], dim=1).cpu()
#     return torch.cat([upper_half_part, lower_halp_part], dim=0).to("cuda:0")


def large_scale_cosine_similarity(data_a, data_b, batch_size=64, epsilon=1e-8):
    total_num = data_a.shape[0]
    whole_angle_matrix = []
    for i in range(total_num // batch_size + 1):
        head = batch_size * i
        if head >= total_num:
            break
        tail = batch_size * i + batch_size if batch_size * i + batch_size < total_num else total_num
        angle_matrix = torch.clamp(torch.cosine_similarity(data_a[head:tail, :].unsqueeze(1), data_b.unsqueeze(0), dim=-1),
                        min=-1 + epsilon, max=1 - epsilon)
        whole_angle_matrix.append(angle_matrix)
    return torch.cat(whole_angle_matrix, dim=0).to(data_a.device)


def cal_class_center(data_x):
    """

    :param data_x: (n * feature_dim), n is the number of samples in the class
    :return: a tensor of class center vector, which shape is (feature_dim)
    """
    if data_x.shape[0] != 1:
        # data_x = data_x.half()
        # sim_mat = torch.cosine_similarity(data_x.unsqueeze(1), data_x.unsqueeze(0), dim=-1)
        sim_mat = large_scale_cosine_similarity(data_x, data_x)
        sim_mat = (torch.einsum('ij->i', [sim_mat]) - 1.0) / (data_x.shape[0]-1)
        sample_weight_vec = sim_mat / torch.einsum('i->', [sim_mat])
        center_vec = torch.einsum("ij,i->j", [data_x, sample_weight_vec])
    else:
        center_vec = data_x[0,:]
    return center_vec


def cal_angle_to_center(data_x, center, radian=False):
    """

    :param data_x: (n * feature_dim), n is the number of samples in the class
    :param center: a tensor of class center vector, which shape is (feature_dim)
    :return: vector of angle
    """
    angle_vec = torch.arccos(torch.cosine_similarity(data_x, center.unsqueeze(0), dim=1))
    if not radian:
        angle_vec = angle_vec * 180 / math.pi
    return angle_vec


def cal_angle_p2p(data_a, data_b, radian=False):
    """

    :param data_x: (n * feature_dim), n is the number of samples in the class
    :param center:  (n * feature_dim), n is the number of samples in the class
    :return: vector of angle, point to point
    """
    angle_vec = torch.arccos(torch.cosine_similarity(data_a, data_b, dim=1))
    if not radian:
        angle_vec = angle_vec * 180 / math.pi
    return angle_vec


def cal_intra_class_angular_distri(data_x, center, radian=False, is_return_angles=False):
    """

    :param data_x: (n * feature_dim), n is the number of samples in the class
    :param center: a tensor of class center vector, which shape is (feature_dim)
    :return: mean, variance
    """
    angle_vec = cal_angle_to_center(data_x, center, radian)
    angular_mean = torch.mean(angle_vec)
    angular_var = torch.var(angle_vec, unbiased=False)
    if is_return_angles:
        return angular_mean, angular_var, angle_vec
    else:
        return angular_mean, angular_var


def cal_class_angles(num_class, class_split_features=None):
    class_split_angle_list= []
    sample_num_list = []
    center_set = []
    for i in tqdm(range(num_class)):
        class_feature = class_split_features[i]
        center = cal_class_center(class_feature)
        _, _, angles = cal_intra_class_angular_distri(class_feature, center, radian=False, is_return_angles=True)
        sample_num_list.append(class_feature.shape[0])
        center_set.append(center)
        class_split_angle_list.append(angles)
    return class_split_angle_list, sample_num_list, center_set


def cal_class_distri(dst, feature_extractor, num_class, is_return_center=False, class_split_features=None):
    if class_split_features is None:
        class_split_features = feature_extractor.extractor_features_from_dst(dst, num_classes=num_class)
    # features = normalize(features, norm='l2')
    distribution_list = []
    sample_num_list = []
    center_set = []
    for i in tqdm(range(num_class)):
        class_feature = class_split_features[i]
        center = cal_class_center(class_feature)
        mean, var = cal_intra_class_angular_distri(class_feature, center, radian=False)
        distribution_list.append((mean.item(), var.item()))
        sample_num_list.append(class_feature.shape[0])
        center_set.append(center)
    if is_return_center:
        return distribution_list, sample_num_list, center_set
    else:
        return distribution_list, sample_num_list


def angle_distribution_calibration(distribution, class_sample_num):
    """
    :param distribution: a list, len , [(mean, var), ...]
    :param class_sample_num: a list, contains the number of samples in each class
    :return:
    """
    total_weight = np.sum(class_sample_num)
    class_weight = np.asarray([num/total_weight for num in class_sample_num], dtype=float)
    global_var = np.sum(class_weight * np.asarray([d[1] for d in distribution], dtype=float))
    calibrated_distribution = [(d[0], d[1]) if d[1] > global_var else (d[0], global_var) for d in distribution]
    return calibrated_distribution


if __name__ == '__main__':
    def angle_distr_test():
        a = torch.randint(100, (10, 50)).float()
        # c = torch.einsum("ik,kj->ij", [a, a.permute(1,0)])
        c = torch.cosine_similarity(a.unsqueeze(1), a.unsqueeze(0), dim=-1)
        c = (torch.einsum('ij->i', [c]) - 1.0) / a.shape[0]
        c = c / torch.einsum('i->', [c])

        c2 = torch.mul(a, c.reshape((10, 1))).sum(dim=0)
        c1 = torch.einsum("ij,i->j", [a, c])

        print(c2 - c1)
        print(c1)
        print(cal_class_center(a))
        print(cal_intra_class_angular_distri(a, c1, radian=False))

    def similarity_test():
        a = torch.randint(100, (10, 50)).float()
        gt = torch.cosine_similarity(a.unsqueeze(1), a.unsqueeze(0), dim=-1)
        ours = large_scale_self_cosine_similarity(a)
        print(gt-ours)

    similarity_test()