"""
Optax Adam optimizer implementation for Poisson PDE example.
"""

import jax
import jax.numpy as jnp
from functools import partial
import numpy as np
import time
import optax

from .common import (
    OptaxTrainState, 
    optax_mse_loss, 
    create_compatible_state,
    save_optimizer_state,
    load_optimizer_state
)

# Single step function for Optax Adam training
@partial(jax.jit, static_argnums=(2, 3))
def optax_adam_train_step(state, batch, weight_pde=1.0, weight_bc=1.0):
    """Train for a single step with Optax Adam optimizer."""
    x, y = batch
    y_interior, y_boundary = y

    def loss_fn(params):
        if state.batch_stats is None:
            u_i, u_b = state.apply_fn({'params': params}, x)
            loss_i = weight_pde * optax_mse_loss(u_i, y_interior)
            loss_b = weight_bc * optax_mse_loss(u_b, y_boundary)
            return loss_i + loss_b, ((u_i, u_b), (loss_i, loss_b), None)
        else:
            (logits, updated_batch_stats) = state.apply_fn(
                {'params': params, 'batch_stats': state.batch_stats},
                x, mutable=['batch_stats']
            )
            u_i, u_b = logits
            loss_i = weight_pde * optax_mse_loss(u_i, y_interior)
            loss_b = weight_bc * optax_mse_loss(u_b, y_boundary)
            return loss_i + loss_b, ((u_i, u_b), (loss_i, loss_b), updated_batch_stats)

    (loss, (logits, losses, updated_batch_stats)), grads = jax.value_and_grad(
        loss_fn, has_aux=True)(state.params)

    # Create updated state
    new_state = state.apply_gradients(grads=grads)

    # Update batch statistics if available
    if updated_batch_stats is not None:
        new_state = new_state.replace(batch_stats=updated_batch_stats['batch_stats'])

    return new_state, loss, losses


def save_optax_adam_state(state, path, filename):
    """Save Optax Adam optimizer state to file."""
    return save_optimizer_state(state, path, filename)


def load_optax_adam_state(path, filename, apply_fn, learning_rate, beta1=0.9, beta2=0.999, eps=1e-8):
    """Load Optax Adam optimizer state from file."""
    def create_adam(learning_rate):
        return optax.adam(
            learning_rate=learning_rate,
            b1=beta1,
            b2=beta2,
            eps=eps
        )
    return load_optimizer_state(path, filename, apply_fn, learning_rate, create_adam)


