# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for t5x.optimizers."""

import dataclasses
import functools
import operator

from absl.testing import absltest
from absl.testing import parameterized
import chex
import flax
from flax.core import frozen_dict
import jax
import jax.numpy as jnp
import numpy as np
import optax
import seqio
from t5x import models
from t5x import optimizers
from t5x import partitioning
from t5x import test_utils
from t5x import trainer
from t5x import utils
from t5x.examples.t5 import network


def _assert_numpy_allclose(a, b, atol=None, rtol=None):
  a, b = jnp.array(a), jnp.array(b)
  a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a
  b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b
  kw = {}
  if atol:
    kw['atol'] = atol
  if rtol:
    kw['rtol'] = rtol
  np.testing.assert_allclose(a, b, **kw)


def check_eq(xs, ys, atol=None, rtol=None):
  xs_leaves, xs_tree = jax.tree_flatten(xs)
  ys_leaves, ys_tree = jax.tree_flatten(ys)
  assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}"
  assert jax.tree_util.tree_all(
      jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape,
                        xs_leaves, ys_leaves)), "Leaves' shapes don't match."
  assert jax.tree_multimap(
      functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol),
      xs_leaves, ys_leaves)


def flattened_state_dict(x):
  s = flax.serialization.to_state_dict(x)
  return flax.traverse_util.flatten_dict(s, sep='/')


def tree_shape(x):
  return jax.tree_map(jnp.shape, x)


def tree_equals(x, y):
  return jax.tree_util.tree_all(jax.tree_multimap(operator.eq, x, y))


def get_fake_tokenized_dataset_no_pretokenized(*_, split='validation', **__):
  return test_utils.get_fake_tokenized_dataset(split=split).map(
      lambda x: {k: v for k, v in x.items() if not k.endswith('_pretokenized')})


def get_t5_test_model(optimizer_def,
                      **config_overrides) -> models.EncoderDecoderModel:
  """Returns a tiny T5 1.1 model to use for testing."""
  tiny_config = network.T5Config(
      vocab_size=128,
      dtype='bfloat16',
      emb_dim=8,
      num_heads=4,
      num_encoder_layers=2,
      num_decoder_layers=2,
      head_dim=3,
      mlp_dim=16,
      mlp_activations=('gelu', 'linear'),
      dropout_rate=0.0,
      logits_via_embedding=False,
  )
  tiny_config = dataclasses.replace(tiny_config, **config_overrides)
  vocabulary = test_utils.get_fake_vocab()
  return models.EncoderDecoderModel(
      module=network.Transformer(tiny_config),
      input_vocabulary=vocabulary,
      output_vocabulary=vocabulary,
      optimizer_def=optimizer_def)


class BasicTest(chex.TestCase):

  @classmethod
  def get_params(cls):
    return frozen_dict.FrozenDict({
        'forward': {
            'input_layer': {
                'embedding': jnp.zeros([16, 8], dtype=jnp.float32),
            },
            'output_layer': {
                'layer_norm': {
                    'scale': jnp.zeros([8], dtype=jnp.float32),
                },
                'proj': {
                    'bias': jnp.zeros([1], dtype=jnp.float32),
                    'kernel': jnp.zeros([8, 1], dtype=jnp.float32),
                },
            },
        },
        'loss': {
            'loss_fn': {
                'loss_biases': jnp.zeros([2], dtype=jnp.float32),
            },
        },
    })

  @classmethod
  def get_params_shapes(cls):
    return jax.tree_map(jnp.shape, cls.get_params())

  @classmethod
  def get_param_logical_axes(cls):
    return frozen_dict.FrozenDict({
        'forward': {
            'input_layer': {
                'embedding': partitioning.PartitionSpec('vocab', 'embed'),
            },
            'output_layer': {
                'layer_norm': {
                    'scale': partitioning.PartitionSpec('embed',),
                },
                'proj': {
                    'bias':
                        partitioning.PartitionSpec('output_head',),
                    'kernel':
                        partitioning.PartitionSpec('embed', 'output_head'),
                },
            },
        },
        'loss': {
            'loss_fn': {
                'loss_biases': partitioning.PartitionSpec('unmodeled',),
            },
        },
    })

  def test_logical_axes_adamw(self):
    opt = optax.adamw(0.001, weight_decay=0.001)
    wrapper = optimizers.OptaxWrapper(opt)
    optimizer = wrapper.create(self.get_params())
    got = wrapper.derive_logical_axes(optimizer, self.get_param_logical_axes())
    want = optimizers.Optimizer(
        optimizer_def=wrapper,
        state=optimizers.OptimizerState(
            step=None,
            param_states=(
                optax.ScaleByAdamState(
                    count=None,
                    mu=self.get_param_logical_axes(),
                    nu=self.get_param_logical_axes()),
                optax.EmptyState(),
                optax.EmptyState(),
            )),
        target=self.get_param_logical_axes())
    chex.assert_trees_all_equal(got, want)

  @parameterized.parameters(
      ('sgd', lambda: optax.sgd(1e-2, 0.0)),
      ('adam', lambda: optax.adam(1e-1)),
      ('adamw', lambda: optax.adamw(1e-1)),
      ('lamb', lambda: optax.adamw(1e-1)),
      ('rmsprop', lambda: optax.rmsprop(1e-1)),
      ('rmsprop_momentum', lambda: optax.rmsprop(5e-2, momentum=0.9)),
      ('fromage', lambda: optax.fromage(1e-2)),
      ('adabelief', lambda: optax.adabelief(1e-1)),
      ('radam', lambda: optax.radam(1e-1)),
      ('yogi', lambda: optax.yogi(1.0)),
  )
  def test_sanity_check_logical_axes(self, opt_name, opt_fn):
    opt = opt_fn()

    wrapper = optimizers.OptaxWrapper(opt)
    optimizer = wrapper.create(self.get_params())
    _ = wrapper.derive_logical_axes(optimizer, self.get_param_logical_axes())

    # TODO(rosun): basic sanity check, we just want to make sure if a param
    # name, e.g., `loss_biases` appear in the tree, the corresponding value is
    # always a PartitionSpec.

  def test_adamw_state_serialization(self):
    opt = optax.adamw(0.001, weight_decay=0.001)
    wrapper = optimizers.OptaxWrapper(opt)
    optimizer = wrapper.create(self.get_params())

    state_dict = optimizer.state_dict()

    chex.assert_trees_all_equal(
        frozen_dict.FrozenDict(jax.tree_map(jnp.shape, state_dict)),
        frozen_dict.FrozenDict({
            'target': self.get_params_shapes(),
            'state': {
                'step': (),
                'param_states': {
                    '0': {
                        'count': (),
                        'mu': self.get_params_shapes(),
                        'nu': self.get_params_shapes(),
                    },
                    '1': {},
                    '2': {},
                },
            }
        }))

    new_optimizer = optimizer.restore_state(state_dict)

    chex.assert_trees_all_equal(optimizer, new_optimizer)


