import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import minimize
from scipy.stats import qmc
from numba import njit
import math
import os
import time
import json
import argparse
import pickle

import os
os.environ["JAX_PLATFORMS"] = "cpu"        # force CPU backend; prevents CUDA init
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") 
# ------------------------------------------------
import jax
import jax.numpy as jnp
from jax import jacrev, vmap
from jax import tree_util as jtu
import flax.linen as nn
import optax
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Tuple

# ---------------------------------------------
# Memory monitoring utilities (lightweight, cross-platform fallbacks)
# ---------------------------------------------
import platform
try:
    import psutil  # preferred if available
except Exception:
    psutil = None
try:
    import resource as _resource  # Unix fallback
except Exception:
    _resource = None

def _read_proc_status_vmrss_bytes():
    try:
        with open("/proc/self/status", "r") as f:
            for line in f:
                if line.startswith("VmRSS:"):
                    parts = line.split()
                    if len(parts) >= 2:
                        return int(parts[1]) * 1024  # kB -> bytes
    except Exception:
        pass
    return None

def get_process_memory_mb():
    """
    Returns the current process resident memory in MB (best-effort).
    Prefers psutil; falls back to /proc/self/status (Linux) or ru_maxrss (Unix).
    """
    mem_bytes = None
    try:
        if psutil is not None:
            mem_bytes = psutil.Process(os.getpid()).memory_info().rss
        else:
            mem_bytes = _read_proc_status_vmrss_bytes()
            if mem_bytes is None and _resource is not None:
                ru = _resource.getrusage(_resource.RUSAGE_SELF)
                # ru_maxrss: bytes on macOS, kilobytes on Linux
                if platform.system() == "Darwin":
                    mem_bytes = int(ru.ru_maxrss)
                else:
                    mem_bytes = int(ru.ru_maxrss) * 1024
    except Exception:
        mem_bytes = None
    if mem_bytes is None:
        return float('nan')
    return mem_bytes / (1024.0 * 1024.0)

#############################################
# Configuration
#############################################
class Config:
    # Number data points
    N_data_points = 10

    # Time domain and discretization
    tmin = 0.0
    tmax = 12.5
    # Measurement times: data points over the full time span (excluding tmin)
    @property
    def t_eval(self):
        return np.linspace(self.tmin, self.tmax,
                           self.N_data_points + 1)[1:]
    # For plotting, we use a fine grid covering the whole time span
    t_plot = np.linspace(tmin, tmax, 200)

    # Ground truth parameters for FitzHugh–Nagumo
    a_true = 0.7
    b_true = 0.8
    r_true = 12.5

    # Noise multiplier for the initial guesses.
    # The initial guess for each parameter is computed as initial_guess_noise * (ground truth)
    initial_guess_noise = 5.0

    # Initial guesses for the parameters a, b, and r
    a_init = (1 + initial_guess_noise) * a_true
    b_init = (1 + initial_guess_noise) * b_true
    r_init = (1 + initial_guess_noise) * r_true

    # Parameter bounds for a, b, and r
    param_bounds = [(0, 10), (0, 10), (0, 100)]

    # Initial condition: [v(0), w(0)] = [0.0, 0.0]
    initial_condition = np.array([0.0, 0.0])

    # Noise level for the data (std = data_error * |true value|)
    data_error = 0.2

    # PINN network configuration: for this ODE we have one input (time) and four outputs (A, B, C, D)
    hidden_layers = [20, 20]
    activation = 'tanh'

    # Collocation points for the ODE residual (in time)
    N_f = 16384

    # Collocation points for the initial condition (only one point is needed)
    N_ic = 1

    # Training configuration: epochs and learning rate
    epochs = 3000
    learning_rate = 1e-1

    # Seeds for reproducibility
    seed = 1234

    # Training method: "mdmm", "traditional" or "nelder-mead"
    training_method = "mdmm"  

    # Output folder configuration
    output_folder = "."
    
    # Precision configuration for JAX/NumPy: set to 32 or 64
    precision = 64

#############################################
# FitzHugh–Nagumo Model Function
#############################################
@njit
def fitz_nagumo_ode(t, y, a, b, r):
    v, w = y
    dvdt = v - (v**3)/3 - w
    dwdt = (v + a - b*w) / r
    return dvdt, dwdt


