from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

class HSICLoss(nn.Module):
    """ Contrastive Learning with MdarLoss."""
    def __init__(self, projection_dim, lmda=0.051,lmda_task= 0.0051, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07, device=None):
        super(HSICLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature
        self.projection_dim= projection_dim #[TODO] - added to normalize loss
        self.penalty= projection_dim/128 #normalized loss
        self.device=device
        self.lmda= lmda
        self.lmda_task= lmda_task

    def forward(self, features, labels=None, mask=None, adv=False, standardize = True, prof=None):
        """ 
        [MdarLoss]

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        #Compute Off-Diagonal Elements for Barlow Twins
        def off_diagonal(x):
                # return a flattened view of the off-diagonal elements of a square matrix
                n, m = x.shape
                assert n == m
                return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
            
        if self.device is not None:
            device = self.device
        else:
            device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)
        
        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)
        
        
        #MdarLoss
        #Many2Many redundancy reduction 
        #Reference: https://github.com/facebookresearch/barlowtwins/blob/main/main.py
        
        contrast_count = features.shape[1]
        anchor_contrast_feature = torch.unbind(features, dim=1)
        anchor_feature= anchor_contrast_feature[0]
        contrast_feature= anchor_contrast_feature[1]
        
        
        anchor_feature= (anchor_feature - anchor_feature.mean(0)) / (anchor_feature.std(0)+ 1e-6) #torch.Size([256, 128])
        contrast_feature = (contrast_feature - contrast_feature.mean(0)) / (contrast_feature.std(0) +1e-6)  #torch
        
        c= torch.matmul(anchor_feature.T, contrast_feature) 
        c.div_(batch_size)
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).add_(1).pow_(2).sum()
        if self.projection_dim != 128:
            on_diag /= (self.penalty)
            off_diag /= (self.penalty * (self.penalty -1))
        loss = on_diag + (1/self.projection_dim)* off_diag
        #loss = on_diag + self.lmda_task * off_diag
        
        
            
        return loss