class OptaxWrapperTest(chex.TestCase):

  def run_train_loop(self, optimizer_def):
    # Construct input data.

    ds = get_fake_tokenized_dataset_no_pretokenized(split='validation')
    ds = seqio.EncDecFeatureConverter()(
        ds, task_feature_lengths={
            'inputs': 8,
            'targets': 8
        })
    ds = ds.repeat().batch(8)
    ds_iter = ds.as_numpy_iterator()
    first_batch = next(ds_iter)

    model = get_t5_test_model(optimizer_def, vocab_size=128)

    learning_rate_fn = utils.create_learning_rate_scheduler()

    input_shapes = jax.tree_map(jnp.shape, first_batch)
    input_types = jax.tree_map(lambda x: jnp.dtype(x.dtype), first_batch)

    partitioner = partitioning.PjitPartitioner(
        num_partitions=2,
        logical_axis_rules=partitioning.standard_logical_axis_rules())

    train_state_initializer = utils.TrainStateInitializer(
        optimizer_def=model.optimizer_def,
        init_fn=model.get_initial_variables,
        input_shapes=input_shapes,
        input_types=input_types,
        partitioner=partitioner)

    train_state_axes = train_state_initializer.train_state_axes
    train_state = train_state_initializer.from_scratch(jax.random.PRNGKey(0))

    trainer_instance = trainer.Trainer(
        model,
        train_state=train_state,
        partitioner=partitioner,
        eval_names=[],
        summary_dir=None,
        train_state_axes=train_state_axes,
        rng=jax.random.PRNGKey(0),
        learning_rate_fn=learning_rate_fn,
        num_microbatches=1)

    chex.assert_tree_all_finite(train_state.params)
    for _ in range(2):
      trainer_instance.train(ds_iter, 1)
      chex.assert_tree_all_finite(train_state.params)

  # NOTE(levskaya): these are surprisingly slow tests on CPU.
  @parameterized.parameters(
      ('sgd', lambda: optax.sgd(1e-2, 0.0)),
      ('adam', lambda: optax.adam(1e-1)),
      # ('adamw', lambda: optax.adamw(1e-1)),
      ('lamb', lambda: optax.adamw(1e-1)),
      # ('rmsprop', lambda: optax.rmsprop(1e-1)),
      # ('rmsprop_momentum', lambda: optax.rmsprop(5e-2, momentum=0.9)),
      # ('fromage', lambda: optax.fromage(1e-2)),
      ('adabelief', lambda: optax.adabelief(1e-1)),
      # ('radam', lambda: optax.radam(1e-1)),
      ('yogi', lambda: optax.yogi(1.0)),
  )
  def test_optimizer(self, opt_name, opt_fn):
    opt = opt_fn()
    optimizer_def = optimizers.OptaxWrapper(opt)
    self.run_train_loop(optimizer_def)


if __name__ == '__main__':
  absltest.main()
