import jax
import jax.numpy as jnp
from typing import Dict, Any
from .optimizer import BaseOptimizer
from foo.models import TrainState, NllFlax, FlaxNet
from foo.linalg import vec_normalize, vec_axpy

class SGDOptimizer(BaseOptimizer):
   def __init__(self, state: TrainState, config: Dict[str, Any] = None, model: FlaxNet = None, loss_fn = None, loss_weights = None):
      """
      Initialize the optimizer with the given configuration file

      Args:
         state: The state of the model
         config: The configuration file
         model: The model
         loss_fn: The loss function or a list of loss functions
         loss_weights: The weights for the loss functions or a list of weights
      """
      super().__init__(state, config, model, loss_fn, loss_weights)
      
      # Get momentum configuration
      config_optimizer = config.get('optimizer', {})
         
      self._momentum = config_optimizer.get('momentum', 0.0)
      self._use_nesterov = config_optimizer.get('nesterov', False)
      
      # Initialize velocity state, even if momentum is 0
      # This keeps structure consistent with Adam implementation
      velocity = jax.tree_util.tree_map(lambda p: jnp.zeros_like(p), state.params)
      self._optimizer_state = velocity
      
      # Define and JIT compile the SGD direction calculation function
      def sgd_direction(state, batch, logits_list, optimizer_state, grads):
         # Apply gradient normalization if configured
         grads = jax.lax.cond(
            self._grad_normalization,
            lambda _: vec_normalize(grads),
            lambda _: grads,
            operand = None
         )
         
         grads_copy = jax.lax.stop_gradient(grads)
         
         # Update velocity based on momentum
         def with_momentum():
            velocity = jax.tree_util.tree_map(
               lambda v, g: self._momentum * v + g,
               optimizer_state, grads_copy
            )
            
            # Apply Nesterov momentum if enabled
            if self._use_nesterov:
               grads_sgd = jax.tree_util.tree_map(
                  lambda v, g: self._momentum * v + g, 
                  velocity, grads_copy
               )
            else:
               grads_sgd = velocity
               
            return grads_sgd, velocity
         
         # No momentum case
         def without_momentum():
            # Keep velocity as zeros but still update it for consistency
            velocity = jax.tree_util.tree_map(lambda v: jnp.zeros_like(v), optimizer_state)
            return grads_copy, velocity
         
         # Select implementation based on momentum value
         grads_sgd, velocity = jax.lax.cond(
            self._momentum > 0,
            lambda _: with_momentum(),
            lambda _: without_momentum(),
            operand = None
         )
         
         return grads_sgd, state, velocity
      
      # JIT compile the direction function
      self._sgd_direction = jax.jit(sgd_direction)
      
      # Define step functions based on whether batch_stats is present
      if state.batch_stats is None:
         def step(state, batch, optimizer_state, learning_rate):
            grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
            (loss, (loss_list, logits_list)), grads = grad_fn(state.params, state, batch)
            grads_sgd, state, optimizer_state = self._sgd_direction(state, batch, logits_list, optimizer_state, grads)
            params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, state.params, grads_sgd)
            state = state.replace(params=params)
            return state, loss, grads, loss_list, logits_list, optimizer_state
      else:
         def step(state, batch, optimizer_state, learning_rate):
            grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
            (loss, (logits, updates)), grads = grad_fn(state.params, state, batch)
            grads_sgd, state, optimizer_state = self._sgd_direction(state, batch, [logits,], optimizer_state, grads)
            params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, state.params, grads_sgd)
            state = state.replace(params=params, batch_stats=updates['batch_stats'])
            return state, loss, grads, [loss,], [logits,], optimizer_state

      # JIT compile the step function
      self._step = jax.jit(step)

   def get_sample_config(self):
      config_optimizer = super().get_sample_config()['optimizer']
      config_optimizer['option'] = 'sgd'
      config_optimizer['momentum'] = self._momentum
      config_optimizer['nesterov'] = self._use_nesterov
      return {'optimizer': config_optimizer}
