# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Unit tests for `losses.py`."""

from typing import Generator

from absl.testing import absltest
from absl.testing import parameterized

from clrs._src import dataset
from clrs._src import losses
from clrs._src import probing
from clrs._src import samplers
from clrs._src import specs
import jax
import jax.numpy as jnp
import numpy as np

_Array = np.ndarray
_Location = specs.Location


def _make_sampler(algo: str, nb_nodes: int) -> samplers.Sampler:
  sampler, _ = samplers.build_sampler(
      algo,
      seed=samplers.CLRS30['val']['seed'],
      num_samples=samplers.CLRS30['val']['num_samples'],
      length=nb_nodes,
  )
  return sampler


def _make_iterable_sampler(
    algo: str, batch_size: int,
    nb_nodes: int) -> Generator[samplers.Feedback, None, None]:
  sampler = _make_sampler(algo, nb_nodes)
  while True:
    yield sampler.next(batch_size)


def _as_pred_data(x, nb_nodes, seed, batch_axis):
  """Fake a prediction from a data point."""
  # Permute along batch axis to make the prediction different.
  key = jax.random.PRNGKey(seed)
  data = jax.random.permutation(key, x.data, axis=batch_axis)
  # Extend to one-hot for pointer types.
  if x.type_ == specs.Type.POINTER:
    return jax.nn.one_hot(data, nb_nodes)
  return data


def _mask_datapoint(x, seed, t_axis=None):
  """Add some masking to data."""
  key = jax.random.PRNGKey(seed)
  data = x.data
  if x.type_ == specs.Type.MASK:
    # mask some data at random
    mask_shape = list(data.shape)
    if t_axis is not None:
      mask_shape[t_axis] = 1
    mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2
    data = jnp.where(mask, specs.OutputClass.MASKED, data)
  elif x.type_ in [specs.Type.CATEGORICAL, specs.Type.MASK_ONE]:
    # mask some data at random (all categories together)
    mask_shape = list(data.shape)[:-1]
    if t_axis is not None:
      mask_shape[t_axis] = 1
    mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2
    data = jnp.where(mask[..., None], specs.OutputClass.MASKED, data)
  return probing.DataPoint(name=x.name, location=x.location, type_=x.type_,
                           data=data)


def _rand_diff(seed, shape):
  return 2.0 * jax.random.uniform(jax.random.PRNGKey(seed), shape) - 1.0


def _rand_mask(seed, shape, p=0.5):
  return (jax.random.uniform(jax.random.PRNGKey(seed), shape) > p).astype(float)


def invert(d):
  """Dict of lists -> list of dicts."""
  if d:
    return [dict(zip(d, i)) for i in zip(*d.values())]


def _create_data(algo, nb_nodes):
  batch_size = 8

  ds = _make_iterable_sampler(algo, batch_size, nb_nodes)
  full_sample = next(ds)

  chunk_length = full_sample.features.lengths[0].astype(int)
  chunked_ds = dataset.chunkify(
      _make_iterable_sampler(algo, batch_size, nb_nodes),
      chunk_length)
  chunk_sample = next(chunked_ds)
  return full_sample, chunk_sample


class FullVsChunkLossesTest(parameterized.TestCase):
  """Test that the full and chunked versions of the losses match."""

  # Test two algorithms with fixed-length, covering all data types
  @parameterized.parameters('dfs', 'floyd_warshall')
  def test_output_loss(self, algo):
    nb_nodes = 16
    full_sample, chunk_sample = _create_data(algo, nb_nodes)

    # Calculate output loss.
    for truth_full, truth_chunked in zip(full_sample.outputs,
                                         chunk_sample.outputs):
      chunk_output_loss = losses.output_loss_chunked(
          truth=_mask_datapoint(truth_chunked, seed=0),
          pred=_as_pred_data(truth_chunked, nb_nodes, 0, 1),
          is_last=chunk_sample.features.is_last,
          nb_nodes=nb_nodes,
      )
      full_output_loss = losses.output_loss(
          truth=_mask_datapoint(truth_full, seed=0),
          pred=_as_pred_data(truth_full, nb_nodes, 0, 0),
          nb_nodes=nb_nodes,
      )
      np.testing.assert_allclose(chunk_output_loss, full_output_loss, rtol=1e-4)

  @parameterized.parameters('dfs', 'floyd_warshall')
  def test_hint_loss(self, algo):
    nb_nodes = 16
    full_sample, chunk_sample = _create_data(algo, nb_nodes)
    for truth_full, truth_chunked in zip(full_sample.features.hints,
                                         chunk_sample.features.hints):
      np.testing.assert_array_equal(truth_full.data, truth_chunked.data)
      pred = _as_pred_data(truth_chunked, nb_nodes, 0, 1)
      chunk_hint_loss = losses.hint_loss_chunked(
          truth=_mask_datapoint(truth_chunked, seed=1, t_axis=0),
          pred=pred,
          is_first=chunk_sample.features.is_first,
          nb_nodes=nb_nodes,
      )

      full_preds = pred[1:]
      full_hint_loss = losses.hint_loss(
          truth=_mask_datapoint(truth_full, 1, t_axis=0),
          preds=full_preds,
          lengths=full_sample.features.lengths,
          nb_nodes=nb_nodes,
      )
      np.testing.assert_allclose(chunk_hint_loss, full_hint_loss, rtol=1e-4)


if __name__ == '__main__':
  absltest.main()
