"""
Training utilities for regression example.
"""

import jax
import jax.numpy as jnp
import numpy as np
import time
from typing import List, Dict, Any, Tuple, Optional

from foo.models import MSEFlax
from foo.optimizers import SGDOptimizer, CGOptimizer, NLTGCROptimizer

def create_optimizer(name, state, config, model, loss_fn_pde=MSEFlax, loss_fn_bc=MSEFlax, weights_pde=1.0, weights_bc=1.0):
   """
   Create an optimizer instance based on name and configuration.
   
   Args:
      name: Optimizer name (SGD, Krylov, Nltgcr)
      state: Model state
      config: Optimizer configuration
      model: Model instance
      loss_fn: Loss function
      
   Returns:
      optimizer: Configured optimizer instance
   """
   if name.startswith("SGD"):
      return SGDOptimizer(state, config, model=model, loss_fn=[loss_fn_pde, loss_fn_bc], loss_weights=[weights_pde, weights_bc])
   elif name.startswith("Adam"):
      # We now use Optax directly in run_test
      # For direct calls to Adam, raise an exception
      raise ValueError(f"Direct call to create_optimizer with 'Adam' should be avoided. Use Optax Adam implementation instead.")
   elif name.startswith("CG"):
      return CGOptimizer(state, config, model=model, loss_fn=[loss_fn_pde, loss_fn_bc], loss_weights=[weights_pde, weights_bc])
   elif name.startswith("Nltgcr"):
      return NLTGCROptimizer(state, config, model=model, loss_fn=[loss_fn_pde, loss_fn_bc], loss_weights=[weights_pde, weights_bc])
   else:
      raise ValueError(f"Unknown optimizer: {name}")

