# Adapted from https://github.com/deepmind/optax/
#
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from collections.abc import Callable
from typing import Any, NamedTuple, Optional, Union

import jax
import jax.numpy as jnp
from optax._src import base, combine, transform, wrappers

from egxc.utils.typing import UInt1


class WeightDecaySchedule(NamedTuple):
    """Maintains count for weight decay scheduling."""

    step: UInt1


def add_decayed_weights(
    weight_decay: Union[float, jax.Array, base.ScalarOrSchedule] = 0.0,
    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
    """Add parameter scaled by `weight_decay`.

    Args:
      weight_decay: A scalar weight decay rate.
      mask: A tree with same structure as (or a prefix of) the params PyTree, or a
        Callable that returns such a pytree given the params/updates. The leaves
        should be booleans, `True` for leaves/subtrees you want to apply the
        transformation to, and `False` for those you want to skip.

    Returns:
      A :class:`optax.GradientTransformation` object.
    """

    def init_fn(params):
        del params
        if callable(weight_decay):
            return WeightDecaySchedule(step=jnp.array(0, dtype=jnp.uint32))
        else:
            return base.EmptyState()

    def update_fn(updates, state, params):
        if params is None:
            raise ValueError(base.NO_PARAMS_MSG)
        s = weight_decay(state.step) if callable(weight_decay) else weight_decay
        updates = jax.tree.map(
            lambda g, p: None if g is None else g + s * p,
            updates,
            params,
            is_leaf=lambda x: x is None,
        )
        return updates, state

    # If mask is not `None`, apply mask to the gradient transformation.
    # E.g. it is common to skip weight decay on bias units and batch stats.
    if mask is not None:
        return wrappers.masked(
            base.GradientTransformation(init_fn, update_fn),  # type: ignore
            mask,
        )
    return base.GradientTransformation(init_fn, update_fn)  # type: ignore


def fromage(
    learning_rate: base.ScalarOrSchedule, min_norm: float = 1e-6
) -> base.GradientTransformationExtraArgs:
    """The Frobenius matched gradient descent (Fromage) optimizer.

    Fromage is a learning algorithm that does not require learning rate tuning.
    The optimizer is based on modeling neural network gradients via deep relative
    trust (a distance function on deep neural networks). Fromage is similar to the
    LARS optimizer and can work on a range of standard neural network benchmarks,
    such as natural language Transformers and generative adversarial networks.

    Args:
      learning_rate: A global scaling factor, either fixed or evolving along
        iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
      min_norm: A minimum value that the norm of the gradient updates and the norm
        of the layer parameters can be clipped to to avoid dividing by zero when
        computing the trust ratio (as in the LARS paper).

    Returns:
      The corresponding :class:`optax.GradientTransformationExtraArgs`.
    """
    if not callable(learning_rate):
        mult = 1 / jnp.sqrt(1 + learning_rate**2)
        return combine.chain(
            transform.scale_by_trust_ratio(min_norm),
            transform.scale_by_learning_rate(learning_rate * mult),
            transform.add_decayed_weights((mult - 1)),
        )
    else:

        def mult_lr(step):
            return 1 / jnp.sqrt(1 + learning_rate(step) ** 2)

        return combine.chain(
            transform.scale_by_trust_ratio(min_norm),
            transform.scale_by_learning_rate(lambda c: mult_lr(c) * learning_rate(c)),
            add_decayed_weights(lambda c: mult_lr(c) - 1),
        )


def custom_fromage(
    learning_rate: base.ScalarOrSchedule, weight_decay: float, min_norm: float = 1e-6
) -> base.GradientTransformation:
    """
    Based on the Frobenius matched gradient descent (Fromage) optimizer.
    Bernstein et al, 2020: https://arxiv.org/abs/2002.03432
    """
    return combine.chain(
        transform.scale_by_trust_ratio(min_norm),
        transform.scale_by_learning_rate(learning_rate),
        add_decayed_weights(weight_decay),
    )
