"""Clustering metrics."""

from typing import  Optional, Sequence, Union

from clu import metrics
import flax
import jax
import jax.numpy as jnp
import numpy as np

Ndarray = Union[np.ndarray, jnp.ndarray]


def check_shape(x, expected_shape: Sequence[Optional[int]], name: str):
  """Check whether shape x is as expected.

  Args:
    x: Any data type with `shape` attribute. If `shape` attribute is not present
      it is assumed to be a scalar with shape ().
    expected_shape: The shape that is expected of x. For example,
      [None, None, 3] can be the `expected_shape` for a color image,
      [4, None, None, 3] if we know that batch size is 4.
    name: Name of `x` to provide informative error messages.

  Raises: ValueError if x's shape does not match expected_shape. Also raises
    ValueError if expected_shape is not a list or tuple.
  """
  if not isinstance(expected_shape, (list, tuple)):
    raise ValueError(
        "expected_shape should be a list or tuple of ints but got "
        f"{expected_shape}.")

  # Scalars have shape () by definition.
  shape = getattr(x, "shape", ())

  if (len(shape) != len(expected_shape) or
      any(j is not None and i != j for i, j in zip(shape, expected_shape))):
    raise ValueError(
        f"Input {name} had shape {shape} but {expected_shape} was expected.")


def _validate_inputs(predicted_segmentations: Ndarray,
                     ground_truth_segmentations: Ndarray,
                     padding_mask: Ndarray,
                     mask: Optional[Ndarray] = None) -> None:
  """Checks that all inputs have the expected shapes.

  Args:
    predicted_segmentations: An array of integers of shape [bs, seq_len, H, W]
      containing model segmentation predictions.
    ground_truth_segmentations: An array of integers of shape [bs, seq_len, H,
      W] containing ground truth segmentations.
    padding_mask: An array of integers of shape [bs, seq_len, H, W] defining
      regions where the ground truth is meaningless, for example because this
      corresponds to regions which were padded during data augmentation. Value 0
      corresponds to padded regions, 1 corresponds to valid regions to be used
      for metric calculation.
    mask: An optional array of boolean mask values of shape [bs]. `True`
      corresponds to actual batch examples whereas `False` corresponds to
      padding.

  Raises:
    ValueError if the inputs are not valid.
  """

  check_shape(
      predicted_segmentations, [None, None, None],
      "predicted_segmentations [bs, h, w]")
  check_shape(
      ground_truth_segmentations, [None, None, None],
      "ground_truth_segmentations [bs, h, w]")
  check_shape(
      predicted_segmentations, ground_truth_segmentations.shape,
      "predicted_segmentations [should match ground_truth_segmentations]")

  if not jnp.issubdtype(predicted_segmentations.dtype, jnp.integer):
    raise ValueError("predicted_segmentations has to be integer-valued. "
                     "Got {}".format(predicted_segmentations.dtype))

  if not jnp.issubdtype(ground_truth_segmentations.dtype, jnp.integer):
    raise ValueError("ground_truth_segmentations has to be integer-valued. "
                     "Got {}".format(ground_truth_segmentations.dtype))


def adjusted_rand_index(true_ids: Ndarray, pred_ids: Ndarray,
                        num_instances_true: int, num_instances_pred: int,
                        padding_mask: Optional[Ndarray] = None,
                        ignore_background: bool = False) -> Ndarray:
  """Computes the adjusted Rand index (ARI), a clustering similarity score.

  Args:
    true_ids: An integer-valued array of shape
      [batch_size, seq_len, H, W]. The true cluster assignment encoded
      as integer ids.
    pred_ids: An integer-valued array of shape
      [batch_size, seq_len, H, W]. The predicted cluster assignment
      encoded as integer ids.
    num_instances_true: An integer, the number of instances in true_ids
      (i.e. max(true_ids) + 1).
    num_instances_pred: An integer, the number of instances in true_ids
      (i.e. max(pred_ids) + 1).
    padding_mask: An array of integers of shape [batch_size, seq_len, H, W]
        defining regions where the ground truth is meaningless, for example
        because this corresponds to regions which were padded during data
        augmentation. Value 0 corresponds to padded regions, 1 corresponds to
        valid regions to be used for metric calculation.
    ignore_background: Boolean, if True, then ignore all pixels where
      true_ids == 0 (default: False).

  Returns:
    ARI scores as a float32 array of shape [batch_size].

  References:
    Lawrence Hubert, Phipps Arabie. 1985. "Comparing partitions"
      https://link.springer.com/article/10.1007/BF01908075
    Wikipedia
      https://en.wikipedia.org/wiki/Rand_index
    Scikit Learn
      http://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html
  """
  # pylint: disable=invalid-name
  true_oh = jax.nn.one_hot(true_ids, num_instances_true)
  pred_oh = jax.nn.one_hot(pred_ids, num_instances_pred)
  if padding_mask is not None:
    true_oh = true_oh * padding_mask[..., None]
    # pred_oh = pred_oh * padding_mask[..., None]  # <--  not needed

  if ignore_background:
    true_oh = true_oh[..., 1:]  # Remove the background row.

  N = jnp.einsum("bhwc,bhwk->bck", true_oh, pred_oh)
  A = jnp.sum(N, axis=-1)  # row-sum  (batch_size, c)
  B = jnp.sum(N, axis=-2)  # col-sum  (batch_size, k)
  num_points = jnp.sum(A, axis=1)

  rindex = jnp.sum(N * (N - 1), axis=[1, 2])
  aindex = jnp.sum(A * (A - 1), axis=1)
  bindex = jnp.sum(B * (B - 1), axis=1)
  expected_rindex = aindex * bindex / jnp.clip(num_points * (num_points-1), 1)
  max_rindex = (aindex + bindex) / 2
  denominator = max_rindex - expected_rindex
  ari = (rindex - expected_rindex) / denominator

  # There are two cases for which the denominator can be zero:
  # 1. If both label_pred and label_true assign all pixels to a single cluster.
  #    (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
  # 2. If both label_pred and label_true assign max 1 point to each cluster.
  #    (max_rindex == expected_rindex == rindex == 0)
  # In both cases, we want the ARI score to be 1.0:
  return jnp.where(denominator, ari, 1.0)