#############################################
# Data Creation Functions
#############################################
def create_forward_data(config, device):
    """
    Creates the time grid, solves the forward Biodiesel model to generate 
    true data, and then adds noise to produce the inverse problem data.
    
    Returns:
      t: the fine time grid for plotting (np.array)
      u0: initial condition (np.array of shape (2,))
      t_data: measurement times (np.array, same as config.t_eval)
      u_data: noisy measurements at t_data (list of 2-element arrays, one per measurement time)
      t_data_tensor, u_data_tensor: noisy data as JAX arrays for training.
    """
    # Initial condition
    u0 = config.initial_condition
    # Solve the forward problem using true parameters over the full time span
    sol = solve_ivp(lambda t, y: fitz_nagumo_ode(t, y, config.a_true, config.b_true, config.r_true),
                    (config.tmin, config.tmax), u0, t_eval=config.t_eval, method='RK45')
    # Extract true solution at measurement times
    u_true = sol.y  # shape (4, len(t_eval))
    # Add noise to each component at the measurement times
    u_data = []
    for i in range(len(config.t_eval)):
        noise = np.random.normal(0, config.data_error * np.abs(u_true[:, i]))
        u_noisy = u_true[:, i] + noise
        u_data.append(u_noisy)
    # Prepare training arrays
    t_data = config.t_eval.reshape(-1, 1)  # shape (N_data, 1)
    dtype = jnp.float32 if config.precision == 32 else jnp.float64
    t_data_tensor = jnp.array(t_data, dtype=dtype)
    # Stack the noisy data into an array of shape (N_data, 2)
    u_data_arr = np.vstack(u_data)
    u_data_tensor = jnp.array(u_data_arr, dtype=dtype)
    return config.t_plot, u0, t_data, u_data, t_data_tensor, u_data_tensor

def generate_collocation_points(config, device):
    """
    Generates collocation points for the ODE residual in time using Sobol sampling.
    Returns t_f_tensor (JAX array) of shape (N_f, 1)
    """
    sobol_engine = qmc.Sobol(d=1, scramble=True, seed=config.seed)
    samples = sobol_engine.random(n=config.N_f)
    t_f = samples * (config.tmax - config.tmin) + config.tmin
    dtype = jnp.float32 if config.precision == 32 else jnp.float64
    t_f_tensor = jnp.array(t_f, dtype=dtype)
    return t_f_tensor

def create_ic_data(config, device):
    """
    Creates collocation data for the initial condition at t = tmin.
    Returns:
      t_ic_tensor: array of initial time points of shape (1, 1)
      u_ic_target_tensor: corresponding target initial condition (array of shape (1, 2))
    """
    t_ic = np.array([[config.tmin]])
    u_ic = config.initial_condition.reshape(1, -1)
    dtype = jnp.float32 if config.precision == 32 else jnp.float64
    t_ic_tensor = jnp.array(t_ic, dtype=dtype)
    u_ic_target_tensor = jnp.array(u_ic, dtype=dtype)
    return t_ic_tensor, u_ic_target_tensor


#############################################
# PINN and Loss Function Definitions 
#############################################
def get_activation(name: str):
    name = name.lower()
    if name == 'tanh':
        return jnp.tanh
    if name == 'relu':
        return jax.nn.relu
    if name == 'sigmoid':
        return jax.nn.sigmoid
    if name == 'gelu':
        return jax.nn.gelu
    if name == 'silu':
        return jax.nn.silu
    raise ValueError(f"Unsupported activation: {name}")

class PINN(nn.Module):
    hidden_layers: Tuple[int, ...]
    activation: str
    dtype: Any

    @nn.compact
    def __call__(self, t):
        act = get_activation(self.activation)
        x = t
        for h in self.hidden_layers:
            x = nn.Dense(h, dtype=self.dtype)(x)
            x = act(x)
        x = nn.Dense(2, dtype=self.dtype)(x)  # outputs [v, w]
        return x

@dataclass
class TrainedPINN:
    params: Any
    model: Any
    training_method: str
    loss_history: dict
    param_history: dict

def get_eta_prams(params_all, training_method):
    if training_method == "traditional":
        a = jnp.exp(params_all['theta']['log_a'])
        b = jnp.exp(params_all['theta']['log_b'])
        r = jnp.exp(params_all['theta']['log_r'])
    else:
        a = params_all['theta']['a']
        b = params_all['theta']['b']
        r = params_all['theta']['r']
    return a, b, r

def model_apply(params_all, model, t):
    return model.apply({'params': params_all['net']}, t)

def ic_loss_fn(params_all, model, t_ic_tensor, u_ic_target_tensor):
    u_ic_pred = model_apply(params_all, model, t_ic_tensor)
    return jnp.mean((u_ic_pred - u_ic_target_tensor)**2)

def physics_loss_fn(params_all, model, t_f_tensor, training_method):
    """
    Computes the ODE residual for the Biodiesel system.
    """
    dtype = t_f_tensor.dtype

    def u_single(t_scalar):
        y = model_apply(params_all, model, t_scalar.reshape(1,1).astype(dtype))[0]  # (2,)
        return y  # shape (2,)

    du_dt = vmap(jacrev(u_single))(t_f_tensor.reshape(-1))
    u = model_apply(params_all, model, t_f_tensor)
    v = u[:, 0:1]; w = u[:, 1:2]
    v_t = du_dt[:, 0:1]; w_t = du_dt[:, 1:2]

    a, b, r  = get_eta_prams(params_all, training_method)

    res_v = v_t - (v - (v**3)/3 - w)
    res_w = w_t - ((v + a - b*w) / r)

    loss_v = jnp.mean(res_v**2)
    loss_w = jnp.mean(res_w**2)

    return loss_v + loss_w 

