from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
from loss_functions import MdarLoss

def pdist(e, squared=False, eps=1e-12):
    e_square = e.pow(2).sum(dim=1)
    prod = e @ e.t()
    res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)

    if not squared:
        res = res.sqrt()

    res = res.clone()
    res[range(len(e)), range(len(e))] = 0
    return res

class RKdAngle(nn.Module):
    def forward(self, student, teacher):
        # N x C
        # N x N x C

        with torch.no_grad():
            td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
            norm_td = F.normalize(td, p=2, dim=2)
            t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)

        sd = (student.unsqueeze(0) - student.unsqueeze(1))
        norm_sd = F.normalize(sd, p=2, dim=2)
        s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)

        loss = F.smooth_l1_loss(s_angle, t_angle, reduction='mean') 
        '''
        [Warning Fix] warnings.warn("reduction='elementwise_mean' is deprecated, please use reduction='mean' instead.")
        '''
        return loss


class RkdDistance(nn.Module):
    def forward(self, student, teacher):
        with torch.no_grad():
            t_d = pdist(teacher, squared=False)
            mean_td = t_d[t_d>0].mean()
            t_d = t_d / mean_td

        d = pdist(student, squared=False)
        mean_d = d[d>0].mean()
        d = d / mean_d

        loss = F.smooth_l1_loss(d, t_d, reduction='mean')
        '''
        [Warning Fix] warnings.warn("reduction='elementwise_mean' is deprecated, please use reduction='mean' instead.")
        '''
        return loss


class RKDLoss(nn.Module):
    """ Contrastive Learning with MdarLoss."""
    def __init__(self, projection_dim, w_angle= 2.0, w_dist= 1.0, lmda=0.051,lmda_task= 0.0051, device=None):
        super(RKDLoss, self).__init__()
        self.projection_dim= projection_dim #[TODO] - added to normalize loss
        self.penalty= projection_dim/128 #normalized loss
        self.device=device
        self.lmda= lmda
        self.lmda_task= lmda_task
        self.angle_loss = RKdAngle()
        self.dist_loss = RkdDistance()
        self.w_angle= w_angle
        self.w_dist = w_dist
        
        #MI
        self.oracle_criterion = MdarLoss(projection_dim=projection_dim)

    def forward(self, features, labels=None, mask=None, adv=False, standardize = True, prof=None):
        """ 
        [RKDLoss]
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
            
        if self.device is not None:
            device = self.device
        else:
            device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)
        
        batch_size = features.shape[0]
        
        contrast_count = features.shape[1]
        anchor_contrast_feature = torch.unbind(features, dim=1)
        
        
        #h_src,h_tgt,h_oracle,h_tgt_oracle
        h_src= anchor_contrast_feature[0] #source
        h_tgt= anchor_contrast_feature[1] #source
        h_oracle= anchor_contrast_feature[2] #oracle
        h_tgt_oracle= anchor_contrast_feature[3] #oracle
        
        #v.2
        #tasknet_all = torch.cat((h_src,h_tgt), 0)
        #oraclenet_all = torch.cat((h_oracle, h_tgt_oracle), 0)
        #rkd_loss = self.w_angle * self.angle_loss(tasknet_all,oraclenet_all) + self.w_dist* self.dist_loss(tasknet_all,oraclenet_all)
        #loss= rkd_loss
        
        #v.1
        tgt_loss = self.w_angle * self.angle_loss(h_tgt,h_tgt_oracle) + self.w_dist* self.dist_loss(h_tgt,h_tgt_oracle)
        #src_loss = self.w_angle * self.angle_loss(h_src,h_oracle) + self.w_dist* self.dist_loss(h_src,h_oracle)
        loss = tgt_loss #+ src_loss
        
        
        
        return loss