# Modified based on https://github.com/lightly-ai/lightly

import torch
from torch import nn


class BCLoss(nn.Module):
    """Implementation of Balanced Contrastive Loss.
    Attributes:
        a: alpha
        l: lambda
    Raises:
        ValueError: If abs(a) < 1e-8 to prevent divide by zero.
    Examples:
        >>> # initialize loss function
        >>> loss_fn = BCLoss(a, l)
        >>>
        >>> # generate two random transforms of images
        >>> t0 = transforms(images)
        >>> t1 = transforms(images)
        >>>
        >>> # feed through model
        >>> batch = torch.cat((t0, t1), dim=0)
        >>> output = model(batch)
        >>>
        >>> # calculate loss
        >>> loss = loss_fn(output)
    """

    def __init__(
        self,
        a,
        l,
    ):
        super(BCLoss, self).__init__()
        self.a = a
        self.l = l
        self.eps = 1e-8

        if abs(self.a) < self.eps:
            raise ValueError(
                "Illegal alpha: abs({}) < 1e-8".format(self.a)
            )

    def forward(self, out0: torch.Tensor, out1: torch.Tensor):
        """Forward pass through Balanced Contrastive Loss.
        Args:
            out0:
                Output projections of the first set of transformed images.
                Shape: (batch_size, embedding_size)
            out1:
                Output projections of the second set of transformed images.
                Shape: (batch_size, embedding_size)
        Returns:
            Balanced Contrastive Loss value.
        """

        device = out0.device
        batch_size, _ = out0.shape

        # normalize the output to length 1
        out0 = nn.functional.normalize(out0, dim=1)
        out1 = nn.functional.normalize(out1, dim=1)

        # We use the cosine similarity, which is a dot product (einsum) here,
        # as all vectors are already normalized to unit length.
        # Notation in einsum: n, m = batch_size, c = embedding_size.

        # create diagonal mask that only selects similarities between
        # views of the same image
        diag_mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool)

        # calculate similiarities
        logits_00 = torch.einsum("nc,mc->nm", out0, out0)
        logits_01 = torch.einsum("nc,mc->nm", out0, out1)
        logits_10 = torch.einsum("nc,mc->nm", out1, out0)
        logits_11 = torch.einsum("nc,mc->nm", out1, out1)

        # arrange similarities for attracting loss
        logits_01_att = logits_01[diag_mask].view(batch_size, -1)
        logits_10_att = logits_10[diag_mask].view(batch_size, -1)
        logits_att = torch.cat([logits_01_att, logits_10_att], dim=0)

        # arrange similarities for repelling loss
        logits_00_rep = logits_00[~diag_mask].view(batch_size, -1)
        logits_01_rep = logits_01[~diag_mask].view(batch_size, -1)
        logits_10_rep = logits_10[~diag_mask].view(batch_size, -1)
        logits_11_rep = logits_11[~diag_mask].view(batch_size, -1)

        logits_0100_rep = torch.cat([logits_01_rep, logits_00_rep], dim=1)
        logits_1011_rep = torch.cat([logits_10_rep, logits_11_rep], dim=1)
        logits_rep = torch.cat([logits_0100_rep, logits_1011_rep], dim=0)

        # construct loss
        loss_att = -logits_att
        loss_rep = 1/self.a * torch.log(torch.sum(torch.exp(self.a * logits_rep), dim=1))
        loss = torch.mean(loss_att + self.l*loss_rep)

        return loss


class GNTXentLoss(nn.Module):
    """Implementation of Generalized NT-Xent Loss.
    Attributes:
        a: alpha
        l: lambda
    Raises:
        ValueError: If abs(a) < 1e-8 to prevent divide by zero.
    Examples:
        >>> # initialize loss function
        >>> loss_fn = GNTXentLoss(a, l)
        >>>
        >>> # generate two random transforms of images
        >>> t0 = transforms(images)
        >>> t1 = transforms(images)
        >>>
        >>> # feed through SSL
        >>> batch = torch.cat((t0, t1), dim=0)
        >>> output = model(batch)
        >>>
        >>> # calculate loss
        >>> loss = loss_fn(output)
    """

    def __init__(
        self,
        a,
        l,
    ):
        super(GNTXentLoss, self).__init__()
        self.a = a
        self.l = l
        self.eps = 1e-8

        if abs(self.a) < self.eps:
            raise ValueError(
                "Illegal alpha: abs({}) < 1e-8".format(self.a)
            )

    def forward(self, out0: torch.Tensor, out1: torch.Tensor):
        """Forward pass through Generalized NT-Xent Loss.
        within-batch samples are used as negative samples.
        Args:
            out0:
                Output projections of the first set of transformed images.
                Shape: (batch_size, embedding_size)
            out1:
                Output projections of the second set of transformed images.
                Shape: (batch_size, embedding_size)
        Returns:
            Generalized NT-Xent Loss value.
        """

        device = out0.device
        batch_size, _ = out0.shape

        # normalize the output to length 1
        out0 = nn.functional.normalize(out0, dim=1)
        out1 = nn.functional.normalize(out1, dim=1)

        # We use the cosine similarity, which is a dot product (einsum) here,
        # as all vectors are already normalized to unit length.
        # Notation in einsum: n = batch_size, c = embedding_size.

        # use other samples from batch as negatives
        # and create diagonal mask that only selects similarities between
        # views of the same image
        diag_mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool)

        # calculate similiarities
        logits_00 = torch.einsum("nc,mc->nm", out0, out0)
        logits_01 = torch.einsum("nc,mc->nm", out0, out1)
        logits_10 = torch.einsum("nc,mc->nm", out1, out0)
        logits_11 = torch.einsum("nc,mc->nm", out1, out1)

        # remove similarities between views of the same image
        logits_00 = logits_00[~diag_mask].view(batch_size, -1)
        logits_11 = logits_11[~diag_mask].view(batch_size, -1)

        # concatenate logits
        # the logits tensor in the end has shape (2*n, 2*m-1)
        logits_0100 = torch.cat([logits_01, logits_00], dim=1)
        logits_1011 = torch.cat([logits_10, logits_11], dim=1)
        logits = torch.cat([logits_0100, logits_1011], dim=0)

        # # create labels
        labels = torch.arange(batch_size, device=device, dtype=torch.long)
        labels = labels.repeat(2)

        loss_1 = -torch.gather(logits, dim=1, index=labels[:, None])
        loss_2 = 1/self.a * torch.log(torch.sum(torch.exp(self.a * logits), dim=1))
        loss = torch.mean(loss_1 + self.l*loss_2)

        return loss