def train_with_optax_adam(
    initial_params,
    apply_fn,
    batch_stats,
    data,
    n_batches,
    epochs,
    learning_rate=0.001,
    beta1=0.9,
    beta2=0.999,
    eps=1e-8,
    resample_freq=0,
    n_interior_batch=0,
    n_boundary_batch=0,
    weight_pde=1.0,
    weight_bc=1.0,
    base_seed=42,
    resample_counter=0,
    print_every=100,
    phase_name="Optax Adam",
    jax_device=None,
    shuffle_seed=None,
    current_optax_state=None,
    dtype=jnp.float32
):
    """
    Train a model using Optax Adam optimizer for the Poisson PDE problem.

    Args:
        initial_params: Initial model parameters
        apply_fn: Model application function
        batch_stats: Initial batch statistics (if any)
        data: Training data (X_train, y_train)
        n_batches: Number of batches for training
        epochs: Number of epochs to train
        learning_rate: Learning rate for Adam
        beta1: Beta1 parameter for Adam
        beta2: Beta2 parameter for Adam
        eps: Epsilon parameter for Adam
        resample_freq: How often to resample data (0 for no resampling)
        n_interior_batch: Interior batch size
        n_boundary_batch: Boundary batch size
        weight_pde: Weight for PDE loss
        weight_bc: Weight for boundary condition loss
        base_seed: Base random seed
        resample_counter: Counter for resampling
        print_every: How often to print progress
        phase_name: Name of the training phase
        jax_device: JAX device to use
        shuffle_seed: Seed for data shuffling
        current_optax_state: Optional existing optimizer state
        dtype: Data type for arrays

    Returns:
        compatible_train_state: State compatible with foo.models.TrainState
        states_history: History of states during training
        losses_both: Combined loss history
        losses_pde: PDE loss history
        losses_bc: Boundary condition loss history
        training_time: Total optimization time (excluding resampling)
        time_per_step: List of cumulative optimization times per step
        resample_counter: Resample counter for next phase
        optax_state: Optax training state for saving and restoring
    """
    # Import here to avoid circular imports
    from data import generate_data

    X_train, y_train = data
    X_interior, X_boundary = X_train
    f_interior, u_boundary = y_train
    n_interior_current = len(X_interior)
    n_boundary_current = len(X_boundary)
    n_interior = n_interior_batch * n_batches
    n_boundary = n_boundary_batch * n_batches

    if current_optax_state is not None:
        state = current_optax_state
    else:
        # Create Optax Adam optimizer and training state
        tx = optax.adam(
            learning_rate=learning_rate,
            b1=beta1,
            b2=beta2,
            eps=eps
        )

        # Create Optax training state with initial parameters and batch stats
        state = OptaxTrainState.create(
            apply_fn=apply_fn,
            params=initial_params,
            tx=tx,
            batch_stats=batch_stats
        )

    # Initialize state history and loss records
    initial_compat_state = create_compatible_state(state)
    states_history = [initial_compat_state]
    losses_both = []
    losses_pde = []
    losses_bc = []
    time_per_step = []  # Track time for each optimization step

    # Track optimization time separately from resampling time
    total_opt_time = 0.0

    # Use the passed resample counter to maintain continuity between phases
    current_resample_counter = resample_counter

    for epoch in range(epochs):
        # If resampling is needed
        if resample_freq > 0 and epoch > 0 and epoch % resample_freq == 0:
            # Increment counter for resampling event
            current_resample_counter += 1

            # Deterministically generate new seed from base seed and current resample counter
            resample_seed = base_seed + 10000 * current_resample_counter

            # Generate new data
            interior_points, boundary_points, f_interior, u_boundary = generate_data(
                n_interior,
                n_boundary,
                resample_seed
            )

            # Convert to JAX arrays
            X_interior = jnp.array(interior_points, dtype=dtype)
            X_boundary = jnp.array(boundary_points, dtype=dtype)
            f_interior = jnp.array(f_interior, dtype=dtype)
            u_boundary = jnp.array(u_boundary, dtype=dtype)
            
            # Move to device if needed
            if jax_device:
                X_interior = jax.device_put(X_interior, jax_device)
                X_boundary = jax.device_put(X_boundary, jax_device)
                f_interior = jax.device_put(f_interior, jax_device)
                u_boundary = jax.device_put(u_boundary, jax_device)
            
            # Update current sample count
            n_interior_current = len(X_interior)
            n_boundary_current = len(X_boundary)

        # Start timing optimization (after potential resampling)
        # Ensure all previous GPU operations are complete before starting the timer
        if jax_device and hasattr(jax_device, 'type') and jax_device.type == 'gpu':
            jax.block_until_ready(state)

        opt_start_time = time.time()

        # Shuffle data with deterministic seed
        if shuffle_seed is not None:
            # Combine the provided shuffle_seed with current epoch for reproducibility
            epoch_shuffle_seed = shuffle_seed + epoch
            rng = np.random.RandomState(epoch_shuffle_seed)
            indices_interior = rng.permutation(n_interior_current)
            indices_boundary = rng.permutation(n_boundary_current)
        else:
            # Use traditional permutation
            indices_interior = np.random.permutation(n_interior_current)
            indices_boundary = np.random.permutation(n_boundary_current)

        x_train_shuffled_interior = X_interior[indices_interior]
        x_train_shuffled_boundary = X_boundary[indices_boundary]
        y_train_shuffled_interior = f_interior[indices_interior]
        y_train_shuffled_boundary = u_boundary[indices_boundary]

        epoch_losses = []
        epoch_losses_pde = []
        epoch_losses_bc = []

        # Batch training
        for i in range(n_batches):
            batch_x_interior = x_train_shuffled_interior[i*n_interior_batch:(i+1)*n_interior_batch]
            batch_x_boundary = x_train_shuffled_boundary[i*n_boundary_batch:(i+1)*n_boundary_batch]
            batch_y_interior = y_train_shuffled_interior[i*n_interior_batch:(i+1)*n_interior_batch]
            batch_y_boundary = y_train_shuffled_boundary[i*n_boundary_batch:(i+1)*n_boundary_batch]

            # Perform optimization step
            state, loss, losses = optax_adam_train_step(
                state, 
                ((batch_x_interior, batch_x_boundary), (batch_y_interior, batch_y_boundary)), 
                weight_pde=weight_pde, 
                weight_bc=weight_bc
            )
            epoch_losses.append(float(loss))
            epoch_losses_pde.append(float(losses[0]))
            epoch_losses_bc.append(float(losses[1]))

        # Record average training loss
        train_loss = np.mean(epoch_losses)
        train_loss_pde = np.mean(epoch_losses_pde)
        train_loss_bc = np.mean(epoch_losses_bc)
        losses_both.append(train_loss)
        losses_pde.append(train_loss_pde)
        losses_bc.append(train_loss_bc)

        # Convert current Optax state to compatible state and add to history
        compat_state = create_compatible_state(state)
        states_history.append(compat_state)

        # Ensure all GPU operations are complete before stopping the timer
        if jax_device and hasattr(jax_device, 'type') and jax_device.type == 'gpu':
            jax.block_until_ready(compat_state)

        # Calculate and record optimization time for this epoch
        epoch_opt_time = time.time() - opt_start_time
        total_opt_time += epoch_opt_time
        time_per_step.append(total_opt_time)  # Cumulative time at each step

        if (epoch + 1) % print_every == 0:
            print(f"{phase_name} - Epoch {epoch+1}/{epochs}, Loss: {train_loss:.6f}, PDE Loss: {train_loss_pde:.6f}, BC Loss: {train_loss_bc:.6f}")

    print(f"{phase_name} optimization time: {total_opt_time:.2f} seconds")

    # Adjust the first step time to exclude JIT compilation overhead
    if len(time_per_step) > 1:
        # Calculate average time per step excluding the first step
        avg_step_time = (time_per_step[-1] - time_per_step[0]) / (len(time_per_step) - 1)

        # Replace the first step's time with the average time (ensure it's not negative)
        if avg_step_time > 0 and time_per_step[0] > avg_step_time:
            time_adjustment = time_per_step[0] - avg_step_time

            # Adjust all time_per_step values
            adjusted_time_per_step = [max(0.0, t - time_adjustment) for t in time_per_step]

            # Adjust total optimization time
            adjusted_opt_time = max(0.0, total_opt_time - time_adjustment)

            print(f"Adjusted first step time: removed {time_adjustment:.4f}s of JIT compilation overhead")
            print(f"{phase_name} adjusted optimization time: {adjusted_opt_time:.2f} seconds")

            # Use adjusted time values
            total_opt_time = adjusted_opt_time
            time_per_step = adjusted_time_per_step
        else:
            print(f"First step time ({time_per_step[0]:.4f}s) appears normal, no adjustment needed")

    # Create final compatible state
    final_compat_state = create_compatible_state(state)

    # Return results
    return final_compat_state, states_history, losses_both, losses_pde, losses_bc, total_opt_time, time_per_step, current_resample_counter, state