@flax.struct.dataclass
class Ari(metrics.Average):
  """Adjusted Rand Index (ARI) computed from predictions and labels.

  ARI is a similarity score to compare two clusterings. ARI returns values in
  the range [-1, 1], where 1 corresponds to two identical clusterings (up to
  permutation), i.e. a perfect match between the predicted clustering and the
  ground-truth clustering. A value of (close to) 0 corresponds to chance.
  Negative values corresponds to cases where the agreement between the
  clusterings is less than expected from a random assignment.

  In this implementation, we use ARI to compare predicted instance segmentation
  masks (including background prediction) with ground-truth segmentation
  annotations.
  """

  @classmethod
  def from_model_output(cls,
                        predicted_segmentations: Ndarray,
                        ground_truth_segmentations: Ndarray,
                        padding_mask: Ndarray,
                        ground_truth_max_num_instances: int,
                        predicted_max_num_instances: int,
                        ignore_background: bool = False,
                        mask: Optional[Ndarray] = None,
                        **_) -> metrics.Metric:
    """Computation of the ARI clustering metric.

    NOTE: This implementation does not currently support padding masks.

    Args:
      predicted_segmentations: An array of integers of shape
        [bs, seq_len, H, W] containing model segmentation predictions.
      ground_truth_segmentations: An array of integers of shape
        [bs, seq_len, H, W] containing ground truth segmentations.
      padding_mask: An array of integers of shape [bs, seq_len, H, W]
        defining regions where the ground truth is meaningless, for example
        because this corresponds to regions which were padded during data
        augmentation. Value 0 corresponds to padded regions, 1 corresponds to
        valid regions to be used for metric calculation.
      ground_truth_max_num_instances: Maximum number of instances (incl.
        background, which counts as the 0-th instance) possible in the dataset.
      predicted_max_num_instances: Maximum number of predicted instances (incl.
        background).
      ignore_background: If True, then ignore all pixels where
        ground_truth_segmentations == 0 (default: False).
      mask: An optional array of boolean mask values of shape [bs]. `True`
        corresponds to actual batch examples whereas `False` corresponds to
        padding.

    Returns:
      Object of Ari with computed intermediate values.
    """
    _validate_inputs(
        predicted_segmentations=predicted_segmentations,
        ground_truth_segmentations=ground_truth_segmentations,
        padding_mask=padding_mask,
        mask=mask)

    batch_size = predicted_segmentations.shape[0]
    
    mask = jnp.ones(batch_size, dtype=predicted_segmentations.dtype)

    ari_batch = adjusted_rand_index(
        pred_ids=predicted_segmentations,
        true_ids=ground_truth_segmentations,
        num_instances_true=ground_truth_max_num_instances,
        num_instances_pred=predicted_max_num_instances,
        padding_mask=padding_mask,
        ignore_background=ignore_background)
    return cls(total=jnp.sum(ari_batch * mask), count=jnp.sum(mask))  # pytype: disable=wrong-keyword-args


@flax.struct.dataclass
class AriNoBg(Ari):
  """Adjusted Rand Index (ARI), ignoring the ground-truth background label."""

  @classmethod
  def from_model_output(cls, **kwargs) -> metrics.Metric:
    """See `Ari` docstring for allowed keyword arguments."""
    return super().from_model_output(**kwargs, ignore_background=True)