"""Utility functions for defining Graph-based Clustering Loss.
"""

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

import graph.utils as utils

EPS = 1e-15


def _rank3_trace(x):
    return torch.einsum('ijj->i', x)


def _rank3_diag(x):
    eye = torch.eye(x.size(1)).type_as(x)
    out = eye * x.unsqueeze(2).expand(*x.size(), x.size(1))
    return out


def dmon_pool_loss(x, adj, s, mask=None, softmax=False):
    """DMon pooling loss [1].
    
    [1]: Graph Clustering with Graph Neural Networks.

    Args:
        x: A `tensor` of shape `[batch_size, num_nodes, channels]`.
        adj: A `tensor` of shape `[batch_size, num_nodes, num_nodes]`.
        s: A `tensor` of shape `[batch_size, num_nodes, num_clusters]`.
            The softmax does not have to be applied beforehand, since it is
            executed within this method.
        mask: (Optional) A `tensor` of shape `[batch_size, num_nodes]`,
            indicating the valid nodes for each graph.

    Returns:
        dmon_loss: A scalar `Tensor`.
        collapse_loss: A scalar `Tensor`.
    """
    x = x.unsqueeze(0) if x.dim() == 2 else x
    adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
    s = s.unsqueeze(0) if s.dim() == 2 else s

    #(batch_size, num_nodes, _), k = x.size(), s.size(-1)
    batch_size, num_nodes, k = s.shape

    if softmax:
        s = torch.softmax(s, dim=-1)

    if mask is not None:
        mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
        s = s * mask

    # C^T A C in the paper.
    out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)

    # C^T d^T d C in the paper.
    d_flat = torch.einsum('ijk->ij', adj)
    dtd = torch.einsum(
        'bij,bjk->bik', d_flat.unsqueeze(2), d_flat.unsqueeze(1))
    #out_deg = torch.matmul(torch.matmul(s.transpose(1, 2), dtd), s)

    # DMon regularization: -1/2m * Tr(C^T A C - 1 / 2m * C^T d^T d C).
    dmon_normalizer = 2 * d_flat.sum(dim=1)
    #dmon_numerator = _rank3_trace(out_adj - out_deg / dmon_normalizer.view(-1, 1, 1))

    # Avoid numerical overflow.
    dtd = dtd / dmon_normalizer.view(-1, 1, 1)
    out_deg = torch.matmul(torch.matmul(s.transpose(1, 2), dtd), s)
    dmon_numerator = _rank3_trace(out_adj - out_deg)

    dmon_loss = 1 - dmon_numerator / dmon_normalizer
    dmon_loss = torch.mean(dmon_loss)

    # Orthogonality regularization.
    ss = torch.matmul(s.transpose(1, 2), s)
    i_s = torch.eye(k).type_as(ss).unsqueeze(0)

    # Collapse regularization: sqrt(k) / n * |C^T C|_F - 1
    collapse_numerator = torch.norm(s.sum(dim=1), dim=1)
    collapse_denomerator = num_nodes / torch.norm(i_s, dim=(-1, -2))
    collapse_loss = collapse_numerator / collapse_denomerator
    collapse_loss = torch.mean(collapse_loss)

    return dmon_loss, collapse_loss


class DMonLoss(_Loss):
    """DMon clustering loss [1].

    [1]: Graph Clustering with Graph Neural Networks.
    """

    def __init__(self,
                 adj_knn=None,
                 remove_self_loop=True,
                 binarize=True,
                 size_average=None,
                 reduce=None,
                 reduction='mean'):
      """Initializes DMonLoss class.

      Args:
          knn_graph: A scalar indicates building K-NN affinity matrix.
      """
      super(DMonLoss, self).__init__(size_average, reduce, reduction)
      self._knn = adj_knn
      self._binarize = binarize
      self._remove_self_loop = remove_self_loop

    def __repr__(self):
        return 'DMonLoss(adj_knn={})'.format(self._knn)

    def forward(self, logits, x, x_padding_mask=None):
      """Compute DMon clustering loss.

      dmon_loss = - 1 / (2*m) * Tr(C^T A C - C^T d^T d C)
      collapse_loss = sqrt(k) / n * \| sum_i C_i^T \|_F

      Args:
          logits: A `tensor` of shape `[batch_size, length, num_clusters]`.
          x: A `tensor` of shape `[batch_size, length, channels]`.
          x_padding_mask: A `tensor` of shape `[batch_size, length]`
              indicates if the corresponding node is padded.

      Returns:
          dmon_loss: A scalar `tensor`.
          reg_loss: A scalar `tensor`.
      """
      # Positive affinity for pulling cluster assignment.
      with torch.no_grad():
          #kernel_fn = lambda x: utils.inner_product_kernel(x.to(torch.float64))
          kernel_fn = lambda x: utils.normed_exp_inner_product_kernel(x.to(torch.float64))
          adj = utils.affinity_matrix_as_attention(
              #x, x_padding_mask, self._knn, True, True, kernel_fn).to(x.dtype)
              x, x_padding_mask, self._knn, self._remove_self_loop,
              self._binarize, kernel_fn)
          adj = adj.to(x.dtype)
      valid_mask = ~x_padding_mask if x_padding_mask is not None else None
      dmon_loss, collapse_loss = dmon_pool_loss(
          x, adj, logits, valid_mask, True)

      return dmon_loss, collapse_loss
