"""
Optax Adam optimizer implementation for regression example.
"""

import jax
import jax.numpy as jnp
import numpy as np
import time
import pickle
from pathlib import Path
from typing import Dict, Tuple, List, Any, Optional
import optax
from flax.training import train_state
from foo.models import TrainState

# Define Flax-style TrainState specifically for Optax's Adam optimizer
class OptaxAdamTrainState(train_state.TrainState):
   """TrainState with batch stats for normalization layers."""
   batch_stats: Optional[Any] = None

   @classmethod
   def create(cls, *, apply_fn, params, tx, batch_stats=None):
      """Creates a new instance with `step=0`."""
      return cls(
         step=0,
         apply_fn=apply_fn,
         params=params,
         tx=tx,
         opt_state=tx.init(params),
         batch_stats=batch_stats,
      )

# MSE loss function for Optax Adam training workflow
def optax_mse_loss(logits, targets):
   """Mean squared error loss for regression."""
   return jnp.mean((logits - targets) ** 2)

# Single step function for Optax Adam training
@jax.jit
def optax_adam_train_step(state, batch):
   """Train for a single step with Optax Adam optimizer."""
   x, y = batch

   def loss_fn(params):
      if state.batch_stats is None:
         logits = state.apply_fn({'params': params}, x)
         loss = optax_mse_loss(logits, y)
         return loss, (logits, None)
      else:
         (logits, updated_batch_stats) = state.apply_fn(
               {'params': params, 'batch_stats': state.batch_stats},
               x, mutable=['batch_stats']
         )
         loss = optax_mse_loss(logits, y)
         return loss, (logits, updated_batch_stats)

   (loss, (logits, updated_batch_stats)), grads = jax.value_and_grad(
      loss_fn, has_aux=True)(state.params)

   # Create updated state
   new_state = state.apply_gradients(grads=grads)

   # Update batch statistics if available
   if updated_batch_stats is not None:
      new_state = new_state.replace(batch_stats=updated_batch_stats['batch_stats'])

   return new_state, loss, logits

# Create state compatible with foo.models.TrainState for interface consistency
def create_compatible_state(optax_state):
   """Create state compatible with foo.models.TrainState."""
   return TrainState(
      step=0,  # Use fixed step 0 to maintain interface consistency
      apply_fn=optax_state.apply_fn,
      params=optax_state.params,
      tx=None,  # TrainState requires tx
      opt_state=None,  # TrainState requires opt_state
      rngs=None,  # TrainState requires rngs
      batch_stats=optax_state.batch_stats
   )

# Save Optax Adam training state to file
def save_optax_adam_state(state, path, filename):
   """Save Optax Adam optimizer state to file."""
   path = Path(path)
   path.mkdir(parents=True, exist_ok=True)

   # Convert JAX arrays to numpy arrays for reliable serialization
   def convert_jax_to_numpy(x):
      if hasattr(x, 'dtype') and hasattr(x, 'shape'):
         return np.array(x)
      return x

   # Process state for serialization
   state_dict = {
      'step': int(state.step),
      'params': jax.tree_util.tree_map(convert_jax_to_numpy, state.params),
      'opt_state': jax.tree_util.tree_map(convert_jax_to_numpy, state.opt_state),
   }

   if state.batch_stats is not None:
      state_dict['batch_stats'] = jax.tree_util.tree_map(
         convert_jax_to_numpy, state.batch_stats)

   with open(path / filename, 'wb') as f:
      pickle.dump(state_dict, f)

   print(f"Successfully saved Optax Adam state to {path / filename}")
   return True

