"""Define Transformer for Clustering."""
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl.geometry as dgl_geo
import segment.utils as seg_utils


class TransformerClustering(nn.Module):

    def __init__(self,
                 num_clusters=4,
                 d_model=512,
                 dropout=0.1,
                 activation="relu",
                 normalize_before=False,
                 detach_src_for_logit=True,
                 l2_normalize_for_fps=True,
                 return_intermediate_dec=False,
                 kmedoid_iterations=5):
        """Initializes a Transformer for Clustering.

        Args:
          num_clusters: A scalar indicates the number of centroids.
          d_model: A scalar indicates the input channels to Transformer.
          dropout: A `float` indicates the dropout rate.
          activation: A string indicates the type of non-linear activation.
          normalize_before: A `bool` indicates if applying normalization first.
          return_intermediate_dec: A `bool` indicates if return intermediate
            results from decoders.
        """
        super().__init__()

        self._num_clusters = num_clusters
        self._detach_src_for_logit = detach_src_for_logit
        self._l2_normalize_for_fps = l2_normalize_for_fps
        self._kmedoid_iterations = kmedoid_iterations

    def _kmedoids(self, src, sampled_inds, mask, iterations, metric='l2'):
        """Get centroid by KMeans.
        """
        bs, sl, cs = src.shape
        num_samples = sampled_inds.shape[1]

        src = src.to(torch.float64)

        # Helper function to compute pairwise distance.
        def _dist_fn(x, y, metric):
            # Compute distances between datas and centroids.
            if metric == 'l2':
                # Calculate L2 distance.
                sqr_x = torch.sum(x * x, dim=-1).unsqueeze(2)
                sqr_y = torch.sum(y * y, dim=-1).unsqueeze(1)
                x_y = torch.einsum('bij,bjk->bik', x, y.transpose(1, 2))
                dists = sqr_x + sqr_y - 2 * x_y
            else:
                # Calculate cosine-similarity distance.
                normed_x = F.normalize(x, dim=-1)
                normed_y = F.normalize(y, dim=-1)
                dists = 1 - torch.einsum('bij,bjk->bik', normed_x, normed_y.transpose(1, 2))

            return dists

        next_sampled_inds = sampled_inds
        for it in range(iterations+1):
            sampled_inds = next_sampled_inds
            # Collect medois.
            medoids = torch.gather(src, 1, sampled_inds.unsqueeze(2).expand(-1, -1, cs))

            # Update medoid labels.
            src_medoid_dists = _dist_fn(src, medoids, metric)
            kmedoid_labels = torch.argmin(src_medoid_dists, dim=-1) # BxS

            # Compute summed pairwise distance within the cluster.
            src_dists = _dist_fn(src, src, metric)
            label_affinity = kmedoid_labels.unsqueeze(2) == kmedoid_labels.unsqueeze(1) # BxSxS
            src_dists = src_dists.masked_fill(~label_affinity, 0)

            # Only consider distances within the cluster.
            sum_dists = torch.sum(src_dists, dim=-1) # BxS

            # Update selected medoid indices.
            unfold_kmedoid_labels = seg_utils.one_hot(kmedoid_labels, num_samples).transpose(1, 2)
            unfold_sum_dists = sum_dists.unsqueeze(1) * unfold_kmedoid_labels.type_as(sum_dists) # BxMxS
            unfold_sum_dists = unfold_sum_dists.masked_fill(~unfold_kmedoid_labels.bool(),
                                                            sum_dists.max() + 1)

            # Avoid selecting masked datas.
            if mask is not None:
                unfold_sum_dists = unfold_sum_dists.masked_fill(mask.unsqueeze(1), sum_dists.max() + 1)
            next_sampled_inds = torch.argmin(unfold_sum_dists, dim=-1) # BxS

        return kmedoid_labels, sampled_inds


    def _rank_significance(self, attn, mask, num_samples):
        """Rank significance score by attention score.
        Args:
          attn: A `tensor` of shape
            `[batch_size, source_sequence_length, source_sequence_length]`.

        Returns:
          sampled_inds: A `tensor` of shape `[batch_size, num_clusters]`.
        """
        significance = attn.sum(dim=2).sum(dim=1)
        if mask is None:
            significance = significance.masked_fill(mask, -1)
        sampled_vals, sampled_inds = torch.topk(significance, num_samples, dim=1)

        return sampled_inds


    def forward(self, attn, src, mask, pos_embed):
        """Feedforward for clustering with Transformer.

        Args:
          attn: A `tensor` of shape
            `[batch_size, num_heads, source_sequence_length, source_sequence_length]`.
          src: A `tensor` of shape `[batch_size, source_sequence_length, channels]`.
          mask: A bool `tensor` of shape `[batch_size, sequence_length]`.
          pos_embed: A `tensor` of shape
            `[batch_size, source_sequence_length, channels]`.

        Returns:
          centroids: A `tensor` of shape `[batch_size, num_clusters, channels]`.
          logits: A `tensor` of shape
            `[batch_size, source_sequence_length, num_clusters]`.
          sampled_inds: A `tensor` of shape `[batch_size, num_clusters]`.
        """
        bs, sl, cs = src.shape

        # Sample indices based on significance score.
        sampled_inds = self._rank_significance(attn, mask, self._num_clusters)

        kmedoid_labels, sampled_inds = self._kmedoids(
            src, sampled_inds, mask,
            self._kmedoid_iterations, 'l2')
        unfold_sampled_inds = sampled_inds.unsqueeze(2).expand(-1, -1, cs)
        centroids = torch.gather(src, 1, unfold_sampled_inds)

        # Dummy variables.
        logits = seg_utils.one_hot(kmedoid_labels, self._num_clusters).type_as(src)

        return centroids, logits, sampled_inds


def valid_mean(x, mask):
     """Compute mean of x given valid mask.

     Args:
         x: A `float` tensor of shape `[batch_size, num_nodes, channels]`.
         mask: A `bool` tensor of shape `[batch_size, num_nodes]`, where
             `True` indicates the entry is valid.

     Returns:
         mean_x: A `float` tensor of shape `[batch_size, channels]`.
     """
     mask = mask.type_as(x).unsqueeze(2)
     sum_mask = torch.clamp(torch.sum(mask, dim=1), min=1)
     masked_x = x * mask
     mean_x = torch.sum(masked_x, dim=1) / sum_mask

     return mean_x

