"""
Common utilities for Optax optimizers.
Contains shared code for state management, serialization, and loss functions.
"""

import jax
import jax.numpy as jnp
import numpy as np
import pickle
from pathlib import Path
from typing import Any, Optional, Dict
import optax
from flax.training import train_state
from foo.models import TrainState


class OptaxTrainState(train_state.TrainState):
    """TrainState with batch stats for normalization layers."""
    batch_stats: Optional[Any] = None

    @classmethod
    def create(cls, *, apply_fn, params, tx, batch_stats=None):
        """Creates a new instance with `step=0`."""
        return cls(
            step=0,
            apply_fn=apply_fn,
            params=params,
            tx=tx,
            opt_state=tx.init(params),
            batch_stats=batch_stats,
        )


def optax_mse_loss(logits, targets):
    """Mean squared error loss for regression."""
    return jnp.mean((logits - targets) ** 2)


def create_compatible_state(optax_state):
    """
    Create state compatible with foo.models.TrainState for interface consistency.
    
    Args:
        optax_state: State from OptaxTrainState
        
    Returns:
        Compatible TrainState instance
    """
    return TrainState(
        step=0,  # Use fixed step 0 to maintain interface consistency
        apply_fn=optax_state.apply_fn,
        params=optax_state.params,
        tx=None,  # TrainState requires tx
        opt_state=None,  # TrainState requires opt_state
        rngs=None,  # TrainState requires rngs
        batch_stats=optax_state.batch_stats
    )


def save_optimizer_state(state, path, filename):
    """
    Save Optax optimizer state to file.
    
    Args:
        state: Optimizer state to save
        path: Directory path
        filename: Filename to save state
        
    Returns:
        bool: Success status
    """
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    # Convert JAX arrays to numpy arrays for reliable serialization
    def convert_jax_to_numpy(x):
        if hasattr(x, 'dtype') and hasattr(x, 'shape'):
            return np.array(x)
        return x

    # Process state for serialization
    state_dict = {
        'step': int(state.step),
        'params': jax.tree_util.tree_map(convert_jax_to_numpy, state.params),
        'opt_state': jax.tree_util.tree_map(convert_jax_to_numpy, state.opt_state),
    }

    if state.batch_stats is not None:
        state_dict['batch_stats'] = jax.tree_util.tree_map(
            convert_jax_to_numpy, state.batch_stats)

    with open(path / filename, 'wb') as f:
        pickle.dump(state_dict, f)

    print(f"Successfully saved optimizer state to {path / filename}")
    return True


def load_optimizer_state(path, filename, apply_fn, learning_rate, 
                         optimizer_creator_fn, params=None):
    """
    Load Optax optimizer state from file.
    
    Args:
        path: Directory path
        filename: Filename to load state from
        apply_fn: Model apply function
        learning_rate: Learning rate for optimizer
        optimizer_creator_fn: Function that creates the optimizer
        params: Optional params if needed for initialization
        
    Returns:
        tuple: (loaded_state, success_flag)
    """
    try:
        with open(Path(path) / filename, 'rb') as f:
            state_dict = pickle.load(f)

        # Convert numpy arrays back to JAX arrays
        def convert_numpy_to_jax(x):
            if hasattr(x, 'dtype') and hasattr(x, 'shape'):
                return jnp.array(x)
            return x

        # Reconstruct training state
        params = jax.tree_util.tree_map(convert_numpy_to_jax, state_dict['params'])
        opt_state = jax.tree_util.tree_map(convert_numpy_to_jax, state_dict['opt_state'])

        # Create Optax optimizer
        tx = optimizer_creator_fn(learning_rate)

        # Create new state
        new_state = OptaxTrainState(
            step=state_dict['step'],
            apply_fn=apply_fn,
            params=params,
            tx=tx,
            opt_state=opt_state,
            batch_stats=None
        )

        # Load batch statistics if available
        if 'batch_stats' in state_dict and state_dict['batch_stats'] is not None:
            batch_stats = jax.tree_util.tree_map(
                convert_numpy_to_jax, state_dict['batch_stats'])
            new_state = new_state.replace(batch_stats=batch_stats)

        print(f"Successfully loaded optimizer state from {Path(path) / filename}")
        return new_state, True
    except Exception as e:
        print(f"Error loading optimizer state: {e}")
        import traceback
        traceback.print_exc()
        return None, False