import torch
import numpy as np
import torch.nn.functional as F
from .supcon_loss import SupConLoss

def euclidean_distance(a, b):
    sq_a = a**2
    sum_sq_a = torch.sum(sq_a,dim=1).unsqueeze(1)  # m->[m, 1]
    sq_b = b**2
    sum_sq_b = torch.sum(sq_b,dim=1).unsqueeze(0)  # n->[1, n]
    bt = b.t()
    return torch.sqrt(sum_sq_a+sum_sq_b-2*a.mm(bt))

def pairwise_cosine_sim(x, y):
    x = F.normalize(x, p=2, dim=1)
    y = F.normalize(y, p=2, dim=1)
    return torch.matmul(x, y.T)

def cosine_similarity(qf, gf):
    epsilon = 0.00001
    dist_mat = qf.mm(gf.t())
    qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True)  # mx1
    gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True)  # nx1
    qg_normdot = qf_norm.mm(gf_norm.t())

    dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy()
    dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon)
    dist_mat = np.arccos(dist_mat)
    return dist_mat

def compute_cosine_distance(qf, gf):
    """Computes cosine distance.
    Args:
        features (torch.Tensor): 2-D feature matrix.
        others (torch.Tensor): 2-D feature matrix.
    Returns:
        torch.Tensor: distance matrix.
    """
    features = F.normalize(qf, p=2, dim=1)
    others = F.normalize(gf, p=2, dim=1)
    dist_m = 1 - torch.mm(features, others.t())
    epsilon = 0.00001
    dist_m = dist_m.cpu().numpy()
    return np.clip(dist_m, epsilon, 1 - epsilon)

def info_nce_logits(features, paired=False, use_pseudo_labels=False, args=None):

    b_ = 0.5 * int(features.size(0))
    features = F.normalize(features, dim=1)

    similarity_matrix = torch.matmul(features, features.T)
    if paired:
        assert similarity_matrix.shape == (
            args.n_views * args.batch_size * 2, args.n_views * args.batch_size * 2)
        if use_pseudo_labels:
            mask_pattern = torch.arange(b_) // 2
            mask = torch.stack([mask_pattern, mask_pattern]).flatten()[len(mask_pattern//2):]
            labels = torch.cat([mask for i in range(args.n_views)], dim=0)
        else:
            labels = torch.cat([torch.arange(b_) for i in range(args.n_views)], dim=0)
    else:
        assert similarity_matrix.shape == (
            args.n_views * args.batch_size, args.n_views * args.batch_size)
        labels = torch.cat([torch.arange(b_) for i in range(args.n_views)], dim=0)

    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    labels = labels.to(args.device)
    assert similarity_matrix.shape == labels.shape

    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(args.device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    # assert similarity_matrix.shape == labels.shape

    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    # select only the negatives the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(args.device)

    logits = logits / args.temperature
    return logits, labels

def pcl_loss(features, prototypes, temperature):
    sim_mat = pairwise_cosine_sim(features, prototypes) / temperature
    s_dist = F.softmax(sim_mat, dim=1)
    cost_mat = euclidean_distance(features, prototypes)
    loss= (cost_mat * s_dist).sum(1).mean()
    return loss
