import torch
import torch.nn.functional as F


def batch_to_torch(batch, device):
    return {
        k: torch.from_numpy(v).to(device=device, non_blocking=True)
        for k, v in batch.items()
    }


def batch_free_memory(batch):
    for k, v in batch.items():
        del k


def contr_loss(anchor, pos, neg, T=1.0):

    neg = torch.cat(neg)
    neg_dot_products = torch.mm(anchor, neg.t())
    neg_dists = -(
            (anchor ** 2).sum(1).unsqueeze(1) - 2 * neg_dot_products + (neg ** 2).sum(1).unsqueeze(
        0))

    bs = anchor.shape[0]
    idxs = torch.arange(bs)
    for i in range(neg_dists.shape[0] // bs):
        neg_dists[idxs, i*bs + idxs] = float('-inf')

    pos_dot_products = (anchor * pos).sum(dim=1)  # b
    pos_dists = -((anchor ** 2).sum(1) - 2 * pos_dot_products + (pos ** 2).sum(1)).unsqueeze(1)

    dists_all = torch.cat((neg_dists, pos_dists), dim=1)
    dists = F.log_softmax(dists_all / T, dim=1)
    loss = -dists[:, -1].mean()

    return loss