# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Unit tests for `baselines.py`."""

import copy
import functools
from typing import Generator

from absl.testing import absltest
from absl.testing import parameterized
import chex

from clrs._src import baselines
from clrs._src import dataset
from clrs._src import probing
from clrs._src import processors
from clrs._src import samplers
from clrs._src import specs

import haiku as hk
import jax
import numpy as np

_Array = np.ndarray


def _error(x, y):
  return np.sum(np.abs(x-y))


def _make_sampler(algo: str, length: int) -> samplers.Sampler:
  sampler, _ = samplers.build_sampler(
      algo,
      seed=samplers.CLRS30['val']['seed'],
      num_samples=samplers.CLRS30['val']['num_samples'],
      length=length,
  )
  return sampler


def _without_permutation(feedback):
  """Replace should-be permutations with pointers."""
  outputs = []
  for x in feedback.outputs:
    if x.type_ != specs.Type.SHOULD_BE_PERMUTATION:
      outputs.append(x)
      continue
    assert x.location == specs.Location.NODE
    outputs.append(probing.DataPoint(name=x.name, location=x.location,
                                     type_=specs.Type.POINTER, data=x.data))
  return feedback._replace(outputs=outputs)


def _make_iterable_sampler(
    algo: str, batch_size: int,
    length: int) -> Generator[samplers.Feedback, None, None]:
  sampler = _make_sampler(algo, length)
  while True:
    yield _without_permutation(sampler.next(batch_size))


def _remove_permutation_from_spec(spec):
  """Modify spec to turn permutation type to pointer."""
  new_spec = {}
  for k in spec:
    if (spec[k][1] == specs.Location.NODE and
        spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION):
      new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER)
    else:
      new_spec[k] = spec[k]
  return new_spec


