"""
Optax L-BFGS optimizer implementation for regression example.
"""

import jax
import jax.numpy as jnp
import numpy as np
import time
import pickle
from pathlib import Path
from typing import Dict, Tuple, List, Any, Optional
import optax
from flax.training import train_state
from foo.models import TrainState

# Define Flax-style TrainState specifically for Optax's L-BFGS optimizer
class OptaxLBFGSTrainState(train_state.TrainState):
   """TrainState with batch stats for normalization layers."""
   batch_stats: Optional[Any] = None

   @classmethod
   def create(cls, *, apply_fn, params, tx, batch_stats=None):
      """Creates a new instance with `step=0`."""
      return cls(
         step=0,
         apply_fn=apply_fn,
         params=params,
         tx=tx,
         opt_state=tx.init(params),
         batch_stats=batch_stats,
      )

# MSE loss function for Optax L-BFGS training workflow
def optax_mse_loss(logits, targets):
   """Mean squared error loss for regression."""
   return jnp.mean((logits - targets) ** 2)

# Single step function for Optax L-BFGS training
@jax.jit
def optax_lbfgs_train_step(state, batch):
   """Train for a single step with Optax L-BFGS optimizer."""
   x, y = batch

   # Create value function and value_and_grad function@jax.jit
   def value_fn(params):
      """Value function for L-BFGS that computes loss for given parameters."""
      # Forward pass
      if state.batch_stats is not None:
         y_pred, _ = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
            x,
            mutable=['batch_stats']
         )
      else:
         y_pred = state.apply_fn({'params': params}, x)

      # Calculate loss
      loss = optax_mse_loss(y_pred, y)
      return loss

   # Loss function that also returns batch stats (for gradient calculation)
   def loss_fn(params):
      if state.batch_stats is None:
         logits = state.apply_fn({'params': params}, x)
         loss = optax_mse_loss(logits, y)
         return loss, (logits, None)
      else:
         (logits, updated_batch_stats) = state.apply_fn(
               {'params': params, 'batch_stats': state.batch_stats},
               x, mutable=['batch_stats']
         )
         loss = optax_mse_loss(logits, y)
         return loss, (logits, updated_batch_stats)

   # Create value_and_grad function
   def value_and_grad_fn(params):
      (loss, (_, updated_batch_stats)), grad = jax.value_and_grad(
         loss_fn, has_aux=True)(params)
      return loss, grad, updated_batch_stats

   # Compute loss and gradient for current parameters
   loss, grad, updated_batch_stats = value_and_grad_fn(state.params)

   # Use L-BFGS optimizer to compute updates
   updates, new_opt_state = state.tx.update(
      grad,
      state.opt_state,
      params=state.params,
      value=loss,
      value_fn=value_fn,
      grad=grad
   )

   # Apply updates to parameters
   new_params = optax.apply_updates(state.params, updates)

   # Create new state
   new_state = state.replace(
      params=new_params,
      opt_state=new_opt_state
   )

   # Update batch statistics if available
   if updated_batch_stats is not None:
      new_state = new_state.replace(batch_stats=updated_batch_stats['batch_stats'])

   return new_state, loss

# Create state compatible with foo.models.TrainState for interface consistency
def create_compatible_state(optax_state):
   """Create state compatible with foo.models.TrainState."""
   return TrainState(
      step=0,  # Use fixed step 0 to maintain interface consistency
      apply_fn=optax_state.apply_fn,
      params=optax_state.params,
      tx=None,  # TrainState requires tx
      opt_state=None,  # TrainState requires opt_state
      rngs=None,  # TrainState requires rngs
      batch_stats=optax_state.batch_stats
   )

