from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import BaseMethod


def contrastive_loss(x0, x1, tau, norm, mean=True):
    # https://github.com/google-research/simclr/blob/master/objective.py
    bsize = x0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    if norm:
        x0 = F.normalize(x0, p=2, dim=1)
        x1 = F.normalize(x1, p=2, dim=1)
    logits00 = x0 @ x0.t() / tau - eye_mask
    logits11 = x1 @ x1.t() / tau - eye_mask
    logits01 = x0 @ x1.t() / tau
    logits10 = x1 @ x0.t() / tau
    if mean:
        return (F.cross_entropy(torch.cat([logits01, logits00], dim=1), target)+
                F.cross_entropy(torch.cat([logits10, logits11], dim=1), target))/2
    else:
        return torch.cat([
            F.cross_entropy(torch.cat([logits01, logits00], dim=1), target, reduction='none'),
            F.cross_entropy(torch.cat([logits10, logits11], dim=1), target, reduction='none')], 
            dim=0) 


class Contrastive(BaseMethod):
    """ implements contrastive loss https://arxiv.org/abs/2002.05709 """

    def __init__(self, cfg):
        """ init additional BN used after head """
        super().__init__(cfg)
        self.bn_last = nn.BatchNorm1d(cfg.emb)
        self.loss_f = partial(contrastive_loss, tau=cfg.tau, norm=cfg.norm)
        self.model=self.model

    def forward(self, samples, mean=True):
        bs = len(samples[0])
        h = [self.model(x) for x in samples]
        h = self.bn_last(self.head(torch.cat(h)))
        loss = 0
        for i in range(len(samples) - 1):
            for j in range(i + 1, len(samples)):
                x0 = h[i * bs : (i + 1) * bs]
                x1 = h[j * bs : (j + 1) * bs]
                loss += self.loss_f(x0, x1, mean=mean)
        if mean:
            loss=loss / self.num_pairs
        return loss