#############################################
# Learning-Rate Scheduler 
#############################################
def make_jittable_schedule(config):
    # Period 1: linear from 1e-2 -> 1e-4 on [0, epochs-30000)
    # Period 2: constant 1e-4 on [epochs-30000, epochs)
    p1_start = 0
    p1_end = int(config.epochs - 30000)
    init1, final1, const2 = 5e-3, 1e-4, 1e-4
    dtype = jnp.float32 if config.precision == 32 else jnp.float64

    p1_start_f = jnp.asarray(p1_start, jnp.int32)
    p1_end_f   = jnp.asarray(p1_end,   jnp.int32)

    def schedule(count):
        count = count.astype(jnp.int32)
        length = jnp.maximum(p1_end_f - p1_start_f, jnp.int32(1))
        progress = jnp.clip((count - p1_start_f) / length, 0, 1).astype(dtype)
        lr_lin = (init1 + progress * (final1 - init1)).astype(dtype)
        cond = (count >= p1_start_f) & (count < p1_end_f)
        lr = jnp.where(cond, lr_lin, jnp.asarray(const2, dtype))
        return lr  # absolute LR
    return schedule

#############################################
# MDMM (JAX) — constraints and gradient transform
#############################################
class LagrangeMultiplier(NamedTuple):
    value: Any

def mdmm_prepare_update(tree):
    pred = lambda x: isinstance(x, LagrangeMultiplier)
    return jtu.tree_map(
        lambda x: LagrangeMultiplier(-x.value) if pred(x) else x,
        tree, is_leaf=pred
    )

def optax_prepare_update():
    def init_fn(params):
        del params
        return optax.EmptyState()
    def update_fn(updates, state, params=None):
        del params
        return mdmm_prepare_update(updates), state
    return optax.GradientTransformation(init_fn, update_fn)

class Constraint(NamedTuple):
    init: Callable
    loss: Callable

def eq_constraint(fun, damping=1., weight=1., reduction=jnp.sum):
    def init_fn(*args, **kwargs):
        return {'lambda': LagrangeMultiplier(jnp.zeros_like(fun(*args, **kwargs)))}
    def loss_fn(params, *args, **kwargs):
        inf = fun(*args, **kwargs)
        return weight * reduction(params['lambda'].value * inf + damping * inf ** 2 / 2), inf
    return Constraint(init_fn, loss_fn)

def ineq_constraint(fun, damping=1., weight=1., reduction=jnp.sum):
    def init_fn(*args, **kwargs):
        out = fun(*args, **kwargs)
        return {'lambda': LagrangeMultiplier(jnp.zeros_like(out)),
                'slack': jax.nn.relu(out) ** 0.5}
    def loss_fn(params, *args, **kwargs):
        inf = fun(*args, **kwargs) - params['slack'] ** 2
        return weight * reduction(params['lambda'].value * inf + damping * inf ** 2 / 2), inf
    return Constraint(init_fn, loss_fn)

def combine_constraints(*args):
    init_fns, loss_fns = zip(*args)
    def init_fn(*a, **k):
        return tuple(fn(*a, **k) for fn in init_fns)
    def loss_fn(params, *a, **k):
        outs = [fn(p, *a, **k) for p, fn in zip(params, loss_fns)]
        return sum(x[0] for x in outs), tuple(x[1] for x in outs)
    return Constraint(init_fn, loss_fn)

def bound_hard(fun_value, lo, hi, damping=1., weight=1., reduction=jnp.sum):
    return eq_constraint(lambda *args, **kwargs: jnp.clip(fun_value(*args, **kwargs), lo, hi) - fun_value(*args, **kwargs),
                         damping=damping, weight=weight, reduction=reduction)