def train_optimizer(optimizer, state, data, n_batches, epochs, resample_freq=0, n_boundary_batch=0, n_interior_batch=0,
               weight_pde=1.0, weight_bc=1.0, base_seed=42, resample_counter=0, print_every=100,
               phase_name="", jax_device=None, dtype=jnp.float32, shuffle_seed=None):
   """
   Train with a specific optimizer and return history, with optional periodic data resampling.

   Args:
      optimizer: Optimizer instance to use for training
      state: Initial state of the model
      data: Tuple containing (X_train, y_train)
      batch_size: Size of mini-batches
      epochs: Number of epochs to train for
      resample_freq: How often to resample the dataset (0 means no resampling)
      n_samples: Number of samples to generate when resampling
      noise_level: Noise level to use when generating new data
      base_seed: Base random seed for resampling
      resample_counter: Current counter for resampling to ensure continuity between phases
      print_every: How often to print progress
      phase_name: Name of the training phase for logging
      jax_device: Device to place data on
      dtype: Data type for arrays
      shuffle_seed: Random seed for data shuffling (for reproducibility)

   Returns:
      states_history: List of model states during training
      losses: List of average epoch losses
      training_time: Total training time in seconds (optimization only, excluding resampling)
      time_per_step: List of per-step optimization times
      resampling_counter: Updated resampling counter for the next phase
   """
   # Import here to avoid circular imports
   from data import generate_data

   X_train, y_train = data
   X_interior, X_boundary = X_train
   f_interior, u_boundary = y_train
   n_interior_current = len(X_interior)
   n_boundary_current = len(X_boundary)
   n_interior = n_interior_batch * n_batches
   n_boundary = n_boundary_batch * n_batches
   states_history = [state]
   losses_both = []
   losses_pde = []
   losses_bc = []
   time_per_step = []  # Track time for each optimization step

   # Track optimization time separately from resampling time
   total_opt_time = 0.0

   # Use the passed resample_counter to maintain continuity between phases
   current_resample_counter = resample_counter

   for epoch in range(epochs):
      # Check if resampling is needed
      if resample_freq > 0 and epoch > 0 and epoch % resample_freq == 0:
         # Increment counter for resampling event
         current_resample_counter += 1

         # Deterministically generate new seed from base seed and current resample counter
         resample_seed = base_seed + 10000 * current_resample_counter

         # Generate new data
         interior_points, boundary_points, f_interior, u_boundary = generate_data(
            n_interior,
            n_boundary,
            resample_seed
         )

         # Convert to JAX arrays
         X_interior = jnp.array(interior_points, dtype=dtype)
         X_boundary = jnp.array(boundary_points, dtype=dtype)
         f_interior = jnp.array(f_interior, dtype=dtype)
         u_boundary = jnp.array(u_boundary, dtype=dtype)
         
         # Move to device if needed
         if jax_device:
            X_interior = jax.device_put(X_interior, jax_device)
            X_boundary = jax.device_put(X_boundary, jax_device)
            f_interior = jax.device_put(f_interior, jax_device)
            u_boundary = jax.device_put(u_boundary, jax_device)
         
         # Update current sample count
         n_interior_current = len(X_interior)
         n_boundary_current = len(X_boundary)

      # Start timing optimization (after potential resampling)
      # Ensure all previous GPU operations are complete before starting the timer
      if jax_device and hasattr(jax_device, 'type') and jax_device.type == 'gpu':
         jax.block_until_ready(state)

      opt_start_time = time.time()
      # Shuffle data with deterministic seed
      if shuffle_seed is not None:
         # Combine the provided shuffle_seed with current epoch for reproducibility
         epoch_shuffle_seed = shuffle_seed + epoch
         rng = np.random.RandomState(epoch_shuffle_seed)
         indices_interior = rng.permutation(n_interior_current)
         indices_boundary = rng.permutation(n_boundary_current)
      else:
         # Use traditional permutation
         indices_interior = np.random.permutation(n_interior_current)
         indices_boundary = np.random.permutation(n_boundary_current)

      x_train_shuffled_interior = X_interior[indices_interior]
      x_train_shuffled_boundary = X_boundary[indices_boundary]
      y_train_shuffled_interior = f_interior[indices_interior]
      y_train_shuffled_boundary = u_boundary[indices_boundary]

      epoch_losses = []
      epoch_losses_pde = []
      epoch_losses_bc = []

      # Batch training
      for i in range(n_batches):
         batch_x_interior = x_train_shuffled_interior[i*n_interior_batch:(i+1)*n_interior_batch]
         batch_x_boundary = x_train_shuffled_boundary[i*n_boundary_batch:(i+1)*n_boundary_batch]
         batch_y_interior = y_train_shuffled_interior[i*n_interior_batch:(i+1)*n_interior_batch]
         batch_y_boundary = y_train_shuffled_boundary[i*n_boundary_batch:(i+1)*n_boundary_batch]

         # Perform optimization step
         state, loss, _, losses, _ = optimizer.step(state, ((batch_x_interior, batch_x_boundary), (batch_y_interior, batch_y_boundary)))
         epoch_losses.append(float(loss))
         epoch_losses_pde.append(float(losses[0]))
         epoch_losses_bc.append(float(losses[1]))

      # Record average training loss
      train_loss = np.mean(epoch_losses)
      train_loss_pde = np.mean(epoch_losses_pde)
      train_loss_bc = np.mean(epoch_losses_bc)
      losses_both.append(train_loss)
      losses_pde.append(train_loss_pde)
      losses_bc.append(train_loss_bc)
      states_history.append(state)

      # Ensure all GPU operations are complete before stopping the timer
      if jax_device and hasattr(jax_device, 'type') and jax_device.type == 'gpu':
         jax.block_until_ready(state)

      # Calculate and record optimization time for this epoch
      epoch_opt_time = time.time() - opt_start_time
      total_opt_time += epoch_opt_time
      time_per_step.append(total_opt_time)  # Cumulative time at each step

      if (epoch + 1) % print_every == 0:
         print(f"{phase_name} - Epoch {epoch+1}/{epochs}, Loss: {train_loss:.6f}, PDE Loss: {train_loss_pde:.6f}, BC Loss: {train_loss_bc:.6f}")

   print(f"{phase_name} optimization time: {total_opt_time:.2f} seconds")

   # Adjust the first step time to exclude JIT compilation overhead
   if len(time_per_step) > 1:
      # Calculate average time per step excluding the first step
      avg_step_time = (time_per_step[-1] - time_per_step[0]) / (len(time_per_step) - 1)

      # Replace the first step's time with the average time (ensure it's not negative)
      if avg_step_time > 0 and time_per_step[0] > avg_step_time:
         time_adjustment = time_per_step[0] - avg_step_time

         # Adjust all time_per_step values
         adjusted_time_per_step = [max(0.0, t - time_adjustment) for t in time_per_step]

         # Adjust total optimization time
         adjusted_opt_time = max(0.0, total_opt_time - time_adjustment)

         print(f"Adjusted first step time: removed {time_adjustment:.4f}s of JIT compilation overhead")
         print(f"{phase_name} adjusted optimization time: {adjusted_opt_time:.2f} seconds")
         
      else:
         print(f"First step time ({time_per_step[0]:.4f}s) appears normal, no adjustment needed")

   # Return the updated counter along with other results (no adjustment for 0-1 steps)
   return states_history, losses_both, losses_pde, losses_bc, total_opt_time, time_per_step, current_resample_counter