import jax
import jax.numpy as jnp
from typing import Dict, Any
from .optimizer import BaseOptimizer
from foo.models import TrainState, NllFlax, FlaxNet
from foo.linalg import vec_normalize, vec_pow, vec_axpy, vec_dot, vec_norm_inf, vec_scale

class NLTGCROptimizer(BaseOptimizer):
   def __init__(self, state: TrainState, config: Dict[str, Any] = None, model: FlaxNet = None, loss_fn = None, loss_weights = None):
      """
      Initialize the optimizer with the given configuration file

      Args:
         state: The state of the model
         config: The configuration file
         model: The model
         loss_fn: The loss function or a list of loss functions
         loss_weights: The weights for the loss functions or a list of weights
      """
      super().__init__(state, config, model, loss_fn, loss_weights)

      config_optimizer = config.get('optimizer', {})
      self._wsize = config_optimizer.get('wsize', 1)
      self._inner_iterations = config_optimizer.get('inner_iterations', 1)
      self._restart = config_optimizer.get('restart', 10)
      self._safeguard = config_optimizer.get('safeguard', 1e3)
      second_order_config = config_optimizer.get('second_order', {})
      self._second_order_option = second_order_config.get('option', 'fisher')
      self._second_order_tol = second_order_config.get('tol', 1e-5)
      self._accumulate = second_order_config.get('accumulate', False)
      self._rank = second_order_config.get('rank', 5)
      self._print_level = second_order_config.get('print_level', 1)

      if self._print_level > 0:
         print(f"Initialize NLTGCR optimizer with")
         print(f"wsize={self._wsize}")
         print(f"inner_iterations={self._inner_iterations}")
         print(f"restart={self._restart}")
         print(f"safeguard={self._safeguard}")
         print(f"second_order_option={self._second_order_option}")
         print(f"second_order_tol={self._second_order_tol}")
         print(f"accumulate={self._accumulate}")
         print(f"rank={self._rank}")

      # Get a sample parameter to determine the dtype
      sample_param = jax.tree_util.tree_leaves(state.params)[0]
      self._dtype = sample_param.dtype

      # create state for nltgcr with correct dtype
      V = jax.tree_util.tree_map(
         lambda p: jnp.zeros((self._wsize + 1,) + p.shape, dtype=self._dtype), 
         state.params
      )
      P = jax.tree_util.tree_map(
         lambda p: jnp.zeros((self._wsize + 1,) + p.shape, dtype=self._dtype), 
         state.params
      )
      xrec = jnp.zeros(self._wsize + 1, dtype=self._dtype)

      # define matvec
      # TODO: make this reusable in other optimizers
      # Note: this fisher_mv does not work for arbitary model structure.
      # when using on MLP, no extra reshape is needed.
      # when using on CNN, be sure to check shape. Adding extra dimension as needed.

      if self._second_order_option == 'fisher':
         # Fisher matrix-vector product
         if state.batch_stats is None:
            if state.num_apply_fns == 1:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  jvp = jax.jvp(lambda params: state.apply_fns[0]({'params': params}, x), (state.params, ), (vec, ))
                  hvp = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[0] * self._loss_fns[0](output, y, batch=batch))(output), jvp[1]))(logits_list[0])
                  vjp = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[0]({'params': params}, data), state.params)[1](v))(x, hvp)
                  return jax.tree_util.tree_map(lambda v: v.mean(axis=0), vjp)[0]
            elif state.num_apply_fns == 2:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  jvp1 = jax.jvp(lambda params: state.apply_fns[0]({'params': params}, x[0]), (state.params, ), (vec, ))
                  jvp2 = jax.jvp(lambda params: state.apply_fns[1]({'params': params}, x[1]), (state.params, ), (vec, ))
                  hvp1 = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[0] * self._loss_fns[0](output, y[0], batch=batch))(output), jvp1[1]))(logits_list[0])
                  hvp2 = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[1] * self._loss_fns[1](output, y[1], batch=batch))(output), jvp2[1]))(logits_list[1])
                  vjp1 = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[0]({'params': params}, data), state.params)[1](v))(x[0], hvp1)
                  vjp2 = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[1]({'params': params}, data), state.params)[1](v))(x[1], hvp2)
                  return jax.tree_util.tree_map(lambda v1, v2: v1.mean(axis=0) + v2.mean(axis=0), vjp1, vjp2)[0]
            elif state.num_apply_fns == 3:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  jvp1 = jax.jvp(lambda params: state.apply_fns[0]({'params': params}, x[0]), (state.params, ), (vec, ))
                  jvp2 = jax.jvp(lambda params: state.apply_fns[1]({'params': params}, x[1]), (state.params, ), (vec, ))
                  jvp3 = jax.jvp(lambda params: state.apply_fns[2]({'params': params}, x[2]), (state.params, ), (vec, ))
                  hvp1 = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[0] * self._loss_fns[0](output, y[0], batch=batch))(output), jvp1[1]))(logits_list[0])
                  hvp2 = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[1] * self._loss_fns[1](output, y[1], batch=batch))(output), jvp2[1]))(logits_list[1])
                  hvp3 = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[2] * self._loss_fns[2](output, y[2], batch=batch))(output), jvp3[1]))(logits_list[2])
                  vjp1 = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[0]({'params': params}, data), state.params)[1](v))(x[0], hvp1)
                  vjp2 = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[1]({'params': params}, data), state.params)[1](v))(x[1], hvp2)
                  vjp3 = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[2]({'params': params}, data), state.params)[1](v))(x[2], hvp3)
                  return jax.tree_util.tree_map(lambda v1, v2, v3: v1.mean(axis=0) + v2.mean(axis=0) + v3.mean(axis=0), vjp1, vjp2, vjp3)[0]
            else:
               raise ValueError(f"Number of apply functions must be 1, 2, or 3, but got {state.num_apply_fns}")
         
         else:
            if state.num_apply_fns == 1:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  jvp = jax.jvp(lambda params: state.apply_fns[0]({'params': params, 'batch_stats': state.batch_stats}, x, train=False), (state.params, ), (vec, ))
                  hvp = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[0] * self._loss_fns[0](output, y, batch=batch))(output), jvp[1]))(logits_list[0])
                  vjp = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns[0]({'params': params, 'batch_stats': state.batch_stats}, data, train=False), state.params)[1](v))(x, hvp)
                  return jax.tree_util.tree_map(lambda v: v.mean(axis=0), vjp)[0]
            else:
               raise ValueError(f"Number of apply functions must be 1 when using batch_stats")
      else:
         # Hessian matrix-vector product
         if state.batch_stats is None:
            if state.num_apply_fns == 1:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  return jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params}, x), y, batch=batch))(params), (state.params,), (vec,))[1]
            elif state.num_apply_fns == 2:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  return jax.tree_util.tree_map(lambda v1, v2: v1 + v2, 
                     jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params}, x[0]), y[0], batch=batch))(params), (state.params,), (vec,))[1],
                     jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[1] * self._loss_fns[1](state.apply_fns[1]({'params': params}, x[1]), y[1], batch=batch))(params), (state.params,), (vec,))[1])
            elif state.num_apply_fns == 3:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  return jax.tree_util.tree_map(lambda v1, v2, v3: v1 + v2 + v3, 
                     jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params}, x[0]), y[0], batch=batch))(params), (state.params,), (vec,))[1],
                     jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[1] * self._loss_fns[1](state.apply_fns[1]({'params': params}, x[1]), y[1], batch=batch))(params), (state.params,), (vec,))[1],
                     jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[2] * self._loss_fns[2](state.apply_fns[2]({'params': params}, x[2]), y[2], batch=batch))(params), (state.params,), (vec,))[1])
         else:
            if state.num_apply_fns == 1:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  return jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params, 'batch_stats': state.batch_stats}, x, train=False), y, batch=batch))(params), (state.params,), (vec,))[1]
            else:
               raise ValueError(f"Number of apply functions must be 1 when using batch_stats")
      
      # JIT compile the base matrix-vector product function
      self._base_mv = jax.jit(base_mv)
      self._optimizer_state = (0, 0, V, P, xrec)

      if self._accumulate:
         # TODO: we assume rank is always less than the number of parameters, so there might be error
         raise NotImplementedError("Accumulate is not implemented for CG optimizer")
      else:
         def second_order_mv(batch, state, optimizer_state, logits, vector):
            tol = jnp.array(self._second_order_tol, dtype=self._dtype)
            return jax.tree_util.tree_map(lambda v, w: v + tol * w, 
                                          self._base_mv(batch, state, optimizer_state, logits, vector), vector)
         
         self._second_order_mv = jax.jit(second_order_mv)
          
         def second_order_bmv(batch, state, optimizer_state, logits, vectors):   
            return jax.vmap(lambda vector: self._second_order_mv(batch, state, optimizer_state, logits, vector), in_axes=0)(vectors)
         self._second_order_bmv = jax.jit(second_order_bmv)
      
      # Define and JIT compile efficient sub-functions for the algorithm
      def body_fun_step(V, P, xrec, i, p, v, w):
         Vi = jax.tree_util.tree_map(lambda arr: arr[i], V)
         Pi = jax.tree_util.tree_map(lambda arr: arr[i], P)
         beta = vec_dot(Vi, v)
         p = vec_axpy(jnp.array(1.0, dtype=self._dtype), p, 
                      jnp.array(-1.0, dtype=self._dtype) * beta, Pi)
         v = vec_axpy(jnp.array(1.0, dtype=self._dtype), v, 
                      jnp.array(-1.0, dtype=self._dtype) * beta, Vi)
         w = w + jnp.abs(beta) * xrec[i]
         return p, v, w
      self._body_fun_step = jax.jit(body_fun_step)
      
      def grad_body_fun_step(V, P, i, r, grads):
         Vi = jax.tree_util.tree_map(lambda arr: arr[i], V)
         Pi = jax.tree_util.tree_map(lambda arr: arr[i], P)
         Vi_dot_r = vec_dot(Vi, r)
         return vec_axpy(jnp.array(1.0, dtype=self._dtype), grads, 
                         jnp.array(-1.0, dtype=self._dtype) * Vi_dot_r, Pi)
      self._grad_body_fun_step = jax.jit(grad_body_fun_step)
      
      def update_V_P_entry(V, P, xrec, its, v, r, nrmv, w_restart):
         V = jax.tree_util.tree_map(
            lambda arr, v_leaf: arr.at[its].set(vec_scale(jnp.array(1.0, dtype=self._dtype) / nrmv, v_leaf)), V, v)
         P = jax.tree_util.tree_map(
            lambda arr, p_leaf: arr.at[its].set(vec_scale(jnp.array(1.0, dtype=self._dtype) / nrmv, p_leaf)), P, r)
         xrec = xrec.at[its].set(w_restart / nrmv)
         return V, P, xrec
      self._update_V_P_entry = jax.jit(update_V_P_entry)
      
      def restart_fn(state_params, V, P, xrec, v, r, wsize):
         V_new = jax.tree_util.tree_map(
            lambda p: jnp.zeros((wsize + 1,) + p.shape, dtype=self._dtype), 
            state_params
         )
         P_new = jax.tree_util.tree_map(
            lambda p: jnp.zeros((wsize + 1,) + p.shape, dtype=self._dtype), 
            state_params
         )
         xrec_new = jnp.zeros(wsize + 1, dtype=self._dtype)
         nrmv = jnp.sqrt(vec_dot(v, v))
         w_restart = vec_norm_inf(r)
         
         V_new = jax.tree_util.tree_map(
               lambda arr, v_leaf: arr.at[0].set(vec_scale(jnp.array(1.0, dtype=self._dtype) / nrmv, v_leaf)), V_new, v)
         P_new = jax.tree_util.tree_map(
               lambda arr, p_leaf: arr.at[0].set(vec_scale(jnp.array(1.0, dtype=self._dtype) / nrmv, p_leaf)), P_new, r)
         xrec_new = xrec_new.at[0].set(w_restart / nrmv)
         return V_new, P_new, xrec_new, 0
      self._restart_fn = jax.jit(restart_fn, static_argnums=(6,))
      
      def update_fn(V, P, xrec):
         V = jax.tree_util.tree_map(lambda arr: jnp.roll(arr, shift=-1, axis=0), V)
         P = jax.tree_util.tree_map(lambda arr: jnp.roll(arr, shift=-1, axis=0), P)
         xrec = jnp.roll(xrec, shift=-1, axis=0)
         return V, P, xrec
      self._update_fn = jax.jit(update_fn)
      
      def nltgcr_direction(state, batch, optimizer_state, logits_list, grads):
         its, tits, V, P, xrec = optimizer_state

         grads = jax.lax.cond(
            self._grad_normalization,
            lambda _: vec_normalize(grads),
            lambda _: grads,
            operand = None
         )
         # r = - grads
         r = jax.lax.stop_gradient(vec_scale(jnp.array(-1.0, dtype=self._dtype), grads))
         # v = fmv(r)
         v = jax.lax.stop_gradient(self._second_order_mv(batch, state, optimizer_state, logits_list, r))
         p = r
         w_restart = vec_norm_inf(p)

         # for i in its, beta = V[i]' * v, p = p - beta * P[i], v = v - beta * V[i]
         def body_fun(i, carry):
            p, v, w = carry
            p, v, w = self._body_fun_step(V, P, xrec, i, p, v, w)
            return (p, v, w)
         p, v, w_restart = jax.lax.fori_loop(0, its, body_fun, (p, v, w_restart))

         # scale v and p by 1/nrmv
         nrmv = jnp.sqrt(vec_dot(v, v))
         V, P, xrec = self._update_V_P_entry(V, P, xrec, its, v, r, nrmv, w_restart)

         grad_nltgcr = jax.tree_util.tree_map(jnp.zeros_like, r)

         # for i in its, grad = grad + V[i]' * r * P[i]
         def grad_body_fun(i, grads):
            return self._grad_body_fun_step(V, P, i, r, grads)

         grad_nltgcr = jax.lax.fori_loop(0, its, grad_body_fun, grad_nltgcr)

         # Update 'its' and set V, P with arr[i-1] = arr[i] when necessary
         its = jax.lax.select(its == self._wsize, its, its + 1)
         tits = tits + 1

         # Conditionally restart or update
         V, P, xrec, its = jax.lax.cond(
            (tits % self._restart == 0) | (w_restart > self._safeguard),
            lambda _: self._restart_fn(state.params, V, P, xrec, v, r, self._wsize),
            lambda _: (V, P, xrec, its),
            operand=None
         )

         V, P, xrec = jax.lax.cond(
            its == self._wsize,
            lambda _: self._update_fn(V, P, xrec),
            lambda _: (V, P, xrec),
            operand=None
         )

         tc = vec_dot(grad_nltgcr, r)
         grad_nltgcr = jax.lax.cond(
            tc > 0,
            lambda _: vec_scale(jnp.array(-1.0, dtype=self._dtype), grad_nltgcr),
            lambda _: grad_nltgcr,
            operand=None,
         )

         grad_nltgcr = jax.lax.cond(
            self._output_normalization,
            lambda _: vec_normalize(grad_nltgcr),
            lambda _: grad_nltgcr,
            operand=None
         )

         return grad_nltgcr, state, (its, tits, V, P, xrec)

      # JIT compile the nltgcr_direction function
      self._nltgcr_direction = jax.jit(nltgcr_direction)

      if state.batch_stats is None:
         def step(state, batch, optimizer_state, learning_rate):
            def inner_loop_body(i, carry):
               state, optimizer_state, loss, loss_list, logits_list, grads = carry
               grads_nltgcr, state, optimizer_state = self._nltgcr_direction(
                  state, batch, optimizer_state, logits_list, grads
               )
               params = jax.tree_util.tree_map(
                  lambda p, g: p - learning_rate * g, 
                  state.params, grads_nltgcr
               )
               state = state.replace(params=params)
               grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
               (loss, (loss_list, logits_list)), grads = jax.lax.cond(
                  i < self._inner_iterations - 1,
                  lambda _: grad_fn(state.params, state, batch),
                  lambda _: ((loss, (loss_list, logits_list)), grads),
                  operand=None
               )
               return (state, optimizer_state, loss, loss_list, logits_list, grads)
            grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
            (loss, (loss_list, logits_list)), grads = grad_fn(state.params, state, batch)
            state, optimizer_state, loss, loss_list, logits_list, _ = jax.lax.fori_loop(
               0, self._inner_iterations, inner_loop_body, (state, optimizer_state, loss, loss_list, logits_list, grads)
            )
            return state, loss, grads, loss_list, logits_list, optimizer_state
          
      else:
                
         def step(state, batch, optimizer_state, learning_rate):
            def inner_loop_body(i, carry):
               state, optimizer_state, loss, logits_list, updates, grads = carry
               grads_nltgcr, state, optimizer_state = self._nltgcr_direction(
                  state, batch, optimizer_state, logits_list, grads
               )
               params = jax.tree_util.tree_map(
                  lambda p, g: p - learning_rate * g, 
                  state.params, grads_nltgcr
               )
               state = state.replace(params=params, batch_stats=updates['batch_stats'])
               grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
               (loss, (logits, updates)), grads = jax.lax.cond(
                  i < self._inner_iterations - 1,
                  lambda _: grad_fn(state.params, state, batch),
                  lambda _: ((loss, (logits_list[0], updates)), grads),
                  operand=None
               )
               return (state, optimizer_state, loss, [logits,], updates, grads)
            grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
            (loss, (logits, updates)), grads = grad_fn(state.params, state, batch)
            state, optimizer_state, loss, logits_list, updates, _ = jax.lax.fori_loop(
               0, self._inner_iterations, inner_loop_body, (state, optimizer_state, loss,[logits,], updates, grads)
            )
            return state, loss, grads, [loss,], logits_list, optimizer_state

      # JIT compile the step function
      self._step = jax.jit(step)