# Load Optax Adam training state from file
def load_optax_adam_state(path, filename, apply_fn, learning_rate, beta1=0.9, beta2=0.999, eps=1e-8):
   """Load Optax Adam training state from file."""
   try:
      with open(Path(path) / filename, 'rb') as f:
         state_dict = pickle.load(f)

      # Convert numpy arrays back to JAX arrays
      def convert_numpy_to_jax(x):
         if hasattr(x, 'dtype') and hasattr(x, 'shape'):
               return jnp.array(x)
         return x

      # Reconstruct training state
      params = jax.tree_util.tree_map(convert_numpy_to_jax, state_dict['params'])
      opt_state = jax.tree_util.tree_map(convert_numpy_to_jax, state_dict['opt_state'])

      # Create Optax optimizer
      tx = optax.adam(
         learning_rate=learning_rate,
         b1=beta1,
         b2=beta2,
         eps=eps
      )

      # Create new state
      new_state = OptaxAdamTrainState(
         step=state_dict['step'],
         apply_fn=apply_fn,
         params=params,
         tx=tx,
         opt_state=opt_state,
         batch_stats=None
      )

      # Load batch statistics if available
      if 'batch_stats' in state_dict and state_dict['batch_stats'] is not None:
         batch_stats = jax.tree_util.tree_map(
               convert_numpy_to_jax, state_dict['batch_stats'])
         new_state = new_state.replace(batch_stats=batch_stats)

      print(f"Successfully loaded Optax Adam state from {Path(path) / filename}")
      return new_state, True
   except Exception as e:
      print(f"Error loading Optax Adam state: {e}")
      import traceback
      traceback.print_exc()
      return None, False

