# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Gradient transformations and other optax utilities."""

import operator
import big_vision.utils as u
import jax
import jax.numpy as jnp
import optax


def find_states(opt_state, cls):
  leaves = jax.tree.leaves(
      opt_state, is_leaf=lambda node: isinstance(node, cls))
  return [leaf for leaf in leaves if isinstance(leaf, cls)]


def get_count(opt_state, jittable=False):
  """Returns `ScaleByScheduleState.count` from `opt_state` as an integer."""
  counts = [
      state.count
      for state in find_states(opt_state, optax.ScaleByScheduleState)
  ]
  if jittable:
    return counts[0]
  else:
    counts = {int(c) for c in counts}
    assert len(counts) == 1, f"Expected exactly 1 ScaleByScheduleState:{counts}"
    return next(iter(counts))


def replace_frozen(schedule, pytree, replacement, log=None):
  """Replaces values matching frozen params in `pytree` with `replacement`."""
  if not isinstance(schedule, (list, tuple)):
    return pytree
  masks, scheds = _make_mask_trees(pytree, schedule, log=log)
  frozen_mask, _, _ = _split_frozen(masks, scheds)
  return jax.tree.map(
      lambda v, f: replacement if f else v, pytree, frozen_mask)


def clip_by_per_example_global_norm(
    max_norm: float,
) -> optax.GradientTransformation:
  """Clips the norm of per-example gradients."""

  def init_fn(params):
    del params
    return optax.EmptyState()

  def update_fn(updates, state, params=None):
    del params
    grads_flat, grads_treedef = jax.tree_util.tree_flatten(updates)
    clipped, _ = optax.per_example_global_norm_clip(grads_flat, max_norm)
    return jax.tree_util.tree_unflatten(grads_treedef, clipped), state

  return optax.GradientTransformation(init_fn, update_fn)


def make(config, params, *, sched_kw):
  """Returns gradient transform and learning rate functions."""

  # Global schedule. No schedule means frozen.
  schedule = config.get("schedule", {})
  if not isinstance(schedule, (tuple, list)):
    schedule = [(".*", schedule)]
  masks, scheds = _make_mask_trees(params, schedule, "config.schedule")
  frozen_mask, masks, scheds = _split_frozen(masks, scheds)
  not_frozen_mask = jax.tree.map(operator.not_, frozen_mask)
  def create_schedule(mult=1.0, **kw):
    assert "base" not in kw, kw
    return u.create_learning_rate_schedule(base=mult, **kw)
  schedule_fns = [create_schedule(**sched_kw, **sched) for sched in scheds]
  schedule_txs = [
      optax.masked(optax.scale_by_schedule(schedule_fn), mask)
      for schedule_fn, mask in zip(schedule_fns, masks)
  ] + [
      # Removes weight decay updates. Note that weight decay already has an
      # independent mask (which cannot be combined easily with a second mask),
      # so instead we multiply updates for frozen params with zero.
      optax.masked(optax.set_to_zero(), frozen_mask)
  ]

  # Gradient clipping.
  if clip_norm := config.get("grad_clip_norm"):
    if config.get("grad_clip_per_example"):
      clip_tx = clip_by_per_example_global_norm(clip_norm)
    else:
      clip_tx = optax.clip_by_global_norm(clip_norm)
    grad_clip_norm_tx = optax.masked(clip_tx, not_frozen_mask)
  else:
    grad_clip_norm_tx = optax.identity()

  # Optimizer updates.
  tx_func = operator.attrgetter(config.optax_name)(optax)
  opt_txs = [optax.masked(tx_func(**config.get("optax", {})), not_frozen_mask)]
  assert "optim" not in config, "Deprecated option, use config.optax."

  # Learning rate multipliers. Defaults to 1.0.
  lr_mult_txs = [optax.scale(config.lr)]
  if config.get("lr_mults"):
    masks, mults = _make_mask_trees(params, config.lr_mults, "config.lr_mults")
    assert all(mult > 0 for mult in mults), (
        f"Use schedule=None for parameter freezing instead of lr_mults={mults}")
    lr_mult_txs += [
        optax.masked(optax.scale(mult), mask)
        for mult, mask in zip(mults, masks)
    ]

  # Weight decay. Defaults to 0.0.
  # Weight decay is not gradient-based but instead uses "params side-input".
  # Hence, weight decay is additive and independent of previous gradient-based
  # updates.
  assert "weight_decay" not in config, "Deprecated option. Use wd and schedule."
  assert config.get("weight_decay_decouple", True), (
      "Coupled weight decay not supported anymore.")
  if config.get("wd"):
    wd_mults = config.get("wd_mults", [(".*/kernel$", 1.0)])
    masks, mults = _make_mask_trees(params, wd_mults, "config.wd_mults")
    weight_decay_txs = [
        optax.add_decayed_weights(config.wd * mult, mask)
        for mult, mask in zip(mults, masks)
    ]
  else:
    weight_decay_txs = []

  # Combine gradient updates and learning rate schedules.
  return optax.chain(
      grad_clip_norm_tx,
      *opt_txs,
      *lr_mult_txs,
      *weight_decay_txs,
      *schedule_txs,
      optax.scale(-1.0)), schedule_fns


