from typing import NamedTuple
from optax import tree_utils as otu
from optax._src import utils as optax_utils
import jax.numpy as jnp
from optax._src import numerics
from jax import tree_util as jtu
from flax import struct
import jax
from functools import partial

def update_moment(updates, moments, decay, order):
    """Compute the exponential moving average of the `order`-th moment."""
    return jtu.tree_map(
        lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)

def update_moment_per_elem_norm(updates, moments, decay, order):
    """Compute the EMA of the `order`-th moment of the element-wise norm."""
    def orderth_norm(g):
        if jnp.isrealobj(g):
            return g ** order
        else:
            half_order = order / 2
        # JAX generates different HLO for int and float `order`
        if half_order.is_integer():
            half_order = int(half_order)

        return numerics.abs_sq(g) ** half_order

    return jtu.tree_map(
        lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments)

@partial(jax.jit, inline=True)
def bias_correction(moment, decay, count):
  """Performs bias correction. It becomes a no-op as count goes to infinity."""
  # The conversion to the data type of the moment ensures that bfloat16 remains
  # bfloat16 in the optimizer state. This conversion has to be done after
  # `bias_correction_` is calculated as calculating `decay**count` in low
  # precision can result in it being rounded to 1 and subsequently a
  # "division by zero" error.
  bias_correction_ = 1 - decay**count

  # Perform division in the original precision.
  return jax.tree_util.tree_map(
      lambda t: t / bias_correction_.astype(t.dtype), moment)

def old_adam_step(opt, params, state, updates, AdamState, lr_factor, wd_factor):
    # mu = mu * b1 + (1 - b1) * updates
    # nu = nu * b2 + (1 - b2) * updates**2 (element-wise)
    selective_wd = opt.selective_wd
    mu = update_moment(updates, state.mu, opt.b1, 1)
    nu = update_moment_per_elem_norm(updates, state.nu, opt.b2, 2)

    count_inc = numerics.safe_int32_increment(state.count)

    # mu_hat = mu/(1 - b1**count)
    # nu_hat = nu/(1 - b2**count)
    mu_hat = bias_correction(mu, opt.b1, count_inc)
    nu_hat = bias_correction(nu, opt.b2, count_inc)

    # Adam update: u = mu_hat / (sqrt(nu_hat + eps_root) + eps)
    updates = jtu.tree_map(
        lambda m, v: m / (jnp.sqrt(v + opt.eps_root) + opt.eps), mu_hat, nu_hat)

    # then get the current learning rate
    lr = opt.lr(state.count) * lr_factor
    wd = opt.wd * wd_factor
    max_lr = opt.max_lr

    def per_param_update(p, u):
        # wd_strength = wd * lr if not opt.factored_lr_wd else wd * (lr / max_lr) [TODO]
        wd_strength = jax.lax.cond(opt.factored_lr_wd, lambda: wd * lr, lambda: wd * (lr / max_lr))
        p = jax.lax.cond(p.ndim > 1,
                         lambda: p * (1. - wd_strength),
                         lambda: jax.lax.cond(selective_wd, 
                                              lambda: p, 
                                              lambda: p * (1. - wd_strength)))
        """ [TODO]
        if p.ndim > 1 or not selective_wd:
            # if no selective wd, or if ndim > 1
            p = p * (1. - wd_strength)
        """

        return p - u * lr

    # minimize the loss
    # new_params = params - u * lr
    new_params = jtu.tree_map(per_param_update, params, updates)

    return new_params, AdamState(count=count_inc, mu=mu, nu=nu)
