from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

class BTLoss(nn.Module):
    """ Contrastive Learning with BarlowTwins."""
    def __init__(self, projection_dim, lmda=0.0051, device=None):
        super(BTLoss, self).__init__()
        self.projection_dim= projection_dim
        self.device=device
        self.lmda= lmda
        self.bn = nn.BatchNorm1d(self.projection_dim, affine=False).to(self.device)

    def forward(self, z_a, z_b, labels=None, mask=None, adv=False, standardize = True, prof=None):
        """ 
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        #Compute Off-Diagonal Elements for Barlow Twins
        def off_diagonal(x):
                # return a flattened view of the off-diagonal elements of a square matrix
                n, m = x.shape
                assert n == m
                return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
            
        if self.device is not None:
            device = self.device
        else:
            device = (torch.device('cuda') if z_a.is_cuda else torch.device('cpu'))


        batch_size= z_a.shape[0]
        z_a= self.bn(z_a.to(device))
        z_b= self.bn(z_b.to(device))

        c= torch.matmul(z_a.T, z_b) 
        c.div_(batch_size)
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()

        loss = on_diag + self.lmda * off_diag
        loss = loss/self.projection_dim
        return loss