def _make_mask_trees(params, patterns_values, log):
  patterns, values = zip(*patterns_values)
  masks = u.make_mask_trees(params, patterns, log=log)
  return masks, values


def _split_frozen(masks, scheds):
  """Computes `frozen_mask` and updates `masks` and `scheds`."""
  # Specifying `None` as a scheduler freezes params.
  all_false = jax.tree.map(lambda *bools: not any(bools), *masks)
  not_covered = [k for k, v in u.tree_flatten_with_names(all_false)[0] if v]
  assert not not_covered, (
      f"All params must be covered (use `None` for freezing): {not_covered}")
  frozen_masks = [
      mask for mask, sched in zip(masks, scheds) if sched is None]
  frozen_mask = jax.tree.map(
      lambda *bools: any(bools), *frozen_masks,
      all_false)  # `all_false` is required when `frozen_masks==[]`.
  masks, scheds = zip(*(
      (mask, sched) for mask, sched in zip(masks, scheds) if sched is not None))
  return frozen_mask, masks, scheds


############ Custom BigVision optimizers #######################################
# Currently there's only one custom optimizer and we don't foresee new ones in
# the near future, we opt not to create a new optimizer folder/module for just
# one isolated case. If there will be more optimizers, we can consider moving
# them into individual files in a subfolder.


# A dummy object to allow for foo.bar access syntax, see
# https://stackoverflow.com/a/19476841/2366315
optax.big_vision = type("", (), {})()


def scale_by_adafactor(min_dim_size_to_factor=32,
                       decay_rate=0.8, decay_offset=0,
                       beta2_cap=0.999,
                       clipping_threshold=None,
                       momentum=0.9, dtype_momentum=jnp.bfloat16,
                       eps=1e-30):
  """The BigVision variant of Adafactor optimizer."""

  def _decay_rate_pow(i, exponent):
    """Second-order moment decay schedule."""
    t = jnp.array(i, jnp.float32) + 1.0
    return jnp.minimum(beta2_cap, 1.0 - t**(-exponent))

  scale_by_rms = optax.scale_by_factored_rms(
      factored=True,
      decay_rate=decay_rate,
      step_offset=decay_offset,
      min_dim_size_to_factor=min_dim_size_to_factor,
      epsilon=eps,
      decay_rate_fn=_decay_rate_pow)

  clip = (optax.clip_by_block_rms(clipping_threshold) if clipping_threshold
          else optax.identity())

  mom = (optax.ema(momentum, debias=False, accumulator_dtype=dtype_momentum)
         if momentum else optax.identity())

  return optax.chain(scale_by_rms, clip, mom)

optax.big_vision.scale_by_adafactor = scale_by_adafactor  # pytype: disable=module-attr


# A few more aliases we use frequently:
def momentum_hp(momentum=0.9, dtype=jnp.bfloat16, nesterov=False):
  """SGD-Momentum with half-precision accumulator."""
  return optax.trace(decay=momentum, accumulator_dtype=dtype, nesterov=nesterov)

optax.big_vision.momentum_hp = momentum_hp  # pytype: disable=module-attr
optax.big_vision.sgd = optax.identity  # pytype: disable=module-attr