# Function to train model with Optax L-BFGS
def train_with_optax_lbfgs(
   initial_params,
   apply_fn,
   batch_stats,
   data,
   batch_size,
   epochs,
   max_linesearch_steps=20,
   memory_size=10,
   scale_init_precond=True,
   sufficient_decrease=0.1,
   curvature=0.9,
   initial_guess_strategy='one',
   resample_freq=0,
   n_samples=0,
   noise_level=0.1,
   base_seed=42,
   resample_counter=0,
   print_every=100,
   phase_name="Optax L-BFGS",
   jax_device=None,
   shuffle_seed=None,
   dtype=jnp.float32
):
   """
   Train a model using Optax L-BFGS optimizer.

   Args:
      initial_params: Initial model parameters
      apply_fn: Model application function
      batch_stats: Initial batch statistics (if any)
      data: Training data (X_train, y_train)
      batch_size: Batch size for training
      epochs: Number of epochs to train
      max_linesearch_steps: Maximum number of line search steps for L-BFGS
      memory_size: Number of past updates to keep in memory for approximating the Hessian
      scale_init_precond: Whether to scale the initial preconditioner
      sufficient_decrease: Armijo-Goldstein slope_rtol parameter (default 0.1)
      curvature: Wolfe curv_rtol parameter (default 0.9)
      initial_guess_strategy: Strategy for initial step size guess ('one' or 'previous')
      resample_freq: How often to resample data (0 for no resampling)
      n_samples: Number of samples to generate when resampling
      noise_level: Noise level for data generation
      base_seed: Base random seed
      resample_counter: Counter for resampling
      print_every: How often to print progress
      phase_name: Name of the training phase
      jax_device: JAX device to use
      shuffle_seed: Seed for data shuffling
      dtype: Data type for arrays

   Returns:
      compatible_train_state: State compatible with foo.models.TrainState
      states_history: History of states during training
      losses: Training loss history
      training_time: Total optimization time (excluding resampling)
      time_per_step: List of cumulative optimization times per step
      resample_counter: Resample counter for next phase
      optax_state: Optax training state for saving and restoring
   """
   # Import here to avoid circular imports
   from data import generate_data

   X_train, y_train = data
   n_samples_current = len(X_train)

   # Create Optax L-BFGS optimizer with all parameters
   tx = optax.lbfgs(
      memory_size=memory_size,
      linesearch=optax.scale_by_zoom_linesearch(
         max_linesearch_steps=max_linesearch_steps,
         slope_rtol=sufficient_decrease, # Armijo-Goldstein sufficient decrease parameter
         curv_rtol=curvature,           # Wolfe curvature condition parameter
         initial_guess_strategy=initial_guess_strategy,
         verbose=False
      ),
      scale_init_precond=scale_init_precond
   )

   # Create Optax training state with initial parameters and batch stats
   state = OptaxLBFGSTrainState.create(
      apply_fn=apply_fn,
      params=initial_params,
      tx=tx,
      batch_stats=batch_stats
   )

   # Initialize state history and loss records
   initial_compat_state = create_compatible_state(state)
   states_history = [initial_compat_state]
   losses = []
   time_per_step = []  # Track time for each optimization step

   # Track optimization time separately from resampling time
   total_opt_time = 0.0

   # Use the passed resample counter to maintain continuity between phases
   current_resample_counter = resample_counter

   for epoch in range(epochs):
      # If resampling is needed
      if resample_freq > 0 and epoch > 0 and epoch % resample_freq == 0:
         # Increment counter for resampling event
         current_resample_counter += 1

         # Deterministically generate new seed from base seed and current resample counter
         resample_seed = base_seed + 10000 * current_resample_counter

         # Generate new data
         x_train, y_train, z_train = generate_data(
               n_samples, noise_level, resample_seed
         )

         # Convert to JAX arrays and move to device
         X_train = jnp.array(np.column_stack((x_train, y_train)), dtype=dtype)
         y_train = jnp.array(z_train, dtype=dtype).reshape(-1, 1)

         if jax_device:
               X_train = jax.device_put(X_train, jax_device)
               y_train = jax.device_put(y_train, jax_device)

         # Update current sample count
         n_samples_current = len(X_train)

      # Start timing optimization (after potential resampling)
      # Ensure all previous GPU operations are complete before starting the timer
      if jax_device and hasattr(jax_device, 'type') and jax_device.type == 'gpu':
         jax.block_until_ready(state)

      opt_start_time = time.time()

      # Shuffle data with deterministic seed
      if shuffle_seed is not None:
         # Combine the provided shuffle_seed with current epoch for reproducibility
         epoch_shuffle_seed = shuffle_seed + epoch
         rng = np.random.RandomState(epoch_shuffle_seed)
         indices = rng.permutation(n_samples_current)
      else:
         # Use traditional permutation
         indices = np.random.permutation(n_samples_current)

      x_train_shuffled = X_train[indices]
      y_train_shuffled = y_train[indices]

      epoch_losses = []

      # Batch training
      for i in range(0, n_samples_current, batch_size):
         batch_x = x_train_shuffled[i:i+batch_size]
         batch_y = y_train_shuffled[i:i+batch_size]

         # Perform optimization step using JIT-compiled function
         state, loss = optax_lbfgs_train_step(state, (batch_x, batch_y))
         epoch_losses.append(loss)

      # Record average training loss
      train_loss = np.mean(epoch_losses)
      losses.append(train_loss)

      # Convert current Optax state to compatible state and add to history
      compat_state = create_compatible_state(state)
      states_history.append(compat_state)

      # Ensure all GPU operations are complete before stopping the timer
      if jax_device and hasattr(jax_device, 'type') and jax_device.type == 'gpu':
         jax.block_until_ready(compat_state)

      # Calculate and record optimization time for this epoch
      epoch_opt_time = time.time() - opt_start_time
      total_opt_time += epoch_opt_time
      time_per_step.append(total_opt_time)  # Cumulative time at each step

      if (epoch + 1) % print_every == 0:
         print(f"{phase_name} - Epoch {epoch+1}/{epochs}, Loss: {train_loss:.6f}")

   print(f"{phase_name} optimization time: {total_opt_time:.2f} seconds")

   # Adjust the first step time to exclude JIT compilation overhead
   if len(time_per_step) > 1:
      # Calculate average time per step excluding the first step
      avg_step_time = (time_per_step[-1] - time_per_step[0]) / (len(time_per_step) - 1)

      # Replace the first step's time with the average time (ensure it's not negative)
      if avg_step_time > 0 and time_per_step[0] > avg_step_time:
         time_adjustment = time_per_step[0] - avg_step_time

         # Adjust all time_per_step values
         adjusted_time_per_step = [max(0.0, t - time_adjustment) for t in time_per_step]

         # Adjust total optimization time
         adjusted_opt_time = max(0.0, total_opt_time - time_adjustment)

         print(f"Adjusted first step time: removed {time_adjustment:.4f}s of JIT compilation overhead")
         print(f"{phase_name} adjusted optimization time: {adjusted_opt_time:.2f} seconds")

         # Use adjusted time values
         total_opt_time = adjusted_opt_time
         time_per_step = adjusted_time_per_step
      else:
         print(f"First step time ({time_per_step[0]:.4f}s) appears normal, no adjustment needed")

   # Create final compatible state
   final_compat_state = create_compatible_state(state)

   # Return results
   return final_compat_state, states_history, losses, total_opt_time, time_per_step, current_resample_counter, state