class BaselinesTest(parameterized.TestCase):

  def test_full_vs_chunked(self):
    """Test that chunking does not affect gradients."""

    batch_size = 4
    length = 8
    algo = 'insertion_sort'
    spec = _remove_permutation_from_spec(specs.SPECS[algo])
    rng_key = jax.random.PRNGKey(42)

    full_ds = _make_iterable_sampler(algo, batch_size, length)
    chunked_ds = dataset.chunkify(
        _make_iterable_sampler(algo, batch_size, length),
        length)
    double_chunked_ds = dataset.chunkify(
        _make_iterable_sampler(algo, batch_size, length),
        length * 2)

    full_batches = [next(full_ds) for _ in range(2)]
    chunked_batches = [next(chunked_ds) for _ in range(2)]
    double_chunk_batch = next(double_chunked_ds)

    with chex.fake_jit():  # jitting makes test longer

      processor_factory = processors.get_processor_factory(
          'mpnn', use_ln=False, nb_triplet_fts=0)
      common_args = dict(processor_factory=processor_factory, hidden_dim=8,
                         learning_rate=0.01,
                         decode_hints=True, encode_hints=True)

      b_full = baselines.BaselineModel(
          spec, dummy_trajectory=full_batches[0], **common_args)
      b_full.init(full_batches[0].features, seed=42)  # pytype: disable=wrong-arg-types  # jax-ndarray
      full_params = b_full.params
      full_loss_0 = b_full.feedback(rng_key, full_batches[0])
      b_full.params = full_params
      full_loss_1 = b_full.feedback(rng_key, full_batches[1])
      new_full_params = b_full.params

      b_chunked = baselines.BaselineModelChunked(
          spec, dummy_trajectory=chunked_batches[0], **common_args)
      b_chunked.init([[chunked_batches[0].features]], seed=42)  # pytype: disable=wrong-arg-types  # jax-ndarray
      chunked_params = b_chunked.params
      jax.tree_util.tree_map(np.testing.assert_array_equal, full_params,
                             chunked_params)
      chunked_loss_0 = b_chunked.feedback(rng_key, chunked_batches[0])
      b_chunked.params = chunked_params
      chunked_loss_1 = b_chunked.feedback(rng_key, chunked_batches[1])
      new_chunked_params = b_chunked.params

      b_chunked.params = chunked_params
      double_chunked_loss = b_chunked.feedback(rng_key, double_chunk_batch)

    # Test that losses match
    np.testing.assert_allclose(full_loss_0, chunked_loss_0, rtol=1e-4)
    np.testing.assert_allclose(full_loss_1, chunked_loss_1, rtol=1e-4)
    np.testing.assert_allclose(full_loss_0 + full_loss_1,
                               2 * double_chunked_loss,
                               rtol=1e-4)

    # Test that gradients are the same (parameters changed equally).
    # First check that gradients were not zero, i.e., parameters have changed.
    param_change, _ = jax.tree_util.tree_flatten(
        jax.tree_util.tree_map(_error, full_params, new_full_params))
    self.assertGreater(np.mean(param_change), 0.1)
    # Now check that full and chunked gradients are the same.
    jax.tree_util.tree_map(
        functools.partial(np.testing.assert_allclose, rtol=1e-4),
        new_full_params, new_chunked_params)

  def test_multi_vs_single(self):
    """Test that multi = single when we only train one of the algorithms."""

    batch_size = 4
    length = 16
    algos = ['insertion_sort', 'activity_selector', 'bfs']
    spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos]
    rng_key = jax.random.PRNGKey(42)

    full_ds = [_make_iterable_sampler(algo, batch_size, length)
               for algo in algos]
    full_batches = [next(ds) for ds in full_ds]
    full_batches_2 = [next(ds) for ds in full_ds]

    with chex.fake_jit():  # jitting makes test longer

      processor_factory = processors.get_processor_factory(
          'mpnn', use_ln=False, nb_triplet_fts=0)
      common_args = dict(processor_factory=processor_factory, hidden_dim=8,
                         learning_rate=0.01,
                         decode_hints=True, encode_hints=True)

      b_single = baselines.BaselineModel(
          spec[0], dummy_trajectory=full_batches[0], **common_args)
      b_multi = baselines.BaselineModel(
          spec, dummy_trajectory=full_batches, **common_args)
      b_single.init(full_batches[0].features, seed=0)  # pytype: disable=wrong-arg-types  # jax-ndarray
      b_multi.init([f.features for f in full_batches], seed=0)  # pytype: disable=wrong-arg-types  # jax-ndarray

      single_params = []
      single_losses = []
      multi_params = []
      multi_losses = []

      single_params.append(copy.deepcopy(b_single.params))
      single_losses.append(b_single.feedback(rng_key, full_batches[0]))
      single_params.append(copy.deepcopy(b_single.params))
      single_losses.append(b_single.feedback(rng_key, full_batches_2[0]))
      single_params.append(copy.deepcopy(b_single.params))

      multi_params.append(copy.deepcopy(b_multi.params))
      multi_losses.append(b_multi.feedback(rng_key, full_batches[0],
                                           algorithm_index=0))
      multi_params.append(copy.deepcopy(b_multi.params))
      multi_losses.append(b_multi.feedback(rng_key, full_batches_2[0],
                                           algorithm_index=0))
      multi_params.append(copy.deepcopy(b_multi.params))

    # Test that losses match
    np.testing.assert_array_equal(single_losses, multi_losses)
    # Test that loss decreased
    assert single_losses[1] < single_losses[0]

    # Test that param changes were the same in single and multi-algorithm
    for single, multi in zip(single_params, multi_params):
      assert hk.data_structures.is_subset(subset=single, superset=multi)
      for module_name, params in single.items():
        jax.tree_util.tree_map(np.testing.assert_array_equal, params,
                               multi[module_name])

    # Test that params change for the trained algorithm, but not the others
    for module_name, params in multi_params[0].items():
      param_changes = jax.tree_util.tree_map(lambda a, b: np.sum(np.abs(a - b)),
                                             params,
                                             multi_params[1][module_name])
      param_change = sum(param_changes.values())
      if module_name in single_params[0]:  # params of trained algorithm
        assert param_change > 1e-3
      else:  # params of non-trained algorithms
        assert param_change == 0.0

  @parameterized.parameters(True, False)
  def test_multi_algorithm_idx(self, is_chunked):
    """Test that algorithm selection works as intended."""

    batch_size = 4
    length = 8
    algos = ['insertion_sort', 'activity_selector', 'bfs']
    spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos]
    rng_key = jax.random.PRNGKey(42)

    if is_chunked:
      ds = [dataset.chunkify(_make_iterable_sampler(algo, batch_size, length),
                             2 * length) for algo in algos]
    else:
      ds = [_make_iterable_sampler(algo, batch_size, length) for algo in algos]
    batches = [next(d) for d in ds]

    processor_factory = processors.get_processor_factory(
        'mpnn', use_ln=False, nb_triplet_fts=0)
    common_args = dict(processor_factory=processor_factory, hidden_dim=8,
                       learning_rate=0.01,
                       decode_hints=True, encode_hints=True)
    if is_chunked:
      baseline = baselines.BaselineModelChunked(
          spec, dummy_trajectory=batches, **common_args)
      baseline.init([[f.features for f in batches]], seed=0)  # pytype: disable=wrong-arg-types  # jax-ndarray
    else:
      baseline = baselines.BaselineModel(
          spec, dummy_trajectory=batches, **common_args)
      baseline.init([f.features for f in batches], seed=0)  # pytype: disable=wrong-arg-types  # jax-ndarray

    # Find out what parameters change when we train each algorithm
    def _change(x, y):
      changes = {}
      for module_name, params in x.items():
        changes[module_name] = sum(
            jax.tree_util.tree_map(
                lambda a, b: np.sum(np.abs(a-b)), params, y[module_name]
                ).values())
      return changes

    param_changes = []
    for algo_idx in range(len(algos)):
      init_params = copy.deepcopy(baseline.params)
      _ = baseline.feedback(
          rng_key,
          batches[algo_idx],
          algorithm_index=(0, algo_idx) if is_chunked else algo_idx)
      param_changes.append(_change(init_params, baseline.params))

    # Test that non-changing parameters correspond to encoders/decoders
    # associated with the non-trained algorithms
    unchanged = [[k for k in pc if pc[k] == 0] for pc in param_changes]

    def _get_other_algos(algo_idx, modules):
      return set([k for k in modules if '_construct_encoders_decoders' in k
                  and f'algo_{algo_idx}' not in k])

    for algo_idx in range(len(algos)):
      expected_unchanged = _get_other_algos(algo_idx, baseline.params.keys())
      self.assertNotEmpty(expected_unchanged)
      self.assertSetEqual(expected_unchanged, set(unchanged[algo_idx]))


if __name__ == '__main__':
  absltest.main()
