"""
Optax L-BFGS optimizer implementation for Poisson PDE example.
"""

import jax
import jax.numpy as jnp
from functools import partial
import numpy as np
import time
import optax

from .common import (
    OptaxTrainState, 
    optax_mse_loss, 
    create_compatible_state,
    save_optimizer_state
)

# Single step function for Optax L-BFGS training
@partial(jax.jit, static_argnums=(2, 3))
def optax_lbfgs_train_step(state, batch, weight_pde=1.0, weight_bc=1.0):
    """Train for a single step with Optax L-BFGS optimizer."""
    x, y = batch
    y_interior, y_boundary = y
    
    # Create value function for L-BFGS optimizer
    def value_fn(params):
        """Value function for L-BFGS that computes loss for given parameters."""
        # Forward pass
        if state.batch_stats is not None:
            u_i, u_b = state.apply_fn(
                {'params': params, 'batch_stats': state.batch_stats},
                x,
                mutable=['batch_stats']
            )[0]
        else:
            u_i, u_b = state.apply_fn({'params': params}, x)

        # Calculate combined loss
        loss_i = weight_pde * optax_mse_loss(u_i, y_interior)
        loss_b = weight_bc * optax_mse_loss(u_b, y_boundary)
        return loss_i + loss_b

    # Loss function that also returns batch stats (for gradient calculation)
    def loss_fn(params):
        if state.batch_stats is None:
            u_i, u_b = state.apply_fn({'params': params}, x)
            loss_i = weight_pde * optax_mse_loss(u_i, y_interior)
            loss_b = weight_bc * optax_mse_loss(u_b, y_boundary)
            return loss_i + loss_b, ((u_i, u_b), (loss_i, loss_b), None)
        else:
            (logits, updated_batch_stats) = state.apply_fn(
                {'params': params, 'batch_stats': state.batch_stats},
                x, mutable=['batch_stats']
            )
            u_i, u_b = logits
            loss_i = weight_pde * optax_mse_loss(u_i, y_interior)
            loss_b = weight_bc * optax_mse_loss(u_b, y_boundary)
            return loss_i + loss_b, ((u_i, u_b), (loss_i, loss_b), updated_batch_stats)

    # Create value_and_grad function
    def value_and_grad_fn(params):
        (loss, (_, losses, updated_batch_stats)), grad = jax.value_and_grad(
            loss_fn, has_aux=True)(params)
        return loss, grad, losses, updated_batch_stats

    # Compute loss and gradient for current parameters
    loss, grad, losses, updated_batch_stats = value_and_grad_fn(state.params)

    # Use L-BFGS optimizer to compute updates
    updates, new_opt_state = state.tx.update(
        grad,
        state.opt_state,
        params=state.params,
        value=loss,
        value_fn=value_fn,
        grad=grad
    )

    # Apply updates to parameters
    new_params = optax.apply_updates(state.params, updates)

    # Create new state
    new_state = state.replace(
        params=new_params,
        opt_state=new_opt_state
    )

    # Update batch statistics if available
    if updated_batch_stats is not None:
        new_state = new_state.replace(batch_stats=updated_batch_stats['batch_stats'])

    # Extract individual losses
    loss_pde, loss_bc = losses

    return new_state, loss, (loss_pde, loss_bc)


def save_optax_lbfgs_state(state, path, filename):
    """Save Optax L-BFGS optimizer state to file."""
    return save_optimizer_state(state, path, filename)


def load_optax_lbfgs_state(path, filename, apply_fn, 
                         max_linesearch_steps=20,
                         memory_size=10,
                         scale_init_precond=True,
                         sufficient_decrease=0.1,
                         curvature=0.9,
                         initial_guess_strategy='one'):
    """
    Load Optax L-BFGS optimizer state from file.
    
    Note: This function is not implemented as L-BFGS states aren't typically saved/restored
    between runs due to their history-dependent nature.
    """
    print("Warning: Loading L-BFGS state is not fully supported. Creating new optimizer.")
    return None, False


