from typing import Any, Callable, Optional

import haiku as hk
import jax
import jax.numpy as jnp
import jax.scipy as jsp

def inverse_softplus(x):
  return np.log(np.exp(x) - 1.)

class MomentumOptimizer(hk.Module):
  def __init__(self, learning_rate: float = 0.125,
               momentum: float = 0.9,
               name: Optional[str] = None):
    super().__init__(name=name)

  self._mu = hk.get_parameter(
      "momentum", [], jnp.float32,
      hk.initializers.Constant(jsp.special.logit(momentum)))
  self._lr = hk.get_parameter(
      "lr", [], jnp.float32,
      hk.initializers.Constant(inverse_softplus(learning_rate)))

  @property
  def learning_rate(self):
    return jax.nn.softplus(self._lr)

  @property
  def momentum(self):
    return jax.nn.sigmoid(self._mu)

  def __call__(self, f: Callable[[Any, jnp.ndarray, Any], jnp.ndarray],
               y_init: jnp.ndarray, x: Any, theta: Any, max_iters: int = 5, 
               gtol: float = 1e-3, clip_value: Optional[float] = None):
    """
    Args:
      f: objective that takes y (optimization argument) of shape
        [batch_size, ...], x (conditioning input) of shape [batch_size, ...],
        and theta (shared params) and outputs a vector of objective values of
        shape [batch_size].
      y_init: the initial value for y of shape [batch_size, ...].
      x: Conditioning parameters.
      theta: shared parameters for the objective.
      max_iters: maximum number of optimization iterations.
      gtol: tolerance level for stopping optimization (in terms of gradient
        max norm).
      clip_value: if specified, defines an inverval [-clip_value, clip_value]
        to project each dimension of the state variable on.

    Returns:
      (y_optimal, optimizer_results).
    """
    def combined_objective(y, x, theta):
      fval = f(y, x, theta)
      return jnp.sum(fval), fval

    grad_fn = jax.grad(combined_objective, argnums=0, has_aux=True)
    y = y_init

    grad_norm = jnp.zeros([y.shape[0]], dtype=y.dtype)
    fval = jnp.zeros([y.shape[0]], dtype=y.dtype)
    max_norm = jnp.zeros([y.shape[0]], dtype=y.dtype)
    momentum = jnp.zeros_like(y)
    
    def loop_body(_, args):
      y, grad_norm, momentum, max_norm, f_val = args
      grad, f_val = grad_fn(y + self.momentum * momentum, x, theta)
      max_norm = jnp.max(jnp.abs(grad), axis=1)
      grad_mask = jnp.greater_equal(max_norm, gtol)
      grad_mask = grad_mask.astype(y.dtype)
      momentum = self.momentum * momentum - self.learning_rate * grad
      y += grad_mask[:, None] * momentum
      if clip_value is not None:
        y = jnp.clip(y, 0. - clip_value, clip_value)

      grad_norm += jnp.square(grad).mean(axis=1)
      
    return jax.lax.fori_loop(0, max_iters, loop_body, (y, grad_norm, momentum, max_norm, fval))