#############################################
# Training Function using MDMM (per-step jitted)
#############################################
def train_mdmm(model, config, X_data_tensor, u_data_tensor, t_f_tensor, t_ic_tensor, u_ic_target_tensor):
    """
    Trains the PINN model using MDMM (JAX/Flax/Optax) with Adam and per-step JIT.
    """
    dtype = jnp.float32 if config.precision == 32 else jnp.float64

    # Initialize model params
    rng = jax.random.PRNGKey(config.seed)
    net_params = model.init(rng, jnp.zeros((1,1), dtype=dtype))['params']

    # Initialize theta (unknown parameters)
    theta = {
        'a': jnp.array(config.a_init, dtype),
        'b': jnp.array(config.b_init, dtype),
        'r': jnp.array(config.r_init, dtype),
    }

    params_all = {'net': net_params, 'theta': theta}

    # Constraints: physics == 0, ic == 0, hard bounds on k's
    physics_constraint = eq_constraint(lambda p: physics_loss_fn(p, model, t_f_tensor, "mdmm"), damping=1., weight=1.)
    ic_constraint = eq_constraint(lambda p: ic_loss_fn(p, model, t_ic_tensor, u_ic_target_tensor), damping=1., weight=1.)

    (a_lo, a_hi), (b_lo, b_hi), (r_lo, r_hi) = config.param_bounds
    bound_a = bound_hard(lambda p: p['theta']['a'], a_lo, a_hi, damping=1., weight=1.)
    bound_b = bound_hard(lambda p: p['theta']['b'], b_lo, b_hi, damping=1., weight=1.)
    bound_r = bound_hard(lambda p: p['theta']['r'], r_lo, r_hi, damping=1., weight=1.)

    constraints = combine_constraints(physics_constraint, ic_constraint, bound_a, bound_b, bound_r)
    mdmm_params = constraints.init(params_all)
    params_all['mdmm'] = mdmm_params

    # Optimizer with  schedule and MDMM update transform
    schedule = make_jittable_schedule(config)
    optimizer = optax.chain(
        optax.adam(learning_rate=schedule),
        optax_prepare_update(),
    )
    opt_state = optimizer.init(params_all)

    def data_loss_fn(p):
        u_pred_data = model_apply(p, model, X_data_tensor)
        return jnp.sqrt(jnp.mean(((u_pred_data - u_data_tensor)/(u_data_tensor))**2))

    def total_loss_and_logs(p):
        dloss = data_loss_fn(p)
        mdmm_term, _ = constraints.loss(p['mdmm'], p)
        value = dloss + mdmm_term
        ic_val   = ic_loss_fn(p, model, t_ic_tensor, u_ic_target_tensor)
        phys_val = physics_loss_fn(p, model, t_f_tensor, "mdmm")
        return value, (dloss, ic_val, phys_val)

    @jax.jit
    def mdmm_step(p, state):
        (val, logs), grads = jax.value_and_grad(total_loss_and_logs, has_aux=True)(p)
        updates, state = optimizer.update(grads, state, p)
        p = optax.apply_updates(p, updates)
        dloss, ic_val, phys_val = logs
        return p, state, dloss, ic_val, phys_val

    loss_history = {"data": [], "ic": [], "physics": [], "memory_mb": []}
    param_history = {"a": [], "b": [], "r": []}
    memory_mb_history = []

    start_train_time = time.time()
    for it in range(config.epochs + 1):
        params_all, opt_state, dloss, ic_val, phys_val = mdmm_step(params_all, opt_state)

        loss_history["data"].append(float(dloss))
        loss_history["ic"].append(float(ic_val))
        loss_history["physics"].append(float(phys_val))

        a_val, b_val, r_val = get_eta_prams(params_all, "mdmm")
        param_history["a"].append(float(a_val))
        param_history["b"].append(float(b_val))
        param_history["r"].append(float(r_val))

        # Memory usage sampling (MB)
        memory_mb_history.append(float(get_process_memory_mb()))

        if it % 500 == 0:
            print(f"Iter {it:5d}, Data Loss: {float(dloss):.5e}, "
                  f"IC Loss: {float(ic_val):.5e}, Physics Loss: {float(phys_val):.5e}, "
                  f"a: {float(a_val):.5f}, b: {float(b_val):.5f}, "
                  f"r: {float(r_val):.5f}")
    end_train_time = time.time()
    train_runtime_seconds = end_train_time - start_train_time

    loss_history["memory_mb"] = memory_mb_history

    trained = TrainedPINN(params=params_all, model=model, training_method="mdmm",
                          loss_history=loss_history, param_history=param_history)
    return trained, train_runtime_seconds

#############################################
# Traditional Training Function (per-step jitted)
#############################################
def train_traditional(model, config, X_data_tensor, u_data_tensor, t_f_tensor, t_ic_tensor, u_ic_target_tensor):
    """
    Trains the PINN model using a traditional approach by minimizing the sum of:
      Data loss + IC loss + Physics loss, with Adam and per-step JIT.
    """
    dtype = jnp.float32 if config.precision == 32 else jnp.float64

    rng = jax.random.PRNGKey(config.seed)
    net_params = model.init(rng, jnp.zeros((1,1), dtype=dtype))['params']

    theta = {
        'log_a': jnp.array(np.log(config.a_init), dtype),
        'log_b': jnp.array(np.log(config.b_init), dtype),
        'log_r': jnp.array(np.log(config.r_init), dtype),
    }

    params_all = {'net': net_params, 'theta': theta}

    schedule = make_jittable_schedule(config)
    optimizer = optax.adam(learning_rate=schedule)
    opt_state = optimizer.init(params_all)

    def total_loss(p):
        loss_data = jnp.sqrt(jnp.mean(((model_apply(p, model, X_data_tensor) - u_data_tensor)/(u_data_tensor))**2))
        loss_ic   = ic_loss_fn(p, model, t_ic_tensor, u_ic_target_tensor)
        loss_phys = physics_loss_fn(p, model, t_f_tensor, "traditional")
        return loss_data + loss_ic + loss_phys, (loss_data, loss_ic, loss_phys)

    @jax.jit
    def trad_step(p, state):
        (tot, logs), grads = jax.value_and_grad(total_loss, has_aux=True)(p)
        updates, state = optimizer.update(grads, state, p)
        p = optax.apply_updates(p, updates)
        ld, li, lp = logs
        return p, state, tot, ld, li, lp

    loss_history = {"data": [], "ic": [], "physics": []}
    param_history = {"a": [], "b": [], "r": []}
    memory_mb_history = []

    start_train_time = time.time()
    for it in range(config.epochs + 1):
        params_all, opt_state, tot, loss_data, loss_ic, loss_phys = trad_step(params_all, opt_state)

        loss_history["data"].append(float(loss_data))
        loss_history["ic"].append(float(loss_ic))
        loss_history["physics"].append(float(loss_phys))

        a_val, b_val, r_val = get_eta_prams(params_all, "traditional")
        param_history["a"].append(float(a_val))
        param_history["b"].append(float(b_val))
        param_history["r"].append(float(r_val))

        # Memory usage sampling (MB)
        memory_mb_history.append(float(get_process_memory_mb()))

        if it % 500 == 0:
            print(f"Iter {it:5d}, Total Loss: {float(tot):.5e}, "
                  f"Data Loss: {float(loss_data):.5e}, IC Loss: {float(loss_ic):.5e}, "
                  f"Physics Loss: {float(loss_phys):.5e}, "
                  f"a: {float(a_val):.5f}, b: {float(b_val):.5f}, "
                  f"r: {float(r_val):.5f}")
    end_train_time = time.time()
    train_runtime_seconds = end_train_time - start_train_time

    loss_history["memory_mb"] = memory_mb_history

    trained = TrainedPINN(params=params_all, model=model, training_method="traditional",
                          loss_history=loss_history, param_history=param_history)
    return trained, train_runtime_seconds

