import jax
import jax.numpy as jnp
from typing import Dict, Any, Tuple, List, Callable
from .optimizer import BaseOptimizer
from foo.models import TrainState, NllFlax, FlaxNet
from foo.linalg import vec_normalize, vec_pow, vec_axpy, vec_dot
import time
from tqdm import tqdm

class CGOptimizer(BaseOptimizer):
   def __init__(self, state: TrainState, config: Dict[str, Any] = None, model: FlaxNet = None, loss_fn = None, loss_weights = None):
      """
      Initialize the optimizer with the given configuration file

      Args:
         state: The state of the model
         config: The configuration file
         model: The model
         loss_fn: The loss function or a list of loss functions
         loss_weights: The weights for the loss functions or a list of weights
      """
      super().__init__(state, config, model, loss_fn, loss_weights)

      config_optimizer = config.get('optimizer', {})
          
      self._momentum = config_optimizer.get('momentum', 0.0)
      self._use_nesterov = config_optimizer.get('nesterov', False)

      self._its = config_optimizer.get('its', 2)
      self._tol = config_optimizer.get('tol', 1e-5)
      second_order_config = config_optimizer.get('second_order', {})
      self._second_order_option = second_order_config.get('option', 'fisher')
      self._second_order_tol = second_order_config.get('tol', 1e-5)
      self._accumulate = second_order_config.get('accumulate', False)
      self._rank = second_order_config.get('rank', 5)
      self._print_level = second_order_config.get('print_level', 1)

      if self._print_level > 0:
         print(f"Initialize CG optimizer with")
         print(f"its={self._its}")
         print(f"tol={self._tol}")
         print(f"second_order_option={self._second_order_option}")
         print(f"second_order_tol={self._second_order_tol}")
         print(f"accumulate={self._accumulate}")
         print(f"rank={self._rank}")
         print(f"momentum={self._momentum}")
         print(f"use_nesterov={self._use_nesterov}")

      # define matvec
      # TODO: make this reusable in other optimizers
      # Note: this fisher_mv does not work for arbitary model structure.
      # when using on MLP, no extra reshape is needed.
      # when using on CNN, be sure to check shape. Adding extra dimension as needed.

      if self._second_order_option == 'fisher':
         # Fisher matrix-vector product
         if state.batch_stats is None:
            if state.num_apply_fns == 1:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  jvp = jax.jvp(lambda params: state.apply_fns[0]({'params': params}, x), (state.params, ), (vec, ))
                  hvp = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[0] * self._loss_fns[0](output, y, batch=batch))(output), jvp[1]))(logits_list[0])
                  vjp = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[0]({'params': params}, data), state.params)[1](v))(x, hvp)
                  return jax.tree_util.tree_map(lambda v: v.mean(axis=0), vjp)[0]
            elif state.num_apply_fns == 2:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  jvp1 = jax.jvp(lambda params: state.apply_fns[0]({'params': params}, x[0]), (state.params, ), (vec, ))
                  jvp2 = jax.jvp(lambda params: state.apply_fns[1]({'params': params}, x[1]), (state.params, ), (vec, ))
                  hvp1 = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[0] * self._loss_fns[0](output, y[0], batch=batch))(output), jvp1[1]))(logits_list[0])
                  hvp2 = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[1] * self._loss_fns[1](output, y[1], batch=batch))(output), jvp2[1]))(logits_list[1])
                  vjp1 = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[0]({'params': params}, data), state.params)[1](v))(x[0], hvp1)
                  vjp2 = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[1]({'params': params}, data), state.params)[1](v))(x[1], hvp2)
                  return jax.tree_util.tree_map(lambda v1, v2: v1.mean(axis=0) + v2.mean(axis=0), vjp1, vjp2)[0]
            elif state.num_apply_fns == 3:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  jvp1 = jax.jvp(lambda params: state.apply_fns[0]({'params': params}, x[0]), (state.params, ), (vec, ))
                  jvp2 = jax.jvp(lambda params: state.apply_fns[1]({'params': params}, x[1]), (state.params, ), (vec, ))
                  jvp3 = jax.jvp(lambda params: state.apply_fns[2]({'params': params}, x[2]), (state.params, ), (vec, ))
                  hvp1 = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[0] * self._loss_fns[0](output, y[0], batch=batch))(output), jvp1[1]))(logits_list[0])
                  hvp2 = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[1] * self._loss_fns[1](output, y[1], batch=batch))(output), jvp2[1]))(logits_list[1])
                  hvp3 = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[2] * self._loss_fns[2](output, y[2], batch=batch))(output), jvp3[1]))(logits_list[2])
                  vjp1 = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[0]({'params': params}, data), state.params)[1](v))(x[0], hvp1)
                  vjp2 = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[1]({'params': params}, data), state.params)[1](v))(x[1], hvp2)
                  vjp3 = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns_single_point[2]({'params': params}, data), state.params)[1](v))(x[2], hvp3)
                  return jax.tree_util.tree_map(lambda v1, v2, v3: v1.mean(axis=0) + v2.mean(axis=0) + v3.mean(axis=0), vjp1, vjp2, vjp3)[0]
            else:
               raise ValueError(f"Number of apply functions must be 1, 2, or 3, but got {state.num_apply_fns}")
         
         else:
            if state.num_apply_fns == 1:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  jvp = jax.jvp(lambda params: state.apply_fns[0]({'params': params, 'batch_stats': state.batch_stats}, x, train=False), (state.params, ), (vec, ))
                  hvp = jax.grad(lambda output: jnp.vdot(jax.grad(lambda output: self._loss_weights[0] * self._loss_fns[0](output, y, batch=batch))(output), jvp[1]))(logits_list[0])
                  vjp = jax.vmap(lambda data, v: jax.vjp(lambda params: state.apply_fns[0]({'params': params, 'batch_stats': state.batch_stats}, data, train=False), state.params)[1](v))(x, hvp)
                  return jax.tree_util.tree_map(lambda v: v.mean(axis=0), vjp)[0]
            else:
               raise ValueError(f"Number of apply functions must be 1 when using batch_stats")
      else:
         # Hessian matrix-vector product
         if state.batch_stats is None:
            if state.num_apply_fns == 1:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  return jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params}, x), y, batch=batch))(params), (state.params,), (vec,))[1]
            elif state.num_apply_fns == 2:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  return jax.tree_util.tree_map(lambda v1, v2: v1 + v2, 
                     jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params}, x[0]), y[0], batch=batch))(params), (state.params,), (vec,))[1],
                     jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[1] * self._loss_fns[1](state.apply_fns[1]({'params': params}, x[1]), y[1], batch=batch))(params), (state.params,), (vec,))[1])
            elif state.num_apply_fns == 3:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  return jax.tree_util.tree_map(lambda v1, v2, v3: v1 + v2 + v3, 
                     jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params}, x[0]), y[0], batch=batch))(params), (state.params,), (vec,))[1],
                     jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[1] * self._loss_fns[1](state.apply_fns[1]({'params': params}, x[1]), y[1], batch=batch))(params), (state.params,), (vec,))[1],
                     jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[2] * self._loss_fns[2](state.apply_fns[2]({'params': params}, x[2]), y[2], batch=batch))(params), (state.params,), (vec,))[1])
         else:
            if state.num_apply_fns == 1:
               def base_mv(batch, state, optimizer_state, logits_list, vec):
                  x, y = batch
                  return jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params, 'batch_stats': state.batch_stats}, x, train=False), y, batch=batch))(params), (state.params,), (vec,))[1]
            else:
               raise ValueError(f"Number of apply functions must be 1 when using batch_stats")
      
      # JIT compile the base matrix-vector product function
      self._base_mv = jax.jit(base_mv)
      # self._velocity = jax.tree_util.tree_map(lambda p: jnp.zeros_like(p), state.params)  # Removed to avoid JAX tracer leaks
      self._optimizer_state = None

      if self._accumulate:
         # TODO: we assume rank is always less than the number of parameters, so there might be error
         raise NotImplementedError("Accumulate is not implemented for CG optimizer")
      else:
         # Define and JIT compile the second-order matrix-vector product function
         def second_order_mv(batch, state, optimizer_state, logits_list, vector):
            return jax.tree_util.tree_map(lambda v, w: v + self._second_order_tol * w, 
                                          self._base_mv(batch, state, optimizer_state, logits_list, vector), vector)
         
         self._second_order_mv = jax.jit(second_order_mv)
         
         def second_order_bmv(batch, state, optimizer_state, logits_list, vectors):   
            return jax.vmap(lambda vector: self._second_order_mv(batch, state, optimizer_state, logits_list, vector), in_axes=0)(vectors)
         
         self._second_order_bmv = jax.jit(second_order_bmv)
      
      # JIT compile the conjugate gradient solver
      def conjugate_gradient(A, b, x0=None, tol=1e-5, maxiter=5):
         if x0 is None:
            x = jax.tree_util.tree_map(jnp.zeros_like, b)
         else:
            x = x0
         
         r = vec_axpy(1.0, b, -1.0, A(x))
         p = r
         rr = vec_dot(r, r)

         def cg_step(state):
            x, r, p, its, rr = state
            Ap = A(p)
            App = vec_dot(Ap, p)
            alpha = rr / App
            x_new = vec_axpy(1.0, x, alpha, p)
            r_new = vec_axpy(1.0, r, -alpha, Ap)
            rr_new = vec_dot(r_new, r_new)
            beta = rr_new / rr
            p_new = vec_axpy(1.0, r_new, beta, p)
            return x_new, r_new, p_new, its + 1, rr_new

         def cg_cond(state):
            _, _, _, its, rr = state
            return (rr > tol) & (its < maxiter)

         final_state = jax.lax.while_loop(cg_cond, cg_step, (x, r, p, 0, rr))
         x, _, _, its, _ = final_state

         info = its

         return x, info
      
      self._conjugate_gradient = jax.jit(conjugate_gradient, static_argnums=(0, 2, 3, 4))
      
      # Define and JIT compile the CG direction function
      def cg_direction(state, batch, logits_list, optimizer_state, grads):
         grads = jax.lax.cond(
            self._grad_normalization,
            lambda _: vec_normalize(grads),
            lambda _: grads,
            operand = None
         )
         # Since velocity is removed, we don't apply momentum
         # Just use the gradients directly
         grads_cg = jax.lax.stop_gradient(grads)

         mv = lambda v: self._second_order_mv(batch, state, optimizer_state, logits_list, v)
         grad_cg, _ = self._conjugate_gradient(mv, grads_cg, None, self._tol, self._its)
         grad_cg = jax.lax.cond(
            self._output_normalization,
            lambda _: vec_normalize(grad_cg),
            lambda _: grad_cg,
            operand = None
         )
         return grad_cg, state, optimizer_state

      self._cg_direction = jax.jit(cg_direction)

      # Define the step functions based on batch_stats
      if state.batch_stats is None:
         def step(state, batch, optimizer_state, learning_rate):
            grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
            (loss, (loss_list, logits_list)), grads = grad_fn(state.params, state, batch)
            grads_cg, state, optimizer_state = self._cg_direction(state, batch, logits_list, optimizer_state, grads)
            params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, state.params, grads_cg)
            state = state.replace(params=params)
            return state, loss, grads, loss_list, logits_list, optimizer_state
      else:
         def step(state, batch, optimizer_state, learning_rate):
            grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
            (loss, (logits, updates)), grads = grad_fn(state.params, state, batch)
            grads_cg, state, optimizer_state = self._cg_direction(state, batch, [logits,], optimizer_state, grads)
            params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, state.params, grads_cg)
            state = state.replace(params=params, batch_stats=updates['batch_stats'])
            return state, loss, grads, [loss,], [logits,], optimizer_state

      # JIT compile the step function
      self._step = jax.jit(step)
      
   def get_hessian(self, state: TrainState, batch) -> jnp.ndarray:
      """
      Computes and returns the full Hessian matrix of the loss function
      with respect to model parameters.
      Includes regularization term (λ*I) if second_order_tol > 0.
      
      Args:
         state: The state of the model
         batch: The batch of data (x, y)
         
      Returns:
         jnp.ndarray: A 2D array representing the Hessian matrix
                     with regularization (λ*I) if second_order_tol > 0
         
      Note:
         - This always computes the true Hessian, regardless of second_order_option
         - For pure Hessian without regularization, use get_hessian_pure()
         - The matrix is computed column-by-column using matrix-vector products
         - Progress bar shows computation progress
      """
      # First, compute the loss and logits to use in Hessian computation
      grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
      (loss, aux), grads = grad_fn(state.params, state, batch)
      
      if state.batch_stats is None:
         logits_list = aux[1]  # For standard models
      else:
         logits_list = [aux[0]]  # For models with batch stats
      
      # Get flat parameters structure
      params_flat, tree_def = jax.tree_util.tree_flatten(state.params)
      params_shapes = [p.shape for p in params_flat]
      
      # Calculate total number of parameters
      n_params = sum(p.size for p in params_flat)
      
      # Flatten parameters into a single vector
      def _flatten_params(params):
         flat_params, _ = jax.tree_util.tree_flatten(params)
         return jnp.concatenate([p.reshape(-1) for p in flat_params])
      
      # Unflatten a vector back to parameter tree structure
      def _unflatten_params(flat_params):
         split_sizes = [p.size for p in params_flat]
         split_indices = jnp.cumsum(jnp.array(split_sizes))[:-1]
         split_params = jnp.split(flat_params, split_indices)
         reshaped_params = [p.reshape(shape) for p, shape in zip(split_params, params_shapes)]
         return jax.tree_util.tree_unflatten(tree_def, reshaped_params)
      
      # Define Hessian matrix-vector product (always compute true Hessian)
      # Copy the Hessian definition from the initialization
      if state.batch_stats is None:
         if state.num_apply_fns == 1:
            def hessian_base_mv(batch, state, optimizer_state, logits_list, vec):
               x, y = batch
               return jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params}, x), y, batch=batch))(params), (state.params,), (vec,))[1]
         elif state.num_apply_fns == 2:
            def hessian_base_mv(batch, state, optimizer_state, logits_list, vec):
               x, y = batch
               return jax.tree_util.tree_map(lambda v1, v2: v1 + v2, 
                  jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params}, x[0]), y[0], batch=batch))(params), (state.params,), (vec,))[1],
                  jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[1] * self._loss_fns[1](state.apply_fns[1]({'params': params}, x[1]), y[1], batch=batch))(params), (state.params,), (vec,))[1])
         elif state.num_apply_fns == 3:
            def hessian_base_mv(batch, state, optimizer_state, logits_list, vec):
               x, y = batch
               return jax.tree_util.tree_map(lambda v1, v2, v3: v1 + v2 + v3, 
                  jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params}, x[0]), y[0], batch=batch))(params), (state.params,), (vec,))[1],
                  jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[1] * self._loss_fns[1](state.apply_fns[1]({'params': params}, x[1]), y[1], batch=batch))(params), (state.params,), (vec,))[1],
                  jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[2] * self._loss_fns[2](state.apply_fns[2]({'params': params}, x[2]), y[2], batch=batch))(params), (state.params,), (vec,))[1])
      else:
         if state.num_apply_fns == 1:
            def hessian_base_mv(batch, state, optimizer_state, logits_list, vec):
               x, y = batch
               return jax.jvp(lambda params: jax.grad(lambda params: self._loss_weights[0] * self._loss_fns[0](state.apply_fns[0]({'params': params, 'batch_stats': state.batch_stats}, x, train=False), y, batch=batch))(params), (state.params,), (vec,))[1]
         else:
            raise ValueError(f"Number of apply functions must be 1 when using batch_stats")
      
      # Add regularization to Hessian if needed
      def hessian_mv_with_reg(batch, state, optimizer_state, logits_list, vector):
         #return jax.tree_util.tree_map(lambda v, w: v + self._second_order_tol * w, 
         #                              hessian_base_mv(batch, state, optimizer_state, logits_list, vector), vector)
         return jax.tree_util.tree_map(lambda v, w: v + 0.0 * w, 
                                       hessian_base_mv(batch, state, optimizer_state, logits_list, vector), vector)
      
      # Compile the Hessian matrix-vector product
      hessian_mv = jax.jit(hessian_mv_with_reg)
      
      # Compute Hessian-vector product function
      def hvp_fn(v):
         v_tree = _unflatten_params(v)
         result = hessian_mv(batch, state, self._optimizer_state, logits_list, v_tree)
         return _flatten_params(result)
      
      # Compute the entire Hessian matrix column by column using JAX's vmap
      def get_column(i):
         # Create a unit vector
         unit_vec = jnp.zeros(n_params)
         unit_vec = unit_vec.at[i].set(1.0)
         # Compute H*e_i (i-th column of H)
         return hvp_fn(unit_vec)
      
      # Use vmap to efficiently compute all columns
      # Note: For JIT-compiled functions, we can't use tqdm directly inside vmap
      # Instead, we'll compute in batches to show progress
      batch_size = min(100, n_params)  # Process columns in batches
      all_columns = []
      
      # Create progress bar
      with tqdm(total=n_params, desc="Computing Hessian", unit="cols") as pbar:
         start_time = time.time()
         
         for start_idx in range(0, n_params, batch_size):
            end_idx = min(start_idx + batch_size, n_params)
            batch_indices = jnp.arange(start_idx, end_idx)
            
            # Compute batch of columns
            batch_columns = jax.vmap(get_column)(batch_indices)
            all_columns.append(batch_columns)
            
            # Update progress bar
            pbar.update(end_idx - start_idx)
            
            # Update ETA
            elapsed = time.time() - start_time
            if start_idx > 0:
               rate = (start_idx + (end_idx - start_idx)) / elapsed
               remaining = (n_params - end_idx) / rate
               pbar.set_postfix({"ETA": f"{remaining:.1f}s"})
      
      # Concatenate all batches and transpose
      hessian_matrix = jnp.concatenate(all_columns, axis=0).T
      
      print(f"Hessian computation complete. Matrix shape: {hessian_matrix.shape}")
      return hessian_matrix
      
   def get_second_order_matrix(self, state: TrainState, batch) -> jnp.ndarray:
      """
      Computes and returns the second-order matrix used by the optimizer
      (Fisher or Hessian based on second_order_option configuration).
      Includes regularization term (λ*I) if second_order_tol > 0.
      
      Args:
         state: The state of the model
         batch: The batch of data (x, y)
         
      Returns:
         jnp.ndarray: A 2D array representing the second-order matrix:
                     - Fisher matrix if second_order_option='fisher'
                     - Hessian matrix if second_order_option='hessian'
                     Both include regularization (λ*I) if second_order_tol > 0
      """
      # First, compute the loss and logits to use in matrix computation
      grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
      (loss, aux), grads = grad_fn(state.params, state, batch)
      
      if state.batch_stats is None:
         logits_list = aux[1]  # For standard models
      else:
         logits_list = [aux[0]]  # For models with batch stats
      
      # Get flat parameters structure
      params_flat, tree_def = jax.tree_util.tree_flatten(state.params)
      params_shapes = [p.shape for p in params_flat]
      
      # Calculate total number of parameters
      n_params = sum(p.size for p in params_flat)
      
      # Flatten parameters into a single vector
      def _flatten_params(params):
         flat_params, _ = jax.tree_util.tree_flatten(params)
         return jnp.concatenate([p.reshape(-1) for p in flat_params])
      
      # Unflatten a vector back to parameter tree structure
      def _unflatten_params(flat_params):
         split_sizes = [p.size for p in params_flat]
         split_indices = jnp.cumsum(jnp.array(split_sizes))[:-1]
         split_params = jnp.split(flat_params, split_indices)
         reshaped_params = [p.reshape(shape) for p, shape in zip(split_params, params_shapes)]
         return jax.tree_util.tree_unflatten(tree_def, reshaped_params)
      
      # Use the optimizer's configured second-order matrix-vector product
      def mvp_fn(v):
         v_tree = _unflatten_params(v)
         result = self._second_order_mv(batch, state, self._optimizer_state, logits_list, v_tree)
         return _flatten_params(result)
      
      # Compute the entire matrix column by column using JAX's vmap
      def get_column(i):
         # Create a unit vector
         unit_vec = jnp.zeros(n_params)
         unit_vec = unit_vec.at[i].set(1.0)
         # Compute M*e_i (i-th column of M)
         return mvp_fn(unit_vec)
      
      # Use vmap to efficiently compute all columns
      batch_size = min(100, n_params)  # Process columns in batches
      all_columns = []
      
      # Create progress bar
      with tqdm(total=n_params, desc=f"Computing {self._second_order_option} matrix", unit="cols") as pbar:
         start_time = time.time()
         
         for start_idx in range(0, n_params, batch_size):
            end_idx = min(start_idx + batch_size, n_params)
            batch_indices = jnp.arange(start_idx, end_idx)
            
            # Compute batch of columns
            batch_columns = jax.vmap(get_column)(batch_indices)
            all_columns.append(batch_columns)
            
            # Update progress bar
            pbar.update(end_idx - start_idx)
            
            # Update ETA
            elapsed = time.time() - start_time
            if start_idx > 0:
               rate = (start_idx + (end_idx - start_idx)) / elapsed
               remaining = (n_params - end_idx) / rate
               pbar.set_postfix({"ETA": f"{remaining:.1f}s"})
      
      # Concatenate all batches and transpose
      matrix = jnp.concatenate(all_columns, axis=0).T
      
      print(f"{self._second_order_option} matrix computation complete. Matrix shape: {matrix.shape}")
      return matrix
      
   def get_hessian_pure(self, state: TrainState, batch) -> jnp.ndarray:
      """
      Computes and returns the pure Hessian matrix (without regularization) 
      of the loss function with respect to model parameters.
      
      Args:
         state: The state of the model
         batch: The batch of data (x, y)
         
      Returns:
         jnp.ndarray: A 2D array representing the pure Hessian matrix
         
      Note:
         This method temporarily sets the regularization (second_order_tol) to 0
         to compute the pure Hessian, then restores the original value.
      """
      # Store the original regularization value and second-order option
      original_tol = self._second_order_tol
      original_option = self._second_order_option
      
      # Temporarily set to compute pure Hessian
      self._second_order_tol = 0.0
      self._second_order_option = 'hessian'
      
      # Compute the Hessian using the regular method
      hessian = self.get_hessian(state, batch)
      
      # Restore the original values
      self._second_order_tol = original_tol
      self._second_order_option = original_option
      
      return hessian
      
   def get_preconditioned_hessian(self, state: TrainState, batch, model: FlaxNet = None, 
                                   include_learning_rate: bool = False) -> jnp.ndarray:
      """
      Computes the preconditioned Hessian matrix by applying CG preconditioning to each column
      of the original Hessian matrix.
      
      Args:
         state: The state of the model
         batch: The batch of data (x, y)
         model: The model (needed for pytree/vector conversions)
         include_learning_rate: If True, scales the result by the learning rate to show
                               the effective curvature seen by the optimizer
         
      Returns:
         jnp.ndarray: A 2D array representing the preconditioned Hessian matrix,
                     optionally scaled by learning rate
      """
      if model is None:
         raise ValueError("Model must be provided for get_preconditioned_hessian")
         
      # Initialize model's tree structure for vector conversion
      _ = model.pytree2leaves(state.params)
      
      # Step 1: Get the original Hessian matrix
      hessian = self.get_hessian(state, batch)
      
      # Get logits list for preconditioning operations
      grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
      (loss, aux), grads = grad_fn(state.params, state, batch)
      
      if state.batch_stats is None:
         logits_list = aux[1]  # For standard models
      else:
         logits_list = [aux[0]]  # For models with batch stats
      
      # Get matrix dimensions
      n_params = hessian.shape[0]
      
      # Setup the matrix-vector product function 
      mv_fn = lambda v: self._second_order_mv(batch, state, self._optimizer_state, logits_list, v)
      
      # Velocity is removed, no need to save state
      
      # Compute the preconditioned matrix one column at a time
      preconditioned_columns = []
      
      # Create progress bar
      with tqdm(total=n_params, desc="Computing Preconditioned Hessian", unit="cols") as pbar:
         start_time = time.time()
         
         for i in range(n_params):
            # Get the i-th column of the Hessian
            h_col = hessian[:, i]
            
            # Step 2: Convert column vector to PyTree using model's API
            h_col_pytree = model.vector2pytree(h_col)
         
            # Step 3: Apply CG preconditioning
            # This corresponds to solving the linear system M * x = h_col_pytree
            # where M is the second-order matrix defined by self._second_order_mv
            
            # First apply any normalization (use vec_normalize to be consistent with CG optimizer)
            h_col_pytree = jax.lax.cond(
               self._grad_normalization,
               lambda _: vec_normalize(h_col_pytree),
               lambda _: h_col_pytree,
               operand = None
            )
            
            # Momentum is disabled since velocity is removed
            h_col_with_momentum = h_col_pytree
            
            # Nesterov momentum is disabled since velocity is removed
            h_col_final = h_col_with_momentum
            
            # Apply conjugate gradient to get the preconditioned column
            precond_col_pytree, _ = self._conjugate_gradient(mv_fn, h_col_final, None, self._tol, self._its)
            
            # Apply output normalization if needed (use vec_normalize for consistency)
            precond_col_pytree = jax.lax.cond(
               self._output_normalization,
               lambda _: vec_normalize(precond_col_pytree),
               lambda _: precond_col_pytree,
               operand = None
            )
            
            # Step 4: Convert back to vector using model's API
            precond_col = model.pytree2vector(precond_col_pytree)
            
            # Add to our collection
            preconditioned_columns.append(precond_col)
            
            # Update progress bar
            pbar.update(1)
            
            # Update ETA
            elapsed = time.time() - start_time
            if i > 0:
               rate = (i + 1) / elapsed
               remaining = (n_params - i - 1) / rate
               pbar.set_postfix({"ETA": f"{remaining:.1f}s"})
      
      # Restore original velocity state
      # Velocity is removed, no need to restore state
      
      # Combine all columns into a matrix
      preconditioned_hessian = jnp.column_stack(preconditioned_columns)
      
      # Apply learning rate scaling if requested
      if include_learning_rate:
         preconditioned_hessian = preconditioned_hessian * self._config['optimizer']['learning_rate']
      
      print(f"Preconditioned Hessian computation complete. Matrix shape: {preconditioned_hessian.shape}")
      return preconditioned_hessian
      
   def get_preconditioned_hessian_lowmem(self, state: TrainState, batch, model: FlaxNet = None, 
                                          max_cols: int = 100, include_learning_rate: bool = False) -> jnp.ndarray:
      """
      Memory-efficient version of get_preconditioned_hessian that processes the Hessian
      in chunks to avoid OOM errors on large models.
      
      Args:
         state: The state of the model
         batch: The batch of data (x, y)
         model: The model (needed for pytree/vector conversions)
         max_cols: Maximum number of columns to process at once
         include_learning_rate: If True, scales the result by the learning rate to show
                               the effective curvature seen by the optimizer
         
      Returns:
         jnp.ndarray: A 2D array representing the preconditioned Hessian matrix,
                     optionally scaled by learning rate
      """
      if model is None:
         raise ValueError("Model must be provided for get_preconditioned_hessian_lowmem. "
                         "The model should have pytree2vector and vector2pytree methods.")
         
      # Initialize model's tree structure for vector conversion
      _ = model.pytree2leaves(state.params)
      
      # Setup for Hessian computation
      grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
      (loss, aux), grads = grad_fn(state.params, state, batch)
      
      if state.batch_stats is None:
         logits_list = aux[1]  # For standard models
      else:
         logits_list = [aux[0]]  # For models with batch stats
      
      # Calculate total number of parameters
      params_flat, _ = jax.tree_util.tree_flatten(state.params)
      n_params = sum(p.size for p in params_flat)
      
      # Define HVP function (reuse existing logic)
      # Use second_order_mv to ensure consistency with CG optimizer
      def hvp_fn(vec):
         vec_tree = model.vector2pytree(vec)
         result_tree = self._second_order_mv(batch, state, self._optimizer_state, logits_list, vec_tree)
         return model.pytree2vector(result_tree)
      
      # Setup the matrix-vector product function for CG
      mv_fn = lambda v: self._second_order_mv(batch, state, self._optimizer_state, logits_list, v)
      
      # Velocity is removed, no need to save state
      
      # Process the Hessian in chunks to save memory
      preconditioned_hessian = []
      
      # Create overall progress bar
      with tqdm(total=n_params, desc="Computing Preconditioned Hessian (Low Memory)", unit="cols") as pbar:
         start_time = time.time()
         
         # Process each chunk of columns
         for chunk_start in range(0, n_params, max_cols):
            chunk_end = min(chunk_start + max_cols, n_params)
            
            # Log progress if print level is set
            if hasattr(self, '_print_level') and self._print_level > 0:
               print(f"Processing Hessian columns {chunk_start} to {chunk_end-1} of {n_params}")
            
            # Compute this chunk of the Hessian
            chunk_columns = []
            for i in range(chunk_start, chunk_end):
               # Create unit vector for this column
               unit_vec = jnp.zeros(n_params).at[i].set(1.0)
            
               # Get the Hessian column using MVP
               h_col = hvp_fn(unit_vec)
               
               # Convert to PyTree using model's API
               h_col_pytree = model.vector2pytree(h_col)
               
               # Apply normalization if enabled in the optimizer configuration
               h_col_pytree = jax.lax.cond(
                  self._grad_normalization,
                  lambda _: vec_normalize(h_col_pytree),
                  lambda _: h_col_pytree,
                  operand = None
               )
               
               # Momentum is disabled since velocity is removed
               h_col_with_momentum = h_col_pytree
               
               # Nesterov momentum is disabled since velocity is removed
               h_col_final = h_col_with_momentum
               
               # Apply conjugate gradient to get the preconditioned column
               precond_col_pytree, _ = self._conjugate_gradient(mv_fn, h_col_final, None, self._tol, self._its)
            
               # Apply output normalization if needed (use vec_normalize for consistency)
               precond_col_pytree = jax.lax.cond(
                  self._output_normalization,
                  lambda _: vec_normalize(precond_col_pytree),
                  lambda _: precond_col_pytree,
                  operand = None
               )
               
               # Convert back to vector using model's API
               precond_col = model.pytree2vector(precond_col_pytree)
               
               # Add to our collection
               chunk_columns.append(precond_col)
               
               # Update progress bar for each column
               pbar.update(1)
               
               # Update ETA
               elapsed = time.time() - start_time
               if i > 0:
                  rate = (i + 1) / elapsed
                  remaining = (n_params - i - 1) / rate
                  pbar.set_postfix({"ETA": f"{remaining:.1f}s"})
            
            # Combine this chunk of columns and add to result
            preconditioned_chunk = jnp.column_stack(chunk_columns)
            preconditioned_hessian.append(preconditioned_chunk)
      
      # Restore original velocity state
      # Velocity is removed, no need to restore state
      
      # Combine all chunks into the final matrix
      result = jnp.hstack(preconditioned_hessian)
      
      # Apply learning rate scaling if requested
      if include_learning_rate:
         result = result * self._config['optimizer']['learning_rate']
      
      print(f"Preconditioned Hessian (low memory) complete. Matrix shape: {result.shape}")
      return result
      
   def get_preconditioner(self, state: TrainState, batch, model: FlaxNet = None, 
                          include_learning_rate: bool = False) -> jnp.ndarray:
      """
      Computes the preconditioner matrix by applying CG preconditioning to each column
      of an identity matrix.
      
      Args:
         state: The state of the model
         batch: The batch of data (x, y)
         model: The model (needed for pytree/vector conversions)
         include_learning_rate: If True, scales the result by the learning rate to show
                               the effective curvature seen by the optimizer
         
      Returns:
         jnp.ndarray: A 2D array representing the preconditioner matrix,
                     optionally scaled by learning rate
      """
      if model is None:
         raise ValueError("Model must be provided for get_preconditioner")
         
      # Initialize model's tree structure for vector conversion
      _ = model.pytree2leaves(state.params)
      
      # Get logits list for preconditioning operations
      grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
      (loss, aux), grads = grad_fn(state.params, state, batch)
      
      if state.batch_stats is None:
         logits_list = aux[1]  # For standard models
      else:
         logits_list = [aux[0]]  # For models with batch stats
      
      # Calculate total number of parameters
      params_flat, _ = jax.tree_util.tree_flatten(state.params)
      n_params = sum(p.size for p in params_flat)
      
      # Create identity matrix
      identity_matrix = jnp.eye(n_params)
      
      # Setup the matrix-vector product function 
      mv_fn = lambda v: self._second_order_mv(batch, state, self._optimizer_state, logits_list, v)
      
      # Compute the preconditioned matrix one column at a time
      preconditioned_columns = []
      
      # Create progress bar
      with tqdm(total=n_params, desc="Computing Preconditioner Matrix", unit="cols") as pbar:
         start_time = time.time()
         
         for i in range(n_params):
            # Get the i-th column of the identity matrix
            identity_col = identity_matrix[:, i]
            
            # Step 2: Convert column vector to PyTree using model's API
            identity_col_pytree = model.vector2pytree(identity_col)
         
            # Step 3: Apply CG preconditioning
            # This corresponds to solving the linear system M * x = identity_col_pytree
            # where M is the second-order matrix defined by self._second_order_mv
            
            # First apply any normalization (use vec_normalize to be consistent with CG optimizer)
            identity_col_pytree = jax.lax.cond(
               self._grad_normalization,
               lambda _: vec_normalize(identity_col_pytree),
               lambda _: identity_col_pytree,
               operand = None
            )
            
            # Momentum is disabled since velocity is removed
            identity_col_with_momentum = identity_col_pytree
            
            # Nesterov momentum is disabled since velocity is removed
            identity_col_final = identity_col_with_momentum
            
            # Apply conjugate gradient to get the preconditioned column
            precond_col_pytree, _ = self._conjugate_gradient(mv_fn, identity_col_final, None, self._tol, self._its)
            
            # Apply output normalization if needed (use vec_normalize for consistency)
            precond_col_pytree = jax.lax.cond(
               self._output_normalization,
               lambda _: vec_normalize(precond_col_pytree),
               lambda _: precond_col_pytree,
               operand = None
            )
            
            # Step 4: Convert back to vector using model's API
            precond_col = model.pytree2vector(precond_col_pytree)
            
            # Add to our collection
            preconditioned_columns.append(precond_col)
            
            # Update progress bar
            pbar.update(1)
            
            # Update ETA
            elapsed = time.time() - start_time
            if i > 0:
               rate = (i + 1) / elapsed
               remaining = (n_params - i - 1) / rate
               pbar.set_postfix({"ETA": f"{remaining:.1f}s"})
      
      # Combine all columns into a matrix
      preconditioner_matrix = jnp.column_stack(preconditioned_columns)
      
      # Apply learning rate scaling if requested
      if include_learning_rate:
         preconditioner_matrix = preconditioner_matrix * self._config['optimizer']['learning_rate']
      
      print(f"Preconditioner matrix computation complete. Matrix shape: {preconditioner_matrix.shape}")
      return preconditioner_matrix
      
   def get_hessian_properties(self, state: TrainState, batch) -> Dict[str, Any]:
      """
      Computes useful properties of the Hessian matrix without constructing
      the full matrix explicitly.
      
      Args:
         state: The state of the model
         batch: The batch of data (x, y)
         
      Returns:
         Dict: A dictionary containing Hessian properties like condition number,
              eigenvalues, etc.
      """
      # First, compute the loss and logits to use in Hessian computation
      grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
      (loss, aux), grads = grad_fn(state.params, state, batch)
      
      if state.batch_stats is None:
         logits_list = aux[1]  # For standard models
      else:
         logits_list = [aux[0]]  # For models with batch stats
      
      # Get flat parameters structure
      params_flat, tree_def = jax.tree_util.tree_flatten(state.params)
      params_shapes = [p.shape for p in params_flat]
      
      # Calculate total number of parameters
      n_params = sum(p.size for p in params_flat)
      
      # Flatten parameters into a single vector
      def _flatten_params(params):
         flat_params, _ = jax.tree_util.tree_flatten(params)
         return jnp.concatenate([p.reshape(-1) for p in flat_params])
      
      # Unflatten a vector back to parameter tree structure
      def _unflatten_params(flat_params):
         split_sizes = [p.size for p in params_flat]
         split_indices = jnp.cumsum(jnp.array(split_sizes))[:-1]
         split_params = jnp.split(flat_params, split_indices)
         reshaped_params = [p.reshape(shape) for p, shape in zip(split_params, params_shapes)]
         return jax.tree_util.tree_unflatten(tree_def, reshaped_params)
      
      # Compute Hessian-vector product function for power iteration
      # Use base_mv directly to get the pure matrix without regularization for accurate eigenvalues
      def hvp_fn(v):
         v_tree = _unflatten_params(v)
         result = self._base_mv(batch, state, self._optimizer_state, logits_list, v_tree)
         return _flatten_params(result)
      
      # Compute condition number using power iteration
      def power_iteration(matvec_fn, n, num_iters=100, tol=1e-6):
         # Estimate largest eigenvalue
         x = jax.random.normal(jax.random.PRNGKey(0), (n,))
         x = x / jnp.linalg.norm(x)
         
         # Power iteration with progress bar
         with tqdm(total=num_iters, desc="Power iteration (largest eigenvalue)", unit="iters") as pbar:
            for i in range(num_iters):
               y = matvec_fn(x)
               x = y / jnp.linalg.norm(y)
               pbar.update(1)
         
         v1 = x
         lambda1 = jnp.dot(v1, matvec_fn(v1))
         
         # Estimate smallest eigenvalue using inverse power iteration
         # This is approximate and assumes positive definite Hessian
         def inverse_matvec_fn(x):
            # Use conjugate gradient to approximate H^-1 x
            def cg_body_fn(i, val):
               x, r, p, rr_old = val
               Ap = matvec_fn(p)
               alpha = rr_old / jnp.dot(p, Ap)
               x_new = x + alpha * p
               r_new = r - alpha * Ap
               rr_new = jnp.dot(r_new, r_new)
               beta = rr_new / rr_old
               p_new = r_new + beta * p
               return x_new, r_new, p_new, rr_new
            
            r = x - matvec_fn(jnp.zeros_like(x))
            val_init = (jnp.zeros_like(x), r, r, jnp.dot(r, r))
            approx_inv, _, _, _ = jax.lax.fori_loop(0, min(n, 50), cg_body_fn, val_init)
            return approx_inv
         
         x = jax.random.normal(jax.random.PRNGKey(1), (n,))
         x = x / jnp.linalg.norm(x)
         
         # Inverse power iteration with progress bar
         with tqdm(total=num_iters, desc="Power iteration (smallest eigenvalue)", unit="iters") as pbar:
            for i in range(num_iters):
               y = inverse_matvec_fn(x)
               x = y / jnp.linalg.norm(y)
               pbar.update(1)
         
         v2 = x
         lambda2 = jnp.dot(v2, matvec_fn(v2)) / jnp.dot(v2, v2)
         
         return lambda1, lambda2, lambda1 / jnp.maximum(lambda2, 1e-10)
      
      # Compute eigenvalue stats
      print(f"Computing Hessian properties for {n_params} parameters...")
      max_eig, min_eig, condition_number = power_iteration(hvp_fn, n_params)
      
      # Compute the trace (sum of diagonal elements)
      def diagonal_element(i):
         unit = jnp.zeros(n_params)
         unit = unit.at[i].set(1.0)
         return jnp.dot(unit, hvp_fn(unit))
      
      # Sample a subset of diagonal elements to estimate trace
      sample_size = min(100, n_params)
      indices = jnp.linspace(0, n_params-1, sample_size).astype(jnp.int32)
      
      # Add progress bar for diagonal element computation
      sampled_diag = []
      with tqdm(total=sample_size, desc="Computing diagonal elements", unit="elements") as pbar:
         for idx in indices:
            diag_elem = diagonal_element(idx)
            sampled_diag.append(diag_elem)
            pbar.update(1)
      
      sampled_diag = jnp.array(sampled_diag)
      trace_estimate = jnp.mean(sampled_diag) * n_params
      
      result = {
         "max_eigenvalue": max_eig,
         "min_eigenvalue": min_eig,
         "condition_number": condition_number,
         "trace_estimate": trace_estimate,
         "num_parameters": n_params,
         "is_fisher": self._second_order_option == 'fisher'
      }
      
      print(f"Hessian properties computed: κ={condition_number:.2e}, λ_max={max_eig:.2e}, λ_min={min_eig:.2e}")
      return result
      
   def get_hessian_jit(self, state: TrainState, batch, max_params: int = 1000) -> jnp.ndarray:
      """
      JIT-compiled version of get_hessian for small models.
      Only use this for models with fewer parameters than max_params.
      
      Args:
         state: The state of the model
         batch: The batch of data (x, y)
         max_params: Maximum number of parameters for which this function is safe to use
         
      Returns:
         jnp.ndarray: A 2D array representing the Hessian matrix
      """
      # First, compute the loss and logits to use in Hessian computation
      grad_fn = jax.value_and_grad(self._eval_and_loss, has_aux=True)
      (loss, aux), grads = grad_fn(state.params, state, batch)
      
      if state.batch_stats is None:
         logits_list = aux[1]  # For standard models
      else:
         logits_list = [aux[0]]  # For models with batch stats
      
      # Get flat parameters structure
      params_flat, tree_def = jax.tree_util.tree_flatten(state.params)
      
      # Calculate total number of parameters
      n_params = sum(p.size for p in params_flat)
      
      # Check if the model is small enough
      if n_params > max_params:
         raise ValueError(f"Model has {n_params} parameters, which exceeds the maximum of {max_params} "
                         f"for JIT-compiled Hessian computation. Use get_hessian instead.")
      
      # Flatten the parameters into a single vector
      flattened_params = jnp.concatenate([p.reshape(-1) for p in params_flat])
      
      # Define a function to compute Hessian using JAX's Jacobian of gradient
      if self._second_order_option == 'fisher':
         # For Fisher, we need to use a different approach
         # This computes Fisher Information Matrix (FIM) directly
         @jax.jit
         def compute_fisher_matrix():
            # Define a function to compute gradient for a single sample
            def compute_single_sample_grad(single_x, single_y):
               # Compute gradient for this single sample
               if state.batch_stats is None:
                  logits = state.apply_fns[0]({'params': state.params}, single_x[None, ...])
                  grad = jax.grad(lambda p: self._loss_weights[0] * self._loss_fns[0](
                     state.apply_fns[0]({'params': p}, single_x[None, ...]), 
                     single_y[None, ...], batch=(single_x[None, ...], single_y[None, ...])))(state.params)
               else:
                  logits, _ = state.apply_fns[0](
                     {'params': state.params, 'batch_stats': state.batch_stats},
                     single_x[None, ...], train=False, mutable=[]
                  )
                  grad = jax.grad(lambda p: self._loss_weights[0] * self._loss_fns[0](
                     state.apply_fns[0]({'params': p, 'batch_stats': state.batch_stats}, 
                     single_x[None, ...], train=False)[0], 
                     single_y[None, ...], batch=(single_x[None, ...], single_y[None, ...])))(state.params)
               
               # Flatten the gradient
               grad_flat, _ = jax.tree_util.tree_flatten(grad)
               return jnp.concatenate([g.reshape(-1) for g in grad_flat])
            
            # Get batch data
            x, y = batch
            
            # Compute gradients for each sample in the batch
            sample_grads = jax.vmap(compute_single_sample_grad)(x, y)
            
            # Compute the Fisher Information Matrix as the average of outer products
            # F = (1/n) * Σ(∇L_i * ∇L_i^T)
            fim = jnp.mean(jax.vmap(lambda g: jnp.outer(g, g))(sample_grads), axis=0)
            
            return fim
            
         # Compute and return the Fisher matrix
         return compute_fisher_matrix()
      else:
         # For true Hessian, use JAX's hessian function directly
         @jax.jit
         def compute_hessian():
            # Define a loss function that takes flat parameters
            def loss_fn(params_vector):
               # Reshape the vector back to the original params structure
               params_split = []
               start_idx = 0
               for p in params_flat:
                  end_idx = start_idx + p.size
                  params_split.append(params_vector[start_idx:end_idx].reshape(p.shape))
                  start_idx = end_idx
               
               params_tree = jax.tree_util.tree_unflatten(tree_def, params_split)
               
               # Compute the loss
               x, y = batch
               if state.batch_stats is None:
                  return self._loss_weights[0] * self._loss_fns[0](
                     state.apply_fns[0]({'params': params_tree}, x), y, batch=batch)
               else:
                  return self._loss_weights[0] * self._loss_fns[0](
                     state.apply_fns[0]({'params': params_tree, 'batch_stats': state.batch_stats}, x, train=False), 
                     y, batch=batch)
            
            # Compute the Hessian
            return jax.hessian(loss_fn)(flattened_params)
            
         # Compute and return the Hessian matrix
         return compute_hessian()
