import torch
import tqdm
import math
from angle_distribution import large_scale_cosine_similarity

def cal_diversity_to_set_by_class(target_set, src_set, num_class):
    diversity_list = []
    for i in tqdm.tqdm(range(num_class)):
        min_angle_vec = cal_diversity_to_set(target_set[i], src_set[i], batch_size=512)
        diversity_list.append(min_angle_vec)
    return diversity_list


def cal_diversity_to_set(tar_feats, src_feats, batch_size=None, epsilon=1e-8):
    if batch_size is None:
        batch_size = tar_feats.shape[0]
    total_len = tar_feats.shape[0]
    min_angle_vec = []
    for i in range(total_len // batch_size + 1):
        head = batch_size * i
        if head >= total_len:
            break
        tail = batch_size * i + batch_size if batch_size * i + batch_size < total_len else total_len
        angle_matrix = torch.arccos(
            torch.clamp(torch.cosine_similarity(tar_feats[head:tail, :].unsqueeze(1), src_feats.unsqueeze(0), dim=-1),
                                                min=-1 + epsilon, max=1-epsilon))
        angle_matrix = angle_matrix * 180 / math.pi
        batch_min_angle_vec = torch.min(angle_matrix, dim=1)[0]
        min_angle_vec.append(batch_min_angle_vec)
    return torch.cat(min_angle_vec, dim=0)