#############################################
# Nelder–Mead Training Function
#############################################
def train_nelder_mead(config, t_data, u_data):
    """
    Trains by optimizing [a, b, r] with Nelder–Mead.
    The objective is the root relative mean squared error between the simulated solution (via solve_ivp)
    and the noisy measurements.
    """
    log = {"loss": [], "a": [], "b": [], "r": [], "memory_mb": []}

    def objective_fn(p):
        a_est, b_est, r_est = p
        sol = solve_ivp(lambda t, y: fitz_nagumo_ode(t, y, a_est, b_est, r_est),
                        (config.tmin, config.tmax), config.initial_condition,
                        t_eval=config.t_eval, method='RK45')
        if (not sol.success) or (len(sol.t) < len(config.t_eval)):
            # Memory usage sampling (MB)
            log["memory_mb"].append(float(get_process_memory_mb()))
            return 1e12
        y = sol.y  # shape (2, len(t_eval))
        u_data_arr = np.vstack(u_data)  # shape (len(t_eval), 4)
        error = np.sqrt(np.mean(((y.T - u_data_arr) / u_data_arr)**2))
        log["loss"].append(float(error))
        log["a"].append(float(a_est))
        log["b"].append(float(b_est))
        log["r"].append(float(r_est))
        # Memory usage sampling (MB)
        log["memory_mb"].append(float(get_process_memory_mb()))
        return error

    initial_guess = [config.a_init, config.b_init, config.r_init]
    bounds = config.param_bounds
    start_train_time = time.time()
    res = minimize(objective_fn, initial_guess, method='Nelder-Mead', bounds=bounds)
    end_train_time = time.time()
    train_runtime_seconds = end_train_time - start_train_time
    print("Nelder–Mead optimization success:", res.success)
    return res.x, log, train_runtime_seconds

#############################################
# Unified Plotting Functions
#############################################
def plot_results(model_or_params, config, t_plot, t_data, u_data, u0, device):
    """
    Unified plotting function.
    """
    plt.figure(figsize=(12,5))
    
    # Compute true solution over t_plot using true parameters
    sol_true = solve_ivp(lambda t, y: fitz_nagumo_ode(t, y, config.a_true, config.b_true, config.r_true),
                         (config.tmin, config.tmax), u0, t_eval=t_plot, method='RK45')
    true_v = sol_true.y[0, :]
    true_w = sol_true.y[1, :]

    if isinstance(model_or_params, TrainedPINN):
        a_est, b_est, r_est = get_eta_prams(model_or_params.params, model_or_params.training_method)
        a_est = float(a_est); b_est = float(b_est); r_est = float(r_est)
        sol_est = solve_ivp(lambda t, y: fitz_nagumo_ode(t, y, a_est, b_est, r_est),
                            (config.tmin, config.tmax), u0, t_eval=t_plot, method='RK45')
        est_v = sol_est.y[0, :]
        est_w = sol_est.y[1, :]

        dtype = jnp.float32 if config.precision == 32 else jnp.float64
        t_plot_tensor = jnp.array(t_plot.reshape(-1, 1), dtype=dtype)
        pinn_pred = np.array(model_or_params.model.apply({'params': model_or_params.params['net']}, t_plot_tensor))
        pinn_v = pinn_pred[:, 0]
        pinn_w = pinn_pred[:, 1]
    else:
        k1_est, k2_est, k3_est, k4_est = model_or_params
        sol_est = solve_ivp(lambda t, y: fitz_nagumo_ode(t, y, k1_est, k2_est, k3_est, k4_est),
                            (config.tmin, config.tmax), u0, t_eval=t_plot, method='RK45')
        sol_est_y = np.array(sol_est.y)
        est_A = sol_est_y[0, :]
        est_B = sol_est_y[1, :]
        est_C = sol_est_y[2, :]
        est_D = sol_est_y[3, :]

    # Plot for component u
    plt.subplot(1, 2, 1)
    plt.plot(t_plot, true_v, 'k--', label='True solution (v)')
    plt.plot(t_plot, est_v, 'g-', label='Estimated trajectory (v)')
    if isinstance(model_or_params, TrainedPINN):
        plt.plot(t_plot, pinn_v, 'b-', label='PINN prediction (v)')
    plt.scatter(t_data.flatten(), [ud[0] for ud in u_data], color='r', marker='o', label='Noisy data (v)')
    plt.xlabel("Time")
    plt.ylabel("v")
    plt.legend()
    plt.title("Component v over time")

    # Plot for component v
    plt.subplot(1, 2, 2)
    plt.plot(t_plot, true_w, 'k--', label='True solution (w)')
    plt.plot(t_plot, est_w, 'g-', label='Estimated trajectory (w)')
    if isinstance(model_or_params, TrainedPINN):
        plt.plot(t_plot, pinn_w, 'b-', label='PINN prediction (w)')
    plt.scatter(t_data.flatten(), [ud[1] for ud in u_data], color='r', marker='o', label='Noisy data (w)')
    plt.xlabel("Time")
    plt.ylabel("w")
    plt.legend()
    plt.title("Component w over time")


    plt.tight_layout()
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    folder_name = f"FitzHughNagumo_{config.training_method}_initial_guess_error_{config.initial_guess_error}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)
    plt.savefig(os.path.join(figures_folder, "Solution_comparison.png"), dpi=300)
    
    # Save the trajectories and corresponding time grid in a npz file in the full_folder
    trajectories = {}
    trajectories = {"t": t_plot}
    if isinstance(model_or_params, TrainedPINN):
        trajectories["u1_true"] = true_v
        trajectories["u1_estimated"] = est_v
        trajectories["u1_PINN"] = pinn_v
        trajectories["u2_true"] = true_w
        trajectories["u2_estimated"] = est_w
        trajectories["u2_PINN"] = pinn_w
    else:
        trajectories["u1_true"] = true_v
        trajectories["u1_estimated"] = est_v
        trajectories["u2_true"] = true_w
        trajectories["u2_estimated"] = est_w
    np.savez(os.path.join(full_folder, "trajectories.npz"), **trajectories)

def plot_loss(loss_log, method, config):
    """
    Plots the evolution of the loss.
    """
    folder_name = f"FitzHughNagumo_{method}_initial_guess_error_{config.initial_guess_error}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)

    if method in ["mdmm", "traditional"]:
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        axs[0].semilogy(loss_log["data"], label="Data loss")
        axs[0].set_title("Data Loss")
        axs[0].set_xlabel("Epoch")
        axs[1].semilogy(loss_log["ic"], label="IC loss")
        axs[1].set_title("IC Loss")
        axs[1].set_xlabel("Epoch")
        axs[2].semilogy(loss_log["physics"], label="Physics loss")
        axs[2].set_title("Physics Loss")
        axs[2].set_xlabel("Epoch")
        plt.tight_layout()
    else:
        plt.figure(figsize=(6, 4))
        plt.semilogy(loss_log["loss"], label="Objective loss")
        plt.xlabel("Iteration")
        plt.title("Objective Loss (Nelder–Mead)")
        plt.tight_layout()
    plt.savefig(os.path.join(figures_folder, "Loss_evolution.png"), dpi=300)

def plot_parameters(param_log, method, config):
    """
    Plots the evolution of the parameter guesses.
    """
    folder_name = f"Biodiesel_{config.training_method}_initial_guess_error_{config.initial_guess_error}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].plot(param_log["a"], label="a")
    axs[0].axhline(y=config.a_true, color='r', linestyle='--')
    axs[0].set_title("a parameter")
    axs[0].set_xlabel("Epoch" if method in ["mdmm", "traditional"] else "Iteration")
    axs[1].plot(param_log["b"], label="b")
    axs[1].axhline(y=config.b_true, color='r', linestyle='--')
    axs[1].set_title("b parameter")
    axs[1].set_xlabel("Epoch" if method in ["mdmm", "traditional"] else "Iteration")
    axs[2].plot(param_log["r"], label="r")
    axs[2].axhline(y=config.r_true, color='r', linestyle='--')
    axs[2].set_title("r parameter")
    axs[2].set_xlabel("Epoch" if method in ["mdmm", "traditional"] else "Iteration")
    plt.tight_layout()
    plt.savefig(os.path.join(figures_folder, "Parameter_evolution.png"), dpi=300)

