import torch
import torch.nn as nn
import torch.nn.functional as F


class PreserveDistanceLoss(nn.Module):
    def __init__(self):
        super(PreserveDistanceLoss, self).__init__()

    def forward(self, z, z_prime):
        # Compute pairwise L2 distance loss
        distance_loss = torch.mean((torch.cdist(z, z) - torch.cdist(z_prime, z_prime)) ** 2)
        return distance_loss


class DynamicWeightingPDLoss(nn.Module):
    """
    Implements dynamically weighted preserve distance loss.

    Classic PD loss create conflicting loss objectives when combined with Lpred because Lpred wants to
    separate two labels as far as possible while Lpreserve makes it as close as possible. To mitigate this problem,
    This version of PDLoss "turns off" loss calculation when i and j have different labels, so that they can be easily
    separated by Lpred.

    """

    def __init__(self, enable_dynamic=True):
        super(DynamicWeightingPDLoss, self).__init__()
        self.enable_dynamic = enable_dynamic
        print("Using dist_method: orig")
        print("Enable dynamic weighting: ", enable_dynamic)
        
    def forward(self, z, z_prime, labels):
        """
        :param z: (B, Z)
        :param z_prime: (B, Z')
        :param labels: (B, 1), long tensor indicating label(cls), float tensor indicating value (reg),
        or (B, N) multitask setting
        :return: frob. norm between cdist(z) and cdist(z'), masked where label between i and j are different
        """
        cdist_z = torch.cdist(z, z, p=2)  # (B, B)
        cdist_z_prime = torch.cdist(z_prime, z_prime, p=2)  # (B, B)

        if len(labels.shape) < 2:
            labels = labels.unsqueeze(-1)

        
        diff = cdist_z - cdist_z_prime
        
        if self.enable_dynamic:
        # if it is not LongTensor, change to regression mode
            if labels.dtype != torch.long:
                # compute pairwise squared distance, negated to make it unnormalized similarity
                neg_diff_matrix = -(labels - labels.T) ** 2

                # normalize it with softmax along row
                # sim_matrix = torch.softmax(neg_diff_matrix, dim=1)

                # or just use gaussian kernel
                sim_matrix = torch.exp(neg_diff_matrix)

                # use this as "soft mask" where i and j are gonna be preserved weaker if their labels are different
                mask = sim_matrix
            else:
                # for classification task, it is simply 0 if i and j have different labels and 1 otherwise
                # so when multiplied to diff, it sets all pairs with different labels to zero
                # so the losses don't get computed
                
                # in the case of multiple labels given at a time,
                # it is simply 1 if i and j have same labels and 0 otherwise
                if labels.shape[-1] > 1:
                    mask = (labels.unsqueeze( 0) == labels.unsqueeze(1)).all(dim=-1).float()
                else:
                    mask = (labels == labels.transpose(0, 1)).float()

            # Apply mask
            masked_diff = diff * mask

            # Calculate Frobenius norm
            frob_norm = torch.mean(masked_diff ** 2)
        else:
            frob_norm = torch.mean(diff ** 2)

        return frob_norm
    

class DynamicWeightingRBFLoss(nn.Module):

    def __init__(self, enable_dynamic=True):
        super(DynamicWeightingRBFLoss, self).__init__()
        self.enable_dynamic = enable_dynamic
        print("Using dist_method: rbf")
        print("Enable dynamic weighting: ", enable_dynamic)
        
    def forward(self, z, z_prime, labels):
        """
        :param z: (B, Z)
        :param z_prime: (B, Z')
        :param labels: (B, 1), long tensor indicating label(cls), float tensor indicating value (reg),
        or (B, N) multitask setting
        :return: frob. norm between cdist(z) and cdist(z'), masked where label between i and j are different
        """
        cdist_z = torch.exp(-torch.cdist(z, z, p=2))  # (B, B)
        cdist_z_prime = torch.exp(-torch.cdist(z_prime, z_prime, p=2))  # (B, B)

        if len(labels.shape) < 2:
            labels = labels.unsqueeze(-1)

        diff = cdist_z - cdist_z_prime
        
        if self.enable_dynamic:
        # if it is not LongTensor, change to regression mode
            if labels.dtype != torch.long:
                # compute pairwise squared distance, negated to make it unnormalized similarity
                neg_diff_matrix = -(labels - labels.T) ** 2

                # normalize it with softmax along row
                # sim_matrix = torch.softmax(neg_diff_matrix, dim=1)

                # or just use gaussian kernel
                sim_matrix = torch.exp(neg_diff_matrix)

                # use this as "soft mask" where i and j are gonna be preserved weaker if their labels are different
                mask = sim_matrix
            else:
                # for classification task, it is simply 0 if i and j have different labels and 1 otherwise
                # so when multiplied to diff, it sets all pairs with different labels to zero
                # so the losses don't get computed
                
                # in the case of multiple labels given at a time,
                # it is simply 1 if i and j have same labels and 0 otherwise
                if labels.shape[-1] > 1:
                    mask = (labels.unsqueeze( 0) == labels.unsqueeze(1)).all(dim=-1).float()
                else:
                    mask = (labels == labels.transpose(0, 1)).float()

            # Apply mask
            masked_diff = diff * mask

            # Calculate Frobenius norm
            frob_norm = torch.mean(masked_diff ** 2)
        else:
            frob_norm = torch.mean(diff ** 2)

        return frob_norm


