import jax
import jax.numpy as jnp
from typing import Dict, Any
from .optimizer import BaseOptimizer
from foo.models import TrainState, NllFlax, FlaxNet
from foo.linalg import vec_normalize, vec_pow, vec_axpy

class AdamOptimizer(BaseOptimizer):
   def __init__(self, state: TrainState, config: Dict[str, Any] = None, model: FlaxNet = None, loss_fn = None, loss_weights = None):
      """
      Initialize the optimizer with the given configuration file

      Args:
         state: The state of the model
         config: The configuration file
         model: The model
         loss_fn: The loss function or a list of loss functions
         loss_weights: The weights for the loss functions or a list of weights
      """
      super().__init__(state, config, model, loss_fn, loss_weights)

      config_optimizer = config.get('optimizer', {})
      self._beta1 = config_optimizer.get('beta1', 0.9)
      self._beta2 = config_optimizer.get('beta2', 0.999)
      # Make epsilon dependent on the actual dtype used in the model
      eps = 1e-8 if jax.config.x64_enabled else 1e-7
      self._eps = config_optimizer.get('eps', eps)

      m_vec  = jax.tree_util.tree_map(lambda p: jnp.zeros_like(p), state.params)
      v_vec  = jax.tree_util.tree_map(lambda p: jnp.zeros_like(p), state.params)
      t_val = 0

      self._optimizer_state = (m_vec, v_vec, t_val)

      # Define and JIT compile the Adam direction calculation function
      def adam_direction(state, batch, logits_list, optimizer_state, grads):
         m_vec, v_vec, t_val = optimizer_state
         t_val += 1
         grads = jax.lax.cond(
            self._grad_normalization,
            lambda _: vec_normalize(grads),
            lambda _: grads,
            operand = None
         )
         grads_copy = jax.lax.stop_gradient(grads)
         grads_squared = vec_pow(grads_copy, 2)
         m_vec = vec_axpy(self._beta1, m_vec, 1 - self._beta1, grads_copy)
         v_vec = vec_axpy(self._beta2, v_vec, 1 - self._beta2, grads_squared)
         grads_adam =jax.tree_util.tree_map(lambda m, v, g: m / (1 - self._beta1**t_val) / (jnp.sqrt(v / (1 - self._beta2**t_val)) + self._eps), m_vec, v_vec, grads_copy)
      
         grads_adam = jax.lax.cond(
            self._output_normalization,
            lambda _: vec_normalize(grads_adam),
            lambda _: grads_adam,
            operand = None
         )
         return grads_adam, state, (m_vec, v_vec, t_val)
      
      # JIT compile the direction function
      self._adam_direction = jax.jit(adam_direction)

      # Define step functions based on whether batch_stats is present
      if state.batch_stats is None:
         def step(state, batch, optimizer_state, learning_rate):
            grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
            (loss, (loss_list, logits_list)), grads = grad_fn(state.params, state, batch)
            grads_adam, state, optimizer_state = self._adam_direction(state, batch, logits_list, optimizer_state, grads)
            params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, state.params, grads_adam)
            state = state.replace(params=params)
            return state, loss, grads, loss_list, logits_list, optimizer_state
      else:
         def step(state, batch, optimizer_state, learning_rate):
            grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
            (loss, (logits, updates)), grads = grad_fn(state.params, state, batch)
            grads_adam, state, optimizer_state = self._adam_direction(state, batch, [logits,], optimizer_state, grads)
            params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, state.params, grads_adam)
            state = state.replace(params=params, batch_stats=updates['batch_stats'])
            return state, loss, grads, [loss,], [logits,], optimizer_state

      # JIT compile the step function
      self._step = jax.jit(step)

   def get_sample_config(self):
      config_optimizer = super().get_sample_config()['optimizer']
      config_optimizer['option'] = 'adam'
      config_optimizer['beta1'] = self._beta1
      config_optimizer['beta2'] = self._beta2
      config_optimizer['eps'] = self._eps
      return {'optimizer': config_optimizer}
