from typing import Dict, Any
from foo.models import TrainState, NllFlax, FlaxNet
import jax
import jax.numpy as jnp

class BaseOptimizer:
   def __init__(self, state: TrainState, config: Dict[str, Any] = None, model: FlaxNet = None, loss_fn = None, loss_weights = None):
      """
      Initialize the optimizer with the given configuration file

      Args:
         state: The state of the model
         config: The configuration file
         model: The model
         loss_fn: The loss function or a list of loss functions
         loss_weights: The weights for the loss functions or a list of weights
      """
      config_optimizer = config.get('optimizer', {})
      self._learning_rate_init = config_optimizer.get('learning_rate_init', 1e-3)
      self._learning_rate = config_optimizer.get('learning_rate', 1e-3)
      self._learning_rate_min = config_optimizer.get('learning_rate_min', min(self._learning_rate_init*0.1, 1e-5))
      self._learning_rate_step_decay = config_optimizer.get('learning_rate_step_decay', False)
      self._learning_rate_step_decay_rate = config_optimizer.get('learning_rate_step_decay_rate', 0.1)
      self._learning_rate_step_decay_steps = config_optimizer.get('learning_rate_step_decay_steps', 1000)
      self._learning_rate_cosine_decay = config_optimizer.get('learning_rate_cosine_decay', False)
      self._learning_rate_cosine_decay_T_max = config_optimizer.get('learning_rate_cosine_decay_T_max', 1000)
      self._grad_normalization = config_optimizer.get('grad_normalization', False)
      self._output_normalization = config_optimizer.get('output_normalization', False)
      self._nits = 0

      if loss_fn is None:
         self._loss_fns = [NllFlax for i in range(state.num_apply_fns)]
      else:
         if isinstance(loss_fn, (list, tuple)):
            if len(loss_fn) != state.num_apply_fns:
               self._loss_fns = [loss_fn[0] for i in range(state.num_apply_fns)]
            else:
               self._loss_fns = [loss_fn[0] for i in range(state.num_apply_fns)]
         else:
            self._loss_fns = [loss_fn for i in range(state.num_apply_fns)]

      if loss_weights is None:
         self._loss_weights = [1.0 for i in range(state.num_apply_fns)]
      else:
         if isinstance(loss_weights, (list, tuple)):
            if len(loss_weights) != state.num_apply_fns:
               self._loss_weights = [loss_weights[0] for i in range(state.num_apply_fns)]
            else:
               self._loss_weights = [loss_weights[i] for i in range(state.num_apply_fns)]
         else:
            self._loss_weights = [loss_weights for i in range(state.num_apply_fns)]

      if state.batch_stats is None:
         # Standard loss function
         if state.num_apply_fns == 1:
            def eval_and_loss(params, state, batch):
               x, y = batch
               logits = state.apply_fns[0]({'params': params}, x)
               loss = self._loss_weights[0] * self._loss_fns[0](logits, y, batch=batch)
               return loss, ([loss,], [logits,])
         # TODO: Lazy implementation
         elif state.num_apply_fns == 2:
            def eval_and_loss(params, state, batch):
               x, y = batch
               logits1 = state.apply_fns[0]({'params': params}, x[0])
               loss1 = self._loss_weights[0] * self._loss_fns[0](logits1, y[0], batch=batch)
               logits2 = state.apply_fns[1]({'params': params}, x[1])
               loss2 = self._loss_weights[1] * self._loss_fns[1](logits2, y[1], batch=batch)
               return loss1 + loss2, ([loss1, loss2,], [logits1, logits2])
         elif state.num_apply_fns == 3:
            def eval_and_loss(params, state, batch):
               x, y = batch
               logits1 = state.apply_fns[0]({'params': params}, x[0])
               loss1 = self._loss_weights[0] * self._loss_fns[0](logits1, y[0], batch=batch)
               logits2 = state.apply_fns[1]({'params': params}, x[1])
               loss2 = self._loss_weights[1] * self._loss_fns[1](logits2, y[1], batch=batch)
               logits3 = state.apply_fns[2]({'params': params}, x[2])
               loss3 = self._loss_weights[2] * self._loss_fns[2](logits3, y[2], batch=batch)
               return loss1 + loss2 + loss3, ([loss1, loss2, loss3,], [logits1, logits2, logits3])
         else:
            raise NotImplementedError("Multi-objective with more than 3 apply functions is not yet supported")    
      else:
         # With batch stats, the old interface
         if state.num_apply_fns == 1:
            def eval_and_loss(params, state, batch):
               x, y = batch
               logits, updates = state.apply_fns[0](
                  {'params': params, 'batch_stats': state.batch_stats},
                  x, train = True, mutable=['batch_stats']
               )
               loss = self._loss_weights[0] * self._loss_fns[0](logits, y, batch=batch)
               return loss, (logits, updates)
         else:
            # multi-objective with batch stats is not yet supported
            raise NotImplementedError("Multi-objective with batch stats is not yet supported")
      self._eval_and_loss = eval_and_loss

   def step(self, state: TrainState, batch):
      """
      Perform a single step of the optimizer

      Args:
         state: The state of the model
         batch: The batch of data
      """
      state, loss, grads, loss_list, logits_list, self._optimizer_state = self._step(state, batch, self._optimizer_state, self._learning_rate)
      return state, loss, grads, loss_list, logits_list
   
   def update_learning_rate(self):
      self._nits += 1
      if self._learning_rate_step_decay:
         if self._nits % self._learning_rate_step_decay_steps == 0:
            self._learning_rate *= self._learning_rate_step_decay_rate
      if self._learning_rate_cosine_decay:
         self._learning_rate = self._learning_rate_min + 0.5 * (self._learning_rate_init - self._learning_rate_min) * (1 + jnp.cos(jnp.pi * self._nits / self._learning_rate_cosine_decay_T_max))

   def set_learning_rate(self, learning_rate):
      self._learning_rate = learning_rate

   def get_sample_config(self):
      config_optimizer = {
         'learning_rate': self._learning_rate,
         'learning_rate_init': self._learning_rate_init,
         'learning_rate_min': self._learning_rate_min,
         'learning_rate_step_decay': self._learning_rate_step_decay,
         'learning_rate_step_decay_rate': self._learning_rate_step_decay_rate,
         'learning_rate_step_decay_steps': self._learning_rate_step_decay_steps,
         'learning_rate_cosine_decay': self._learning_rate_cosine_decay,
         'learning_rate_cosine_decay_T_max': self._learning_rate_cosine_decay_T_max,
         'grad_normalization': self._grad_normalization,
         'output_normalization': self._output_normalization
      }
      return {'optimizer': config_optimizer}