class DynamicPreserveCosineLoss(nn.Module):
    def __init__(self, enable_dynamic=True):
        super(DynamicPreserveCosineLoss, self).__init__()
        self.enable_dynamic = enable_dynamic
        print("Using dist_method: cosine")
        print("Enable dynamic weighting: ", enable_dynamic)

    def forward(self, z, z_prime, labels):
        # Compute pairwise dot product(ZZ^T=gram matrix) loss
        z = F.normalize(z)
        z_prime = F.normalize(z_prime)
        
        cdist_z = z @ z.T  # (B, B)
        cdist_z_prime = z_prime @ z_prime.T  # (B, B)
        
        diff = cdist_z - cdist_z_prime
        
        if len(labels.shape) < 2:
            labels = labels.unsqueeze(-1)
        
        if self.enable_dynamic:
            if labels.dtype != torch.long:
                neg_diff_matrix = -(labels - labels.T) ** 2
                sim_matrix = torch.exp(neg_diff_matrix)
                mask = sim_matrix
            else:
                if labels.shape[-1] > 1:
                    mask = (labels.unsqueeze( 0) == labels.unsqueeze(1)).all(dim=-1).float()
                else:
                    mask = (labels == labels.transpose(0, 1)).float()

            masked_diff = diff * mask
            frob_norm = torch.mean(masked_diff ** 2)
        else:
            frob_norm = torch.mean(diff ** 2)
        return frob_norm


class DynamicIRDLoss(nn.Module):
    def __init__(self, temperature=1, set_diag_to_zero=True, enable_dynamic=True):
        super(DynamicIRDLoss, self).__init__()
        # by default authors set diagonals to zero
        self.set_diag_zero = set_diag_to_zero
        self.temp = temperature
        self.enable_masking = enable_dynamic
        print("Using dist_method: IRD")
        print("Enable dynamic weighting: ", enable_dynamic)

    def forward(self, z, z_prime, labels):
        # turn to unit vector for spherical projection
        z = F.normalize(z)
        z_prime = F.normalize(z_prime)

        z_sim = torch.exp(z @ z.T / self.temp)
        z_prime_sim = torch.exp(z_prime @ z_prime.T / self.temp)

        # by default the loss sets diagonals to zero
        if self.set_diag_zero:
            mask = torch.eye(z_sim.shape[0], device=z_sim.device).bool()
            z_sim = z_sim.masked_fill(mask, 0)
            z_prime_sim = z_prime_sim.masked_fill(mask, 0)

        # for classification task, it is simply 1 if i and j have different labels and 0 otherwise
        if len(labels.shape) < 2:
            labels = labels.unsqueeze(-1)

        # 1 if labels are different, so that when used masked_fill, it turns those values to zero
        if self.enable_masking:
            label_mask = (labels != labels.transpose(0, 1))
            z_sim = z_sim.masked_fill(label_mask, 0)
            z_prime_sim = z_prime_sim.masked_fill(label_mask, 0)

        # normalize by row
        z_sim = z_sim / z_sim.sum(dim=1, keepdim=True)
        z_prime_sim = z_prime_sim / z_prime_sim.sum(dim=1, keepdim=True)

        # calculate masked cross entropy
        IRD_loss = -z_sim * torch.log(z_prime_sim + 1e-8)
        if self.enable_masking:
            IRD_loss = IRD_loss.masked_fill(label_mask, 0)
        if self.set_diag_zero:
            IRD_loss = IRD_loss.masked_fill(mask, 0)
        return IRD_loss.sum() / (IRD_loss.numel() - mask.sum())

class DynamicWeightingRKDLoss(nn.Module):
    def __init__(self, enable_dynamic=True):
        super(DynamicWeightingRKDLoss, self).__init__()
        self.enable_masking = enable_dynamic
        
        print("Using dist_method: RKD")
        print("Enable dynamic weighting: ", enable_dynamic)

    def forward(self, z, z_prime, labels):
        # Compute pairwise L2 distance loss
        z_dists = torch.cdist(z, z)
        z_prime_dists = torch.cdist(z_prime, z_prime)
        # normalize by dividing by mean
        z_dists = z_dists / torch.mean(z_dists)
        z_prime_dists = z_prime_dists / torch.mean(z_prime_dists)

        if len(labels.shape) < 2:
            labels = labels.unsqueeze(-1)

        # if dynamic masking is turned on, we mask the loss where i and j have different labels
        if self.enable_masking:
        # if it is not LongTensor, change to regression mode
            if labels.dtype != torch.long:
                # compute pairwise squared distance, negated to make it unnormalized similarity
                neg_diff_matrix = -(labels - labels.T) ** 2

                # normalize it with softmax along row
                # sim_matrix = torch.softmax(neg_diff_matrix, dim=1)

                # or just use gaussian kernel
                sim_matrix = torch.exp(neg_diff_matrix)

                # use this as "soft mask" where i and j are gonna be preserved weaker if their labels are different
                mask = sim_matrix
            else:
                # for classification task, it is simply 0 if i and j have different labels and 1 otherwise
                # so when multiplied to diff, it sets all pairs with different labels to zero
                # so the losses don't get computed
                mask = (labels == labels.transpose(0, 1)).float()

            z_prime_dists = z_prime_dists * mask
            z_dists = z_dists * mask

        huber_distance_loss = F.huber_loss(input=z_prime_dists, target=z_dists)

        return huber_distance_loss