# Function to train model with Optax Adam
def train_with_optax_adam(
   initial_params,
   apply_fn,
   batch_stats,
   data,
   batch_size,
   epochs,
   learning_rate=0.001,
   beta1=0.9,
   beta2=0.999,
   eps=1e-8,
   resample_freq=0,
   n_samples=0,
   noise_level=0.1,
   base_seed=42,
   resample_counter=0,
   print_every=100,
   phase_name="Optax Adam",
   jax_device=None,
   shuffle_seed=None,
   current_optax_state=None,
   dtype=jnp.float32
):
   """
   Train a model using Optax Adam optimizer, following Flax official example style.

   Args:
      initial_params: Initial model parameters
      apply_fn: Model application function
      batch_stats: Initial batch statistics (if any)
      data: Training data (X_train, y_train)
      batch_size: Batch size for training
      epochs: Number of epochs to train
      learning_rate: Learning rate for Adam
      beta1: Beta1 parameter for Adam
      beta2: Beta2 parameter for Adam
      eps: Epsilon parameter for Adam
      resample_freq: How often to resample data (0 for no resampling)
      n_samples: Number of samples to generate when resampling
      noise_level: Noise level for data generation
      base_seed: Base random seed
      resample_counter: Counter for resampling
      print_every: How often to print progress
      phase_name: Name of the training phase
      jax_device: JAX device to use
      shuffle_seed: Seed for data shuffling
      dtype: Data type for arrays

   Returns:
      compatible_train_state: State compatible with foo.models.TrainState
      states_history: History of states during training
      losses: Training loss history
      training_time: Total optimization time (excluding resampling)
      time_per_step: List of cumulative optimization times per step
      resample_counter: Resample counter for next phase
      optax_state: Optax training state for saving and restoring
   """
   # Import here to avoid circular imports
   from data import generate_data

   X_train, y_train = data
   n_samples_current = len(X_train)

   if current_optax_state is not None:
      state = current_optax_state
   else:
      # Create Optax Adam optimizer and training state
      tx = optax.adam(
         learning_rate=learning_rate,
         b1=beta1,
         b2=beta2,
         eps=eps
      )

      # Create Optax training state with initial parameters and batch stats
      state = OptaxAdamTrainState.create(
         apply_fn=apply_fn,
         params=initial_params,
         tx=tx,
         batch_stats=batch_stats
      )

   # Initialize state history and loss records
   initial_compat_state = create_compatible_state(state)
   states_history = [initial_compat_state]
   losses = []
   time_per_step = []  # Track time for each optimization step

   # Track optimization time separately from resampling time
   total_opt_time = 0.0

   # Use the passed resample counter to maintain continuity between phases
   current_resample_counter = resample_counter

   for epoch in range(epochs):
      # If resampling is needed
      if resample_freq > 0 and epoch > 0 and epoch % resample_freq == 0:
         # Increment counter for resampling event
         current_resample_counter += 1

         # Deterministically generate new seed from base seed and current resample counter
         resample_seed = base_seed + 10000 * current_resample_counter

         # Generate new data
         x_train, y_train, z_train = generate_data(
               n_samples, noise_level, resample_seed
         )

         # Convert to JAX arrays and move to device
         X_train = jnp.array(np.column_stack((x_train, y_train)), dtype=dtype)
         y_train = jnp.array(z_train, dtype=dtype).reshape(-1, 1)

         if jax_device:
               X_train = jax.device_put(X_train, jax_device)
               y_train = jax.device_put(y_train, jax_device)

         # Update current sample count
         n_samples_current = len(X_train)

      # Start timing optimization (after potential resampling)
      # Ensure all previous GPU operations are complete before starting the timer
      if jax_device and hasattr(jax_device, 'type') and jax_device.type == 'gpu':
         jax.block_until_ready(state)

      opt_start_time = time.time()

      # Shuffle data with deterministic seed
      if shuffle_seed is not None:
         # Combine the provided shuffle_seed with current epoch for reproducibility
         epoch_shuffle_seed = shuffle_seed + epoch
         rng = np.random.RandomState(epoch_shuffle_seed)
         indices = rng.permutation(n_samples_current)
      else:
         # Use traditional permutation
         indices = np.random.permutation(n_samples_current)

      x_train_shuffled = X_train[indices]
      y_train_shuffled = y_train[indices]

      epoch_losses = []

      # Batch training
      for i in range(0, n_samples_current, batch_size):
         batch_x = x_train_shuffled[i:i+batch_size]
         batch_y = y_train_shuffled[i:i+batch_size]

         # Perform optimization step
         state, loss, _ = optax_adam_train_step(state, (batch_x, batch_y))
         epoch_losses.append(loss)

      # Record average training loss
      train_loss = np.mean(epoch_losses)
      losses.append(train_loss)

      # Convert current Optax state to compatible state and add to history
      compat_state = create_compatible_state(state)
      states_history.append(compat_state)

      # Ensure all GPU operations are complete before stopping the timer
      if jax_device and hasattr(jax_device, 'type') and jax_device.type == 'gpu':
         jax.block_until_ready(compat_state)

      # Calculate and record optimization time for this epoch
      epoch_opt_time = time.time() - opt_start_time
      total_opt_time += epoch_opt_time
      time_per_step.append(total_opt_time)  # Cumulative time at each step

      if (epoch + 1) % print_every == 0:
         print(f"{phase_name} - Epoch {epoch+1}/{epochs}, Loss: {train_loss:.6f}")

   print(f"{phase_name} optimization time: {total_opt_time:.2f} seconds")

   # Adjust the first step time to exclude JIT compilation overhead
   if len(time_per_step) > 1:
      # Calculate average time per step excluding the first step
      avg_step_time = (time_per_step[-1] - time_per_step[0]) / (len(time_per_step) - 1)

      # Replace the first step's time with the average time (ensure it's not negative)
      if avg_step_time > 0 and time_per_step[0] > avg_step_time:
         time_adjustment = time_per_step[0] - avg_step_time

         # Adjust all time_per_step values
         adjusted_time_per_step = [max(0.0, t - time_adjustment) for t in time_per_step]

         # Adjust total optimization time
         adjusted_opt_time = max(0.0, total_opt_time - time_adjustment)

         print(f"Adjusted first step time: removed {time_adjustment:.4f}s of JIT compilation overhead")
         print(f"{phase_name} adjusted optimization time: {adjusted_opt_time:.2f} seconds")

         # Use adjusted time values
         total_opt_time = adjusted_opt_time
         time_per_step = adjusted_time_per_step
      else:
         print(f"First step time ({time_per_step[0]:.4f}s) appears normal, no adjustment needed")

   # Create final compatible state
   final_compat_state = create_compatible_state(state)

   # Return results
   return final_compat_state, states_history, losses, total_opt_time, time_per_step, current_resample_counter, state