import torch
import torch.nn.functional as F
import torch.nn as nn

def uda_loss(outputs_x, targets_x, outputs_u, targets_u):
    probs_u = torch.softmax(outputs_u, dim=1)
    Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
    Lu = torch.mean((probs_u - targets_u)**2)
    return Lx, Lu

def entropy(x, input_as_probabilities, q):
    if input_as_probabilities:
        x_ = torch.clamp(x, min=1e-8)
        b = x_ * torch.log(x_ / q)
    else:
        b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)

    if len(b.size()) == 2:
        return -b.sum(dim=1).mean()
    elif len(b.size()) == 1:
        return -b.sum()
    else:
        raise ValueError('Input tensor is %d-Dimensional' % (len(b.size())))

def symmetric_mse_loss(input1, input2):
    assert input1.size() == input2.size()
    return torch.sum((input1 - input2)**2)

def supcon_knn(features, features_all, mask, temperature=0.07):
    assert mask.shape[0] == features.shape[0] and mask.shape[1] == features_all.shape[0]
    features_all = features_all.detach().clone()
    device = (torch.device('cuda') if features.is_cuda else torch.device('cpu'))
    mask = mask.float().detach().clone().to(device)
    
    anchor_dot_contrast = torch.div(
        torch.matmul(features, features_all.T),
        temperature)
    
    logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
    logits = anchor_dot_contrast - logits_max.detach()

    exp_logits = torch.exp(logits)
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)

    mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

    loss = -mean_log_prob_pos.mean()
    return loss

def info_nce_logits(features, n_views=2, temperature=1.0, device='cuda'):
    b_ = 0.5 * int(features.size(0))

    labels = torch.cat([torch.arange(b_) for i in range(n_views)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    labels = labels.to(device)

    features = F.normalize(features, dim=1)

    similarity_matrix = torch.matmul(features, features.T)

    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

    logits = logits / temperature
    return logits, labels