def train_with_optax_lbfgs(
    initial_params,
    apply_fn,
    batch_stats,
    data,
    n_batches,
    epochs,
    max_linesearch_steps=20,
    memory_size=10,
    scale_init_precond=True,
    sufficient_decrease=0.1,
    curvature=0.9,
    initial_guess_strategy='one',
    resample_freq=0,
    n_boundary_batch=0,
    n_interior_batch=0,
    weight_pde=1.0,
    weight_bc=1.0,
    base_seed=42,
    resample_counter=0,
    print_every=100,
    phase_name="Optax L-BFGS",
    jax_device=None,
    shuffle_seed=None,
    dtype=jnp.float32
):
    """
    Train a model using Optax L-BFGS optimizer for Poisson PDE.

    Args:
        initial_params: Initial model parameters
        apply_fn: Model application function
        batch_stats: Initial batch statistics (if any)
        data: Training data (X_train, y_train)
        n_batches: Number of batches for training
        epochs: Number of epochs to train
        max_linesearch_steps: Maximum number of line search steps
        memory_size: Size of memory buffer for L-BFGS approximation
        scale_init_precond: Whether to scale initial preconditioner
        sufficient_decrease: Sufficient decrease parameter
        curvature: Curvature parameter
        initial_guess_strategy: Strategy for initial step size
        resample_freq: How often to resample data
        n_boundary_batch: Boundary batch size
        n_interior_batch: Interior batch size
        weight_pde: Weight for PDE loss
        weight_bc: Weight for boundary condition loss
        base_seed: Base random seed
        resample_counter: Counter for resampling
        print_every: How often to print progress
        phase_name: Name of the training phase
        jax_device: JAX device to use
        shuffle_seed: Seed for data shuffling
        dtype: Data type for arrays

    Returns:
        compatible_train_state: State compatible with foo.models.TrainState
        states_history: History of states during training
        losses_both: Combined loss history
        losses_pde: PDE loss history
        losses_bc: Boundary condition loss history
        training_time: Total optimization time (excluding resampling)
        time_per_step: List of cumulative optimization times per step
        resample_counter: Resample counter for next phase
        optax_state: Optax training state for saving and restoring
    """
    # Import here to avoid circular imports
    from data import generate_data

    X_train, y_train = data
    X_interior, X_boundary = X_train
    f_interior, u_boundary = y_train
    n_interior_current = len(X_interior)
    n_boundary_current = len(X_boundary)
    n_interior = n_interior_batch * n_batches
    n_boundary = n_boundary_batch * n_batches

    # Create Optax L-BFGS optimizer with all parameters
    tx = optax.lbfgs(
        memory_size=memory_size,
        linesearch=optax.scale_by_zoom_linesearch(
            max_linesearch_steps=max_linesearch_steps,
            slope_rtol=sufficient_decrease, # Armijo-Goldstein sufficient decrease parameter
            curv_rtol=curvature,           # Wolfe curvature condition parameter
            initial_guess_strategy=initial_guess_strategy,
            verbose=False
        ),
        scale_init_precond=scale_init_precond
    )

    # Create Optax training state with initial parameters and batch stats
    state = OptaxTrainState.create(
        apply_fn=apply_fn,
        params=initial_params,
        tx=tx,
        batch_stats=batch_stats
    )

    # Initialize state history and loss records
    initial_compat_state = create_compatible_state(state)
    states_history = [initial_compat_state]
    losses_both = []
    losses_pde = []
    losses_bc = []
    time_per_step = []  # Track time for each optimization step

    # Track optimization time separately from resampling time
    total_opt_time = 0.0

    # Use the passed resample counter to maintain continuity between phases
    current_resample_counter = resample_counter

    for epoch in range(epochs):
        # If resampling is needed
        if resample_freq > 0 and epoch > 0 and epoch % resample_freq == 0:
            # Increment counter for resampling event
            current_resample_counter += 1

            # Deterministically generate new seed from base seed and current resample counter
            resample_seed = base_seed + 10000 * current_resample_counter

            # Generate new data
            interior_points, boundary_points, f_interior, u_boundary = generate_data(
                n_interior,
                n_boundary,
                resample_seed
            )

            # Convert to JAX arrays
            X_interior = jnp.array(interior_points, dtype=dtype)
            X_boundary = jnp.array(boundary_points, dtype=dtype)
            f_interior = jnp.array(f_interior, dtype=dtype)
            u_boundary = jnp.array(u_boundary, dtype=dtype)
            
            # Move to device if needed
            if jax_device:
                X_interior = jax.device_put(X_interior, jax_device)
                X_boundary = jax.device_put(X_boundary, jax_device)
                f_interior = jax.device_put(f_interior, jax_device)
                u_boundary = jax.device_put(u_boundary, jax_device)
            
            # Update current sample count
            n_interior_current = len(X_interior)
            n_boundary_current = len(X_boundary)

        # Start timing optimization (after potential resampling)
        # Ensure all previous GPU operations are complete before starting the timer
        if jax_device and hasattr(jax_device, 'type') and jax_device.type == 'gpu':
            jax.block_until_ready(state)

        opt_start_time = time.time()

        # Shuffle data with deterministic seed
        if shuffle_seed is not None:
            # Combine the provided shuffle_seed with current epoch for reproducibility
            epoch_shuffle_seed = shuffle_seed + epoch
            rng = np.random.RandomState(epoch_shuffle_seed)
            indices_interior = rng.permutation(n_interior_current)
            indices_boundary = rng.permutation(n_boundary_current)
        else:
            # Use traditional permutation
            indices_interior = np.random.permutation(n_interior_current)
            indices_boundary = np.random.permutation(n_boundary_current)

        x_train_shuffled_interior = X_interior[indices_interior]
        x_train_shuffled_boundary = X_boundary[indices_boundary]
        y_train_shuffled_interior = f_interior[indices_interior]
        y_train_shuffled_boundary = u_boundary[indices_boundary]

        epoch_losses = []
        epoch_losses_pde = []
        epoch_losses_bc = []

        # Batch training
        for i in range(n_batches):
            batch_x_interior = x_train_shuffled_interior[i*n_interior_batch:(i+1)*n_interior_batch]
            batch_x_boundary = x_train_shuffled_boundary[i*n_boundary_batch:(i+1)*n_boundary_batch]
            batch_y_interior = y_train_shuffled_interior[i*n_interior_batch:(i+1)*n_interior_batch]
            batch_y_boundary = y_train_shuffled_boundary[i*n_boundary_batch:(i+1)*n_boundary_batch]

            # Perform optimization step
            state, loss, losses = optax_lbfgs_train_step(
                state, 
                ((batch_x_interior, batch_x_boundary), (batch_y_interior, batch_y_boundary)), 
                weight_pde=weight_pde, 
                weight_bc=weight_bc
            )
            epoch_losses.append(float(loss))
            epoch_losses_pde.append(float(losses[0]))
            epoch_losses_bc.append(float(losses[1]))

        # Record average training loss
        train_loss = np.mean(epoch_losses)
        train_loss_pde = np.mean(epoch_losses_pde)
        train_loss_bc = np.mean(epoch_losses_bc)
        losses_both.append(train_loss)
        losses_pde.append(train_loss_pde)
        losses_bc.append(train_loss_bc)

        # Convert current Optax state to compatible state and add to history
        compat_state = create_compatible_state(state)
        states_history.append(compat_state)

        # Ensure all GPU operations are complete before stopping the timer
        if jax_device and hasattr(jax_device, 'type') and jax_device.type == 'gpu':
            jax.block_until_ready(compat_state)

        # Calculate and record optimization time for this epoch
        epoch_opt_time = time.time() - opt_start_time
        total_opt_time += epoch_opt_time
        time_per_step.append(total_opt_time)  # Cumulative time at each step

        if (epoch + 1) % print_every == 0:
            print(f"{phase_name} - Epoch {epoch+1}/{epochs}, Loss: {train_loss:.6f}, PDE Loss: {train_loss_pde:.6f}, BC Loss: {train_loss_bc:.6f}")

    print(f"{phase_name} optimization time: {total_opt_time:.2f} seconds")

    # Adjust the first step time to exclude JIT compilation overhead
    if len(time_per_step) > 1:
        # Calculate average time per step excluding the first step
        avg_step_time = (time_per_step[-1] - time_per_step[0]) / (len(time_per_step) - 1)

        # Replace the first step's time with the average time (ensure it's not negative)
        if avg_step_time > 0 and time_per_step[0] > avg_step_time:
            time_adjustment = time_per_step[0] - avg_step_time

            # Adjust all time_per_step values
            adjusted_time_per_step = [max(0.0, t - time_adjustment) for t in time_per_step]

            # Adjust total optimization time
            adjusted_opt_time = max(0.0, total_opt_time - time_adjustment)

            print(f"Adjusted first step time: removed {time_adjustment:.4f}s of JIT compilation overhead")
            print(f"{phase_name} adjusted optimization time: {adjusted_opt_time:.2f} seconds")

            # Use adjusted time values
            total_opt_time = adjusted_opt_time
            time_per_step = adjusted_time_per_step
        else:
            print(f"First step time ({time_per_step[0]:.4f}s) appears normal, no adjustment needed")

    # Create final compatible state
    final_compat_state = create_compatible_state(state)

    # Return results
    return final_compat_state, states_history, losses_both, losses_pde, losses_bc, total_opt_time, time_per_step, current_resample_counter, state