#############################################
# Function to save outputs and summary statistics
#############################################
def save_outputs_and_summary(config, method, t_plot, u0, t_data, u_data, runtime,
                             trained_model=None, optimal_params=None, log=None):
    # Create output folder structure and save outputs
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    folder_name = f"FitzHughNagumo_{config.training_method}_initial_guess_error_{config.initial_guess_error}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)
    
    # Save the trained model, noisy data, loss history, and parameter history
    if method in ["mdmm", "traditional"]:
        with open(os.path.join(full_folder, "trained_model.pt"), "wb") as f:
            pickle.dump({
                'params': trained_model.params,
                'training_method': trained_model.training_method,
                'hidden_layers': config.hidden_layers,
                'activation': config.activation,
                'precision': config.precision
            }, f)
        np.savez(os.path.join(full_folder, "noisy_data.npz"), t_data=t_data, u_data=u_data)
        np.savez(os.path.join(full_folder, "loss_history.npz"), **trained_model.loss_history)
        np.savez(os.path.join(full_folder, "parameter_history.npz"), **trained_model.param_history)
    elif method == "nelder-mead":
        np.savez(os.path.join(full_folder, "optimal_params.npz"), optimal_params=optimal_params)
        np.savez(os.path.join(full_folder, "loss_log.npz"), **log)
        np.savez(os.path.join(full_folder, "noisy_data.npz"), t_data=t_data, u_data=u_data)
        if "a" in log and "b" in log and "r" in log :
            np.savez(os.path.join(full_folder, "parameter_history.npz"), a=log["a"], b=log["b"], r=log["r"])
    
    # Determine estimated parameters
    if method in ["mdmm", "traditional"]:
        a_est, b_est, r_est = get_eta_prams(trained_model.params, method)
        a_est = float(a_est)
        b_est = float(b_est)
        r_est = float(r_est)
    elif method == "nelder-mead":
        a_est, b_est, r_est = optimal_params

    # Compute root mean squared error on data
    sol_ml = solve_ivp(lambda t, y: fitz_nagumo_ode(t, y, a_est, b_est, r_est),
                        (config.tmin, config.tmax), u0, t_eval=config.t_eval, method='RK45')
    u_ml = sol_ml.y.T
    u_data_arr = np.vstack(u_data)  # shape (len(t_eval), 2)
    rel_errors = ((u_ml - u_data_arr) / u_data_arr)**2
    root_relative_mse_data = float(np.sqrt(np.mean(rel_errors)))
    
    # Compute root mean squared error of parameters (compared to true values)
    root_relative_mse_parameters = float(np.sqrt((((a_est - config.a_true)/config.a_true)**2 + ((b_est - config.b_true)/config.b_true)**2 + 
                               ((r_est - config.r_true)/config.r_true)**2)/3))
    
    # Compute max distance between solve_ivp estimated trajectory and trained PINN prediction over the dense time grid
    if trained_model is not None:
        dtype = jnp.float32 if config.precision == 32 else jnp.float64
        t_plot_tensor = jnp.array(t_plot.reshape(-1, 1), dtype=dtype)
        pinn_pred = np.array(trained_model.model.apply({'params': trained_model.params['net']}, t_plot_tensor))  # shape (len(t_plot), 2)
        sol_est = solve_ivp(lambda t, y: fitz_nagumo_ode(t, y, a_est, b_est, r_est),
                            (config.tmin, config.tmax), u0, t_eval=t_plot, method='RK45')
        est_curve = sol_est.y.T  # shape (len(t_plot), 2)
        
        # Calculate maximum pointwise distance
        max_distance = float(np.max(np.abs(est_curve - pinn_pred)))
    else:
        sol_true = solve_ivp(lambda t, y: fitz_nagumo_ode(t, y, config.a_true, config.b_true, config.r_true),
                             (config.tmin, config.tmax), u0, t_eval=t_plot, method='RK45')
        true_curve = sol_true.y.T  # shape (len(t_plot), 4)
        sol_est = solve_ivp(lambda t, y: fitz_nagumo_ode(t, y, a_est, b_est, r_est),
                            (config.tmin, config.tmax), u0, t_eval=t_plot, method='RK45')
        est_curve = sol_est.y.T
        
        # Calculate maximum pointwise distance
        max_distance = float(np.max(np.abs(true_curve - est_curve)))
    
    # Memory usage statistics (average and maximum over epochs/iterations)
    memory_usage_avg = float('nan')
    memory_usage_max = float('nan')
    if method in ["mdmm", "traditional"] and trained_model is not None:
        mem_list = trained_model.loss_history.get("memory_mb", [])
        if len(mem_list) > 0:
            finite_vals = [m for m in mem_list if np.isfinite(m)]
            if len(finite_vals) > 0:
                memory_usage_avg = float(np.mean(finite_vals))
                memory_usage_max = float(np.max(finite_vals))
    elif method == "nelder-mead" and log is not None:
        mem_list = log.get("memory_mb", [])
        if len(mem_list) > 0:
            finite_vals = [m for m in mem_list if np.isfinite(m)]
            if len(finite_vals) > 0:
                memory_usage_avg = float(np.mean(finite_vals))
                memory_usage_max = float(np.max(finite_vals))
    
    summary = {
        "training_method": config.training_method,
        "runtime_seconds": float(runtime),
        "initial_condition": config.initial_condition.tolist() if isinstance(config.initial_condition, np.ndarray) else str(config.initial_condition),
        "parameter_initial_guess": [float(config.a_init), float(config.b_init), float(config.r_init)],
        "parameter_ground_truth": [float(config.a_true), float(config.b_true), float(config.r_true)],
        "data_noise": float(config.data_error),
        "initial_guess_error": float(config.initial_guess_error),
        "root_relative_mse_data": float(root_relative_mse_data),
        "root_relative_mse_parameters": float(root_relative_mse_parameters),
        "max_distance": float(max_distance),
        "memory_usage_avg": float(memory_usage_avg),
        "memory_usage_max": float(memory_usage_max)
    }
    # Add additional configuration parameters to summary
    summary["tmin"] = config.tmin
    summary["tmax"] = config.tmax
    summary["hidden_layers"] = config.hidden_layers
    summary["activation"] = config.activation
    summary["param_bounds"] = config.param_bounds

    if method in ["mdmm", "traditional"]:
        summary["parameter_estimates"] = [float(a_est), float(b_est), float(r_est)]
    elif method == "nelder-mead":
        summary["parameter_estimates"] = [float(opt) for opt in optimal_params]

    with open(os.path.join(full_folder, "summary_statistics.json"), "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=4)

#############################################
# Main Function
#############################################
def main():
    parser = argparse.ArgumentParser(description="Train PINN for Biodiesel system (JAX/Flax/Optax)")
    parser.add_argument("--epochs", type=int, default=50000, help="Number of training epochs")
    parser.add_argument("--initial_guess_error", type=float, default=5.0, help="Initial guess error multiplier")
    parser.add_argument("--data_error", type=float, default=0.3, help="Data noise level")
    parser.add_argument("--N_data_points", type=int, default=8, help="Number of data points")
    parser.add_argument("--training_method", type=str, default="mdmm", help="Choice of minimization method (mdmm, traditional or nelder-mead)")
    parser.add_argument("--precision", type=int, default=64, help="Precision for calculations (32 or 64)")
    args = parser.parse_args()

    config = Config()
    # Update configuration with command line arguments
    config.epochs = args.epochs
    config.initial_guess_error = args.initial_guess_error
    config.data_error = args.data_error
    config.N_data_points = args.N_data_points
    config.training_method = args.training_method
    config.precision = args.precision
    config.a_init = (1 + config.initial_guess_noise) * config.a_true
    config.b_init = (1 + config.initial_guess_noise) * config.b_true
    config.r_init = (1 + config.initial_guess_noise) * config.r_true

    # Set JAX default dtype based on precision
    if config.precision == 64:
        jax.config.update("jax_enable_x64", True)
        dtype = jnp.float64
    else:
        jax.config.update("jax_enable_x64", False)
        dtype = jnp.float32

    np.random.seed(config.seed)
    device = jax.devices()[0].platform
    print("Using device:", device)

    # Create forward data and noisy measurements
    t_plot_dense, u0, t_data, u_data, X_data_tensor, u_data_tensor = create_forward_data(config, device)
    t_f_tensor = generate_collocation_points(config, device)
    t_ic_tensor, u_ic_target_tensor = create_ic_data(config, device)

    method = config.training_method.lower()
    hidden_layers = tuple(config.hidden_layers)
    model = PINN(hidden_layers=hidden_layers, activation=config.activation, dtype=dtype)

    if method in ["mdmm", "traditional"]:
        if method == "mdmm":
            print("Training using MDMM...")
            trained_model, train_runtime = train_mdmm(model, config, X_data_tensor, u_data_tensor, t_f_tensor, t_ic_tensor, u_ic_target_tensor)
        else:
            print("Training using traditional loss minimization...")
            trained_model, train_runtime = train_traditional(model, config, X_data_tensor, u_data_tensor, t_f_tensor, t_ic_tensor, u_ic_target_tensor)
        a_est, b_est, r_est = get_eta_prams(trained_model.params, method)
        print(f"\nFinal recovered parameters: a = {float(a_est):.5f}, b = {float(b_est):.5f}, r = {float(r_est):.5f}")
        plot_results(trained_model, config, config.t_plot, t_data, u_data, u0, device)
        plot_loss(trained_model.loss_history, method, config)
        plot_parameters(trained_model.param_history, method, config)
        optimal_params = None
        log = None
        runtime = train_runtime
    elif method == "nelder-mead":
        print("Training using Nelder–Mead...")
        optimal_params, log, train_runtime = train_nelder_mead(config, t_data, u_data)
        print(f"\nNelder–Mead: Final recovered parameters: a = {optimal_params[0]:.5f}, b = {optimal_params[1]:.5f}, r = {optimal_params[2]:.5f}")
        plot_results(optimal_params, config, config.t_plot, t_data, u_data, u0, device)
        plot_loss(log, method, config)
        plot_parameters(log, method, config)
        trained_model = None
        runtime = train_runtime
    else:
        raise ValueError(f"Unknown training method: {config.training_method}")

    save_outputs_and_summary(config, method, config.t_plot, u0, t_data, u_data, runtime,
                             trained_model=trained_model if method in ["mdmm", "traditional"] else None,
                             optimal_params=optimal_params if method=="nelder-mead" else None,
                             log=log if method=="nelder-mead" else None)

if __name__ == "__main__":
    main()
