# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Unit tests for `losses.py`."""

from typing import Generator

from absl.testing import absltest
from absl.testing import parameterized

from clrs._src import dataset
from clrs._src import losses
from clrs._src import probing
from clrs._src import samplers
from clrs._src import specs
import jax
import jax.numpy as jnp
import numpy as np

_Array = np.ndarray
_Location = specs.Location


def _make_sampler(algo: str, nb_nodes: int) -> samplers.Sampler:
  sampler, _ = samplers.build_sampler(
      algo,
      seed=samplers.CLRS30['val']['seed'],
      num_samples=samplers.CLRS30['val']['num_samples'],
      length=nb_nodes,
  )
  return sampler


def _make_iterable_sampler(
    algo: str, batch_size: int,
    nb_nodes: int) -> Generator[samplers.Feedback, None, None]:
  sampler = _make_sampler(algo, nb_nodes)
  while True:
    yield sampler.next(batch_size)


def _as_pred_data(x, nb_nodes, seed, batch_axis):
  """Fake a prediction from a data point."""
  # Permute along batch axis to make the prediction different.
  key = jax.random.PRNGKey(seed)
  data = jax.random.permutation(key, x.data, axis=batch_axis)
  # Extend to one-hot for pointer types.
  if x.type_ == specs.Type.POINTER:
    return jax.nn.one_hot(data, nb_nodes)
  return data


def _mask_datapoint(x, seed, t_axis=None):
  """Add some masking to data."""
  key = jax.random.PRNGKey(seed)
  data = x.data
  if x.type_ == specs.Type.MASK:
    # mask some data at random
    mask_shape = list(data.shape)
    if t_axis is not None:
      mask_shape[t_axis] = 1
    mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2
    data = jnp.where(mask, specs.OutputClass.MASKED, data)
  elif x.type_ in [specs.Type.CATEGORICAL, specs.Type.MASK_ONE]:
    # mask some data at random (all categories together)
    mask_shape = list(data.shape)[:-1]
    if t_axis is not None:
      mask_shape[t_axis] = 1
    mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2
    data = jnp.where(mask[..., None], specs.OutputClass.MASKED, data)
  return probing.DataPoint(name=x.name, location=x.location, type_=x.type_,
                           data=data)


def _rand_diff(seed, shape):
  return 2.0 * jax.random.uniform(jax.random.PRNGKey(seed), shape) - 1.0


def _rand_mask(seed, shape, p=0.5):
  return (jax.random.uniform(jax.random.PRNGKey(seed), shape) > p).astype(float)


def invert(d):
  """Dict of lists -> list of dicts."""
  if d:
    return [dict(zip(d, i)) for i in zip(*d.values())]


def _create_data(algo, nb_nodes):
  batch_size = 8

  ds = _make_iterable_sampler(algo, batch_size, nb_nodes)
  full_sample = next(ds)

  chunk_length = full_sample.features.lengths[0].astype(int)
  chunked_ds = dataset.chunkify(
      _make_iterable_sampler(algo, batch_size, nb_nodes),
      chunk_length)
  chunk_sample = next(chunked_ds)

  gt_diffs = {
      _Location.NODE: _rand_mask(0, (chunk_length, batch_size,
                                     nb_nodes)),
      _Location.EDGE: _rand_mask(1, (chunk_length, batch_size,
                                     nb_nodes, nb_nodes)),
      _Location.GRAPH: _rand_mask(2, (chunk_length, batch_size)),
  }
  diff_logits = {
      _Location.NODE: _rand_diff(3, (chunk_length, batch_size,
                                     nb_nodes)),
      _Location.EDGE: _rand_diff(4, (chunk_length, batch_size,
                                     nb_nodes, nb_nodes)),
      _Location.GRAPH: _rand_diff(5, (chunk_length, batch_size)),
  }

  return full_sample, chunk_sample, gt_diffs, diff_logits


class FullVsChunkLossesTest(parameterized.TestCase):
  """Test that the full and chunked versions of the losses match."""

  # Test two algorithms with fixed-length, covering all data types
  @parameterized.parameters('dfs', 'floyd_warshall')
  def test_output_loss(self, algo):
    nb_nodes = 16
    full_sample, chunk_sample, _, _ = _create_data(algo, nb_nodes)

    # Calculate output loss.
    for truth_full, truth_chunked in zip(full_sample.outputs,
                                         chunk_sample.outputs):
      chunk_output_loss = losses.output_loss_chunked(
          truth=_mask_datapoint(truth_chunked, seed=0),
          pred=_as_pred_data(truth_chunked, nb_nodes, 0, 1),
          is_last=chunk_sample.features.is_last,
          nb_nodes=nb_nodes,
      )
      full_output_loss = losses.output_loss(
          truth=_mask_datapoint(truth_full, seed=0),
          pred=_as_pred_data(truth_full, nb_nodes, 0, 0),
          nb_nodes=nb_nodes,
      )
      np.testing.assert_allclose(chunk_output_loss, full_output_loss, rtol=1e-4)

  @parameterized.parameters('dfs', 'floyd_warshall')
  def test_diff_loss(self, algo):
    nb_nodes = 16
    full_sample, chunk_sample, gt_diffs, diff_logits = _create_data(
        algo, nb_nodes)
    chunk_diff_loss = losses.diff_loss_chunked(
        diff_logits=diff_logits,
        gt_diffs=gt_diffs,
        is_first=chunk_sample.features.is_first,
    )
    full_diff_loss = losses.diff_loss(
        diff_logits=invert(diff_logits)[1:],
        gt_diffs=invert(gt_diffs)[1:],
        lengths=full_sample.features.lengths,
    )
    np.testing.assert_allclose(chunk_diff_loss, full_diff_loss, rtol=1e-4)

  @parameterized.parameters('dfs', 'floyd_warshall')
  def test_hint_loss(self, algo):
    nb_nodes = 16
    full_sample, chunk_sample, gt_diffs, unused_diff_logits = _create_data(
        algo, nb_nodes)
    for decode_diffs in [False, True]:
      for truth_full, truth_chunked in zip(full_sample.features.hints,
                                           chunk_sample.features.hints):
        np.testing.assert_array_equal(truth_full.data, truth_chunked.data)
        pred = _as_pred_data(truth_chunked, nb_nodes, 0, 1)
        chunk_hint_loss = losses.hint_loss_chunked(
            truth=_mask_datapoint(truth_chunked, seed=1, t_axis=0),
            pred=pred,
            gt_diffs=gt_diffs,
            is_first=chunk_sample.features.is_first,
            nb_nodes=nb_nodes,
            decode_diffs=decode_diffs,
        )

        full_preds = pred[1:]
        full_gt_diffs = invert(gt_diffs)[1:]
        full_hint_loss = losses.hint_loss(
            truth=_mask_datapoint(truth_full, 1, t_axis=0),
            preds=full_preds,
            gt_diffs=full_gt_diffs,
            lengths=full_sample.features.lengths,
            nb_nodes=nb_nodes,
            decode_diffs=decode_diffs,
        )
        np.testing.assert_allclose(chunk_hint_loss, full_hint_loss, rtol=1e-4)


if __name__ == '__main__':
  absltest.main()
