"""Define Transformer for Clustering."""
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl.geometry as dgl_geo


class TransformerClustering(nn.Module):

    def __init__(self,
                 num_clusters=4,
                 d_model=512,
                 dropout=0.1,
                 activation="relu",
                 normalize_before=False,
                 detach_src_for_logit=True,
                 l2_normalize_for_fps=True,
                 return_intermediate_dec=False):
        """Initializes a Transformer for Clustering.

        Args:
          num_clusters: A scalar indicates the number of centroids.
          d_model: A scalar indicates the input channels to Transformer.
          dropout: A `float` indicates the dropout rate.
          activation: A string indicates the type of non-linear activation.
          normalize_before: A `bool` indicates if applying normalization first.
          return_intermediate_dec: A `bool` indicates if return intermediate
            results from decoders.
        """
        super().__init__()
        self.fc = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.ReLU(inplace=True),
            nn.Linear(d_model, d_model, bias=True))

        self._num_clusters = num_clusters
        self._detach_src_for_logit = detach_src_for_logit
        self._l2_normalize_for_fps = l2_normalize_for_fps

    def _fill_with_mean(self, src, mask):
        """A helper function to fill invalid entries with mean values.
        """
        bs, sl, cs = src.shape
        if mask is not None:
            mean_src = valid_mean(src, ~mask).unsqueeze(1).type_as(src)
            # Fill padded entries with mean values.
            fill_mask = mask.unsqueeze(2).expand(-1, -1, cs)
            filled_src = torch.where(fill_mask, mean_src.expand(-1, sl, -1), src)
        else:
            mean_src = torch.mean(src, dim=1, keepdim=True).type_as(src)
            filled_src = src

        return filled_src, mean_src

    def forward(self, src, mask, pos_embed):
        """Feedforward for clustering with Transformer.

        Args:
          src: A `tensor` of shape `[batch_size, source_sequence_length, channels]`.
          mask: A bool `tensor` of shape `[batch_size, sequence_length]`.
          pos_embed: A `tensor` of shape
            `[batch_size, source_sequence_length, channels]`.

        Returns:
          centroids: A `tensor` of shape `[batch_size, num_clusters, channels]`.
          logits: A `tensor` of shape
            `[batch_size, source_sequence_length, num_clusters]`.
          sampled_inds: A `tensor` of shape
            `[batch_size, num_clusters]`.
        """
        bs, sl, cs = src.shape

        # Sample query by Farthest Point Sampling.
        # `centroids` is of shape `[batch_size, target_sequence_length, channels]`.
        filled_src, mean_src = self._fill_with_mean(src, mask)
        padded_src = torch.cat([mean_src, filled_src], dim=1)

        if self._l2_normalize_for_fps:
            sampling_src = F.normalize(padded_src, dim=-1)
        else:
            sampling_src = padded_src

        # NOTE: if the number of valid entries is smaller than
        # num_clusters, the 1st padded entry will be re-sampled.
        sampled_inds = dgl_geo.farthest_point_sampler(
            sampling_src.to(torch.float64),
            self._num_clusters + 1,
            0).long()
        sampled_inds = sampled_inds[:, 1:]
        sampled_inds = sampled_inds.unsqueeze(2).expand(-1, -1, cs)
        sampled_src = torch.gather(padded_src, 1, sampled_inds)

        centroids = sampled_src

        # Optimize clustering regularization after a fc layer.
        node_features = src.detach() if self._detach_src_for_logit else src
        node_features = self.fc(node_features)
        filled_node_features, mean_node_features = self._fill_with_mean(
            node_features, mask)
        padded_node_features = torch.cat(
            [mean_node_features, filled_node_features], dim=1)
        centroid_features = torch.gather(padded_node_features, 1, sampled_inds)

        # Pick centroids with maximum activations w.r.t `node_features`, where
        # `logits` is of shape
        # `[batch_size, source_sequence_length, target_sequence_length].
        normed_centroid_features = F.normalize(centroid_features, dim=-1)
        normed_node_features = F.normalize(node_features, dim=-1)
        logits = torch.einsum(
            'bij,bjk->bik', normed_node_features, normed_centroid_features.transpose(1, 2))
        logits = logits * 5

        return centroids, logits, sampled_inds[:, :, 0] - 1


def valid_mean(x, mask):
     """Compute mean of x given valid mask.

     Args:
         x: A `float` tensor of shape `[batch_size, num_nodes, channels]`.
         mask: A `bool` tensor of shape `[batch_size, num_nodes]`, where
             `True` indicates the entry is valid.

     Returns:
         mean_x: A `float` tensor of shape `[batch_size, channels]`.
     """
     mask = mask.type_as(x).unsqueeze(2)
     sum_mask = torch.clamp(torch.sum(mask, dim=1), min=1)
     masked_x = x * mask
     mean_x = torch.sum(masked_x, dim=1) / sum_mask

     return mean_x

