import os
os.environ.setdefault("JAX_PLATFORMS", "cpu")                  # keep behavior close to your previous JAX scripts
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import minimize
from scipy.stats import qmc
import math
import time
import json
import argparse
import pickle
import platform

# ---------------- JAX / Flax / Optax ----------------
import jax
import jax.numpy as jnp
from jax import grad, jacrev, jvp, vmap
import flax.linen as nn
import optax
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Tuple

# --------------- Lightweight memory monitor ---------------
try:
    import psutil
except Exception:
    psutil = None
try:
    import resource as _resource
except Exception:
    _resource = None

def _read_proc_status_vmrss_bytes():
    try:
        with open("/proc/self/status", "r") as f:
            for line in f:
                if line.startswith("VmRSS:"):
                    parts = line.split()
                    if len(parts) >= 2:
                        return int(parts[1]) * 1024  # kB -> bytes
    except Exception:
        pass
    return None

def get_process_memory_mb():
    mem_bytes = None
    try:
        if psutil is not None:
            mem_bytes = psutil.Process(os.getpid()).memory_info().rss
        else:
            mem_bytes = _read_proc_status_vmrss_bytes()
            if mem_bytes is None and _resource is not None:
                ru = _resource.getrusage(_resource.RUSAGE_SELF)
                if platform.system() == "Darwin":
                    mem_bytes = int(ru.ru_maxrss)
                else:
                    mem_bytes = int(ru.ru_maxrss) * 1024
    except Exception:
        mem_bytes = None
    if mem_bytes is None:
        return float('nan')
    return mem_bytes / (1024.0 * 1024.0)

# =========================================================
# Configuration
# =========================================================
class Config:
    # Domain and discretization
    xmin = 0.0
    xmax = 10.0
    L = xmax
    Nx = 256
    N_data_points = 18

    # Time domain
    tmin = 0.0
    tmax = 2.0

    # Measurement times
    t_eval = [1.0, 2.0]

    # Plot times
    plot_times = [0.0, 1.0, 2.0]

    # Ground-truth parameters (Fisher–KPP)
    D_true = 0.5
    r_true = 1.0

    # Initial guess noise
    initial_guess_noise = 5.0

    # Initial guesses
    D_init = (1 + initial_guess_noise) * D_true
    r_init = (1 + initial_guess_noise) * r_true

    # Parameter bounds [(D_lo,D_hi), (r_lo,r_hi)]
    parameter_bounds = [(0.1, 3.0), (0.5, 6.0)]

    # Initial condition: u(x,0) = (1/10) * exp(-x)
    @staticmethod
    def initial_condition(x):
        return (1.0 / 10.0) * np.exp(-x)

    # Data noise level (std = data_error * |true value|)
    data_error = 0.3

    # PINN architecture
    hidden_layers = [20, 20]
    activation = 'tanh'

    # Collocation counts
    N_f = 16384
    N_bc = 1024
    N_ic = 1024

    # Training config
    epochs = 20000
    learning_rate = 1e-2

    # Seed
    seed = 1234

    # Training method: "mdmm", "traditional", "nelder-mead"
    training_method = "mdmm"

    # Output
    output_folder = "."

    # Precision (32 or 64)
    precision = 64
    np_dtype = np.float64
    jax_dtype = jnp.float64

# =========================================================
# Utilities
# =========================================================
def generate_biased_samples(config, xmin, xmax, N, bias_strength=.1):
    samples = np.random.lognormal(mean=0, sigma=bias_strength, size=N)
    max_val = np.max(samples)
    normalized = samples / max_val
    scaled = xmin + normalized * (xmax - xmin)
    scaled = np.clip(scaled, xmin, xmax)
    return np.sort(scaled.astype(config.np_dtype))

# =========================================================
# Fisher–KPP (MoL) forward model (NumPy / SciPy)
#   u_t = D u_xx + r u (1 - u) with Neumann BCs
# =========================================================
def fisher_kpp_mol_rhs(t, u, dx, D, r):
    N = u.shape[0]
    dudt = np.empty_like(u)
    # Neumann at left: u_x=0 -> u_xx approx via ghost point
    dudt[0] = D * (2 * (u[1] - u[0]) / (dx**2)) + r * u[0] * (1 - u[0])
    for i in range(1, N-1):
        dudt[i] = D * (u[i+1] - 2*u[i] + u[i-1]) / (dx**2) + r * u[i] * (1 - u[i])
    # Neumann at right
    dudt[N-1] = D * (2 * (u[N-2] - u[N-1]) / (dx**2)) + r * u[N-1] * (1 - u[N-1])
    return dudt

def fisher_kpp_wrapper(t, u, dx, D, r):
    return fisher_kpp_mol_rhs(t, u, dx, D, r)

# =========================================================
# Data creation
# =========================================================
def create_forward_data(config):
    x = np.linspace(config.xmin, config.xmax, config.Nx, dtype=config.np_dtype)
    dx = x[1] - x[0]
    u0 = config.initial_condition(x)

    t_span = (config.tmin, config.tmax)
    sol = solve_ivp(lambda t, y: fisher_kpp_wrapper(t, y, dx, config.D_true, config.r_true),
                    t_span, u0, t_eval=config.t_eval, method='RK45')

    # Biased measurement locations
    x_data = generate_biased_samples(config, config.xmin, config.xmax, config.N_data_points, bias_strength=.9)

    X_data_list, u_noisy_list = [], []
    rng = np.random.default_rng(config.seed)
    for i, t in enumerate(config.t_eval):
        u_interp = np.interp(x_data, x, sol.y[:, i])
        noise = rng.normal(0.0, config.data_error * np.abs(u_interp), size=u_interp.shape)
        u_noisy = (u_interp + noise).astype(config.np_dtype)
        u_noisy_list.append(u_noisy)
        X_data_list.append(np.column_stack([x_data, np.full_like(x_data, t, dtype=config.np_dtype)]))

    X_data = np.vstack(X_data_list).astype(config.np_dtype)
    u_data = np.hstack(u_noisy_list).astype(config.np_dtype)

    # JAX arrays
    ttype = config.jax_dtype
    X_data_tensor = jnp.asarray(X_data, dtype=ttype)
    u_data_tensor = jnp.asarray(u_data.reshape(-1, 1), dtype=ttype)
    return x, dx, u0, x_data, X_data_tensor, u_data_tensor, u_noisy_list, t_span

def generate_collocation_points(config):
    sampler = qmc.Sobol(d=2, scramble=True, seed=config.seed)
    sob = sampler.random(config.N_f).astype(config.np_dtype)
    X_f = np.zeros((config.N_f, 2), dtype=config.np_dtype)
    X_f[:, 0] = sob[:, 0] * (config.xmax - config.xmin) + config.xmin
    X_f[:, 1] = sob[:, 1] * (config.tmax - config.tmin) + config.tmin
    return jnp.asarray(X_f, dtype=config.jax_dtype)

def create_ic_bc_data(config):
    # IC: x ~ Sobol, t=0
    sampler_ic = qmc.Sobol(d=1, scramble=True, seed=config.seed + 1)
    x_ic = sampler_ic.random(config.N_ic).astype(config.np_dtype).reshape(-1)
    x_ic = x_ic * (config.xmax - config.xmin) + config.xmin
    X_ic = np.column_stack([x_ic, np.zeros_like(x_ic, dtype=config.np_dtype)])

    u_ic_target = config.initial_condition(x_ic).astype(config.np_dtype)

    # BC: t ~ Sobol, x = xmin or xmax
    sampler_bc = qmc.Sobol(d=1, scramble=True, seed=config.seed + 2)
    t_bc = sampler_bc.random(config.N_bc).astype(config.np_dtype).reshape(-1)
    t_bc = t_bc * (config.tmax - config.tmin) + config.tmin

    X_bc_left = np.column_stack([np.full_like(t_bc, config.xmin), t_bc])
    X_bc_right = np.column_stack([np.full_like(t_bc, config.xmax), t_bc])

    ttype = config.jax_dtype
    X_ic_tensor = jnp.asarray(X_ic, dtype=ttype)
    u_ic_target_tensor = jnp.asarray(u_ic_target.reshape(-1, 1), dtype=ttype)
    X_bc_left_tensor = jnp.asarray(X_bc_left, dtype=ttype)
    X_bc_right_tensor = jnp.asarray(X_bc_right, dtype=ttype)

    return X_ic_tensor, u_ic_target_tensor, X_bc_left_tensor, X_bc_right_tensor

# =========================================================
# Model 
# =========================================================
def get_activation(name: str):
    name = name.lower()
    if name == 'tanh':
        return jnp.tanh
    if name == 'relu':
        return jax.nn.relu
    if name == 'sigmoid':
        return jax.nn.sigmoid
    if name == 'gelu':
        return jax.nn.gelu
    if name == 'silu':
        return jax.nn.silu
    raise ValueError(f"Unsupported activation: {name}")

class PINN(nn.Module):
    hidden_layers: Tuple[int, ...]
    activation: str
    dtype: Any

    @nn.compact
    def __call__(self, X):
        # X: (N,2) with columns [x, t]
        act = get_activation(self.activation)
        z = X
        for h in self.hidden_layers:
            z = nn.Dense(h, dtype=self.dtype)(z)
            z = act(z)
        z = nn.Dense(1, dtype=self.dtype)(z)
        return z  # (N,1)

@dataclass
class TrainedPINN:
    params: Any
    model: Any
    training_method: str
    loss_history: dict
    param_history: dict

def model_apply(params_all, model, X):
    return model.apply({'params': params_all['net']}, X)

def get_dr_params(params_all, training_method):
    if training_method == "traditional":
        D = jnp.exp(params_all['theta']['logD'])
        r = jnp.exp(params_all['theta']['logr'])
    else:
        D = params_all['theta']['D']
        r = params_all['theta']['r']
    return D, r

# =========================================================
# Losses: IC, BC (Neumann), Physics residual
# =========================================================
def ic_loss_fn(p, model, X_ic, u_ic_target):
    u_ic_pred = model_apply(p, model, X_ic)
    return jnp.mean((u_ic_pred - u_ic_target)**2)

def _u_scalar_fn(p, model):
    # Returns a function mapping z=[x,t] -> scalar u
    def u_scalar(z):
        return model_apply(p, model, z.reshape(1, 2))[0, 0]
    return u_scalar

def _ux_ut_uxx(p, model, X):
    """
    Compute u_x, u_t, and u_xx at points X (N,2) using grad + jvp (efficient higher-order AD).
    """
    u_scalar = _u_scalar_fn(p, model)
    g = grad(u_scalar)

    def derivs(z):
        gz = g(z)                       # (2,)
        ux, ut = gz[0], gz[1]
        ex = jnp.array([1., 0.], dtype=z.dtype)
        _, uxx = jvp(lambda zz: g(zz)[0], (z,), (ex,))  # d/dx of (du/dx)
        return ux, ut, uxx

    ux, ut, uxx = vmap(derivs)(X)
    return ux.reshape(-1, 1), ut.reshape(-1, 1), uxx.reshape(-1, 1)

def bc_loss_fn(p, model, X_left, X_right):
    u_scalar = _u_scalar_fn(p, model)
    g = grad(u_scalar)  # grad wrt (x,t)
    def ux_at(z): return g(z)[0]
    ux_left = vmap(ux_at)(X_left).reshape(-1, 1)
    ux_right = vmap(ux_at)(X_right).reshape(-1, 1)
    return 0.5 * (jnp.mean(ux_left**2) + jnp.mean(ux_right**2))

def physics_loss_fn(p, model, X_f, training_method):
    D, r = get_dr_params(p, training_method)
    ux, ut, uxx = _ux_ut_uxx(p, model, X_f)
    u = model_apply(p, model, X_f)
    # Residual: u_t - D u_xx - r u (1 - u)
    f = ut - D * uxx - r * u * (1. - u)
    return jnp.mean(f**2)

# =========================================================
# Learning-rate schedule (linear -> constant)
# =========================================================
def make_jittable_schedule(config):
    # First segment: linear from 1e-2 -> 1e-4 over [0, max(1, epochs-30000))
    # Second: constant 1e-4 on [epochs-30000, epochs)
    p1_end = int(max(1, config.epochs - 30000))
    init1, final1, const2 = 1e-2, 1e-4, 1e-4
    dtype = jnp.float32 if config.precision == 32 else jnp.float64

    def schedule(count):
        count = jnp.asarray(count, jnp.int32)
        length = jnp.asarray(p1_end, jnp.int32)
        progress = jnp.clip(count / jnp.maximum(length, 1), 0, 1).astype(dtype)
        lr_lin = (init1 + progress * (final1 - init1)).astype(dtype)
        lr = jnp.where(count < length, lr_lin, jnp.asarray(const2, dtype))
        return lr
    return schedule

# =========================================================
# MDMM 
# =========================================================
class LagrangeMultiplier(NamedTuple):
    value: Any

def mdmm_prepare_update(tree):
    pred = lambda x: isinstance(x, LagrangeMultiplier)
    return jax.tree_map(
        lambda x: LagrangeMultiplier(-x.value) if pred(x) else x,
        tree, is_leaf=pred
    )

def optax_prepare_update():
    def init_fn(params):
        del params
        return optax.EmptyState()
    def update_fn(updates, state, params=None):
        del params
        return mdmm_prepare_update(updates), state
    return optax.GradientTransformation(init_fn, update_fn)

class Constraint(NamedTuple):
    init: Callable
    loss: Callable

def eq_constraint(fun, damping=1., weight=1., reduction=jnp.mean):
    def init_fn(p):
        out = fun(p)
        return {'lambda': LagrangeMultiplier(jnp.zeros_like(out))}
    def loss_fn(cparams, p):
        inf = fun(p)
        return weight * reduction(cparams['lambda'].value * inf + damping * 0.5 * inf**2), inf
    return Constraint(init_fn, loss_fn)

def bound_hard(fun_value, lo, hi, damping=1., weight=100., reduction=jnp.mean):
    # hard bound via clipping penalty
    return eq_constraint(lambda p: jnp.clip(fun_value(p), lo, hi) - fun_value(p),
                         damping=damping, weight=weight, reduction=reduction)

def combine_constraints(*constraints):
    inits = [c.init for c in constraints]
    losses = [c.loss for c in constraints]
    def init_all(p):
        return tuple(fn(p) for fn in inits)
    def loss_all(cparams, p):
        vals = [fn(cp, p) for fn, cp in zip(losses, cparams)]
        return sum(v[0] for v in vals), tuple(v[1] for v in vals)
    return Constraint(init_all, loss_all)

# =========================================================
# Training: MDMM
# =========================================================
def train_mdmm(model, config, X_data, u_data, X_f, X_ic, u_ic_target, X_bc_left, X_bc_right):
    dtype = config.jax_dtype
    rng = jax.random.PRNGKey(config.seed)
    net_params = model.init(rng, jnp.zeros((1, 2), dtype=dtype))['params']
    params_all = {
        'net': net_params,
        'theta': {
            'D': jnp.asarray(config.D_init, dtype),
            'r': jnp.asarray(config.r_init, dtype),
        }
    }

    # Constraints
    physics_c = eq_constraint(lambda p: physics_loss_fn(p, model, X_f, "mdmm"), damping=1., weight=1.)
    ic_c      = eq_constraint(lambda p: ic_loss_fn(p, model, X_ic, u_ic_target), damping=1., weight=1.)
    bc_c      = eq_constraint(lambda p: bc_loss_fn(p, model, X_bc_left, X_bc_right), damping=1., weight=1.)
    (D_lo, D_hi), (r_lo, r_hi) = config.parameter_bounds
    bound_D   = bound_hard(lambda p: p['theta']['D'], D_lo, D_hi, damping=1., weight=100.)
    bound_r   = bound_hard(lambda p: p['theta']['r'], r_lo, r_hi, damping=1., weight=100.)
    constraints = combine_constraints(physics_c, ic_c, bc_c, bound_D, bound_r)

    mdmm_params = constraints.init(params_all)
    params_all = {'net': net_params, 'theta': params_all['theta'], 'mdmm': mdmm_params}

    # Optimizer
    schedule = make_jittable_schedule(config)
    optimizer = optax.chain(
        optax.adam(learning_rate=schedule),
        optax_prepare_update(),
    )
    opt_state = optimizer.init(params_all)

    def data_rmse_loss(p):
        u_pred = model_apply(p, model, X_data)
        return jnp.sqrt(jnp.mean((u_pred - u_data)**2))

    def total_loss_and_logs(p):
        dloss = data_rmse_loss(p)
        mdmm_term, _ = constraints.loss(p['mdmm'], p)
        value = dloss + mdmm_term
        ic_val   = ic_loss_fn(p, model, X_ic, u_ic_target)
        phys_val = physics_loss_fn(p, model, X_f, "mdmm")
        bc_val   = bc_loss_fn(p, model, X_bc_left, X_bc_right)
        return value, (dloss, ic_val, bc_val, phys_val)

    @jax.jit
    def mdmm_step(p, state):
        (val, logs), grads = jax.value_and_grad(total_loss_and_logs, has_aux=True)(p)
        updates, state = optimizer.update(grads, state, p)
        p = optax.apply_updates(p, updates)
        dloss, ic_val, bc_val, phys_val = logs
        return p, state, dloss, ic_val, bc_val, phys_val

    loss_history = {"data": [], "ic": [], "bc": [], "physics": [], "memory_mb": []}
    param_history = {"D": [], "r": []}

    start_train = time.time()
    for it in range(config.epochs):
        params_all, opt_state, dloss, ic_val, bc_val, phys_val = mdmm_step(params_all, opt_state)
        loss_history["data"].append(float(dloss))
        loss_history["ic"].append(float(ic_val))
        loss_history["bc"].append(float(bc_val))
        loss_history["physics"].append(float(phys_val))
        D_val, r_val = get_dr_params(params_all, "mdmm")
        param_history["D"].append(float(D_val))
        param_history["r"].append(float(r_val))
        loss_history["memory_mb"].append(float(get_process_memory_mb()))
        if (it + 1) % 500 == 0:
            print(f"Iter {it+1:5d} | Data {float(dloss):.5e} | IC {float(ic_val):.5e} "
                  f"| BC {float(bc_val):.5e} | Phys {float(phys_val):.5e} "
                  f"| D {float(D_val):.5f} | r {float(r_val):.5f}")
    train_runtime = time.time() - start_train

    trained = TrainedPINN(params=params_all, model=model, training_method="mdmm",
                          loss_history=loss_history, param_history=param_history)
    return trained, train_runtime

# =========================================================
# Training: Traditional (sum of losses) with log-parameterization
# =========================================================
def train_traditional(model, config, X_data, u_data, X_f, X_ic, u_ic_target, X_bc_left, X_bc_right):
    dtype = config.jax_dtype
    rng = jax.random.PRNGKey(config.seed)
    net_params = model.init(rng, jnp.zeros((1, 2), dtype=dtype))['params']
    params_all = {
        'net': net_params,
        'theta': {
            'logD': jnp.asarray(np.log(config.D_init), dtype),
            'logr': jnp.asarray(np.log(config.r_init), dtype),
        }
    }
    schedule = make_jittable_schedule(config)
    optimizer = optax.adam(learning_rate=schedule)
    opt_state = optimizer.init(params_all)

    def component_losses(p):
        loss_data = jnp.sqrt(jnp.mean((model_apply(p, model, X_data) - u_data)**2))
        loss_ic   = ic_loss_fn(p, model, X_ic, u_ic_target)
        loss_bc   = bc_loss_fn(p, model, X_bc_left, X_bc_right)
        loss_phys = physics_loss_fn(p, model, X_f, "traditional")
        return loss_data, loss_ic, loss_bc, loss_phys

    def total_loss(p):
        ld, li, lb, lp = component_losses(p)
        return ld + li + lb + lp, (ld, li, lb, lp)

    @jax.jit
    def step(p, state):
        (tot, logs), grads = jax.value_and_grad(total_loss, has_aux=True)(p)
        updates, state = optimizer.update(grads, state, p)
        p = optax.apply_updates(p, updates)
        return p, state, tot, *logs

    loss_history = {"data": [], "ic": [], "bc": [], "physics": [], "memory_mb": []}
    param_history = {"D": [], "r": []}

    start_train = time.time()
    for it in range(config.epochs):
        params_all, opt_state, tot, ld, li, lb, lp = step(params_all, opt_state)
        loss_history["data"].append(float(ld))
        loss_history["ic"].append(float(li))
        loss_history["bc"].append(float(lb))
        loss_history["physics"].append(float(lp))
        D_val, r_val = get_dr_params(params_all, "traditional")
        param_history["D"].append(float(D_val))
        param_history["r"].append(float(r_val))
        loss_history["memory_mb"].append(float(get_process_memory_mb()))
        if (it + 1) % 500 == 0:
            print(f"Iter {it+1:5d} | Total {float(tot):.5e} | Data {float(ld):.5e} | IC {float(li):.5e} "
                  f"| BC {float(lb):.5e} | Phys {float(lp):.5e} | D {float(D_val):.5f} | r {float(r_val):.5f}")
    train_runtime = time.time() - start_train

    trained = TrainedPINN(params=params_all, model=model, training_method="traditional",
                          loss_history=loss_history, param_history=param_history)
    return trained, train_runtime

# =========================================================
# Nelder–Mead training (NumPy/SciPy solve_ivp)
# =========================================================
def train_nelder_mead(config, x, dx, u0, x_data, u_noisy_list, t_span):
    log = {"loss": [], "D": [], "r": [], "memory_mb": []}

    def objective_fn(p):
        D_est, r_est = p
        if D_est < 0 or r_est < 0:
            return 1e12
        sol = solve_ivp(lambda t, y: fisher_kpp_wrapper(t, y, dx, D_est, r_est),
                        t_span, u0, t_eval=config.t_eval, method='RK45')
        y = np.array(sol.y)
        mse = 0.0
        for i, _t in enumerate(config.t_eval):
            u_model = np.interp(x_data, x, y[:, i])
            mse += np.mean((u_model - u_noisy_list[i])**2)
        current = mse
        log["loss"].append(float(current))
        log["D"].append(float(D_est))
        log["r"].append(float(r_est))
        log["memory_mb"].append(float(get_process_memory_mb()))
        return math.sqrt(current)

    initial_guess = [config.D_init, config.r_init]
    res = minimize(objective_fn, initial_guess, method='Nelder-Mead', bounds=config.parameter_bounds)
    print("Nelder–Mead optimization success:", res.success)
    return res.x, log

# =========================================================
# Plotting
# =========================================================
def plot_results(model_or_params, config, x, x_data, u_noisy_list, u0, t_span, dx, device_str):
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    folder_name = f"FisherKPP_{config.training_method}_initial_guess_error_{config.initial_guess_noise}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)

    num_plots = len(config.plot_times)
    plt.figure(figsize=(7 * num_plots, 6))

    if isinstance(model_or_params, TrainedPINN):
        D_est, r_est = get_dr_params(model_or_params.params, model_or_params.training_method)
        D_est, r_est = float(D_est), float(r_est)
    else:
        D_est, r_est = float(model_or_params[0]), float(model_or_params[1])

    def f_est(t, u): return fisher_kpp_mol_rhs(t, u, dx, D_est, r_est)
    measurement_data = {t: (x_data, u_noisy) for t, u_noisy in zip(config.t_eval, u_noisy_list)}

    trajectories = {}

    for idx, t_plot in enumerate(config.plot_times):
        sol_true = solve_ivp(lambda t, y: fisher_kpp_wrapper(t, y, dx, config.D_true, config.r_true),
                             t_span, u0, t_eval=[t_plot], method='RK45')
        u_true = np.array(sol_true.y)[:, 0]
        sol_est = solve_ivp(f_est, t_span, u0, t_eval=[t_plot], method='RK45')
        u_est = np.array(sol_est.y)[:, 0]

        ax = plt.subplot(1, num_plots, idx + 1)
        ax.plot(x, u_true, 'k--', label='True solution (MoL)')
        ax.plot(x, u_est, 'g-', label='Estimated trajectory')

        trajectories[f"x_time_{t_plot}"] = x
        trajectories[f"u1_time_{t_plot}_true"] = u_true
        trajectories[f"u1_time_{t_plot}_estimated"] = u_est

        if isinstance(model_or_params, TrainedPINN):
            X_plot = jnp.asarray(np.column_stack((x, np.full_like(x, t_plot))), dtype=config.jax_dtype)
            u_pinn = np.array(model_or_params.model.apply({'params': model_or_params.params['net']}, X_plot)).flatten()
            ax.plot(x, u_pinn, 'b-', label='PINN prediction')
            trajectories[f"u1_time_{t_plot}_PINN"] = u_pinn

        if t_plot in measurement_data:
            x_meas, u_meas = measurement_data[t_plot]
            ax.plot(x_meas, u_meas, 'ro', label='Noisy data')

        ax.set_xlabel('x'); ax.set_ylabel('u'); ax.set_title(f"t = {t_plot}")
        ax.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(figures_folder, "Solution_comparison.png"), dpi=300)
    np.savez(os.path.join(full_folder, "trajectories.npz"), **trajectories)

def plot_loss(loss_log, method, config):
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    folder_name = f"FisherKPP_{method}_initial_guess_error_{config.initial_guess_noise}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)

    if method in ["mdmm", "traditional"]:
        fig, axs = plt.subplots(2, 2, figsize=(10, 8))
        axs[0, 0].semilogy(loss_log["data"]);     axs[0, 0].set_title("Data Loss")
        axs[0, 1].semilogy(loss_log["ic"]);       axs[0, 1].set_title("IC Loss")
        axs[1, 0].semilogy(loss_log["bc"]);       axs[1, 0].set_title("BC Loss")
        axs[1, 1].semilogy(loss_log["physics"]);  axs[1, 1].set_title("Physics Loss")
        for ax in axs.flatten(): ax.set_xlabel("Epoch")
        plt.tight_layout()
    else:
        plt.figure(figsize=(6,4))
        plt.semilogy(loss_log["loss"]); plt.xlabel("Iteration"); plt.title("Objective Loss (Nelder–Mead)")
        plt.tight_layout()

    plt.savefig(os.path.join(figures_folder, "Loss_evolution.png"), dpi=300)

def plot_parameters(param_log, method, config):
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    folder_name = f"FisherKPP_{method}_initial_guess_error_{config.initial_guess_noise}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)

    fig, axs = plt.subplots(1, 2, figsize=(10, 4))
    axs[0].plot(param_log["D"]); axs[0].set_title("D parameter")
    axs[1].plot(param_log["r"]); axs[1].set_title("r parameter")
    axs[0].axhline(y=config.D_true, color='r', linestyle='--')
    axs[1].axhline(y=config.r_true, color='r', linestyle='--')
    axs[0].set_xlabel("Epoch" if method in ["mdmm", "traditional"] else "Iteration")
    axs[1].set_xlabel("Epoch" if method in ["mdmm", "traditional"] else "Iteration")
    plt.tight_layout()
    plt.savefig(os.path.join(figures_folder, "Parameter_evolution.png"), dpi=300)

# =========================================================
# Max distance metric
# =========================================================
def measure_max_distance(model_or_params, config, x, dx, u0, t_span):
    distances = []
    times = list(config.t_eval)
    for t in times:
        if isinstance(model_or_params, TrainedPINN):
            D_val, r_val = get_dr_params(model_or_params.params, model_or_params.training_method)
            D_val, r_val = float(D_val), float(r_val)
        else:
            D_val, r_val = float(model_or_params[0]), float(model_or_params[1])

        sol_ml = solve_ivp(lambda tt, y: fisher_kpp_wrapper(tt, y, dx, D_val, r_val),
                           t_span, u0, t_eval=[t], method='RK45')
        u_ml = np.array(sol_ml.y)[:, 0]
        X_compare = jnp.asarray(np.column_stack((x, np.full_like(x, t))), dtype=config.jax_dtype)
        if isinstance(model_or_params, TrainedPINN):
            u_pinn = np.array(model_or_params.model.apply({'params': model_or_params.params['net']}, X_compare)).flatten()
        else:
            u_pinn = u_ml
        distances.append(float(np.max(np.abs(u_ml - u_pinn))))
    return max(distances) if distances else float('nan')

# =========================================================
# Save outputs and summary
# =========================================================
def save_outputs_and_summary(config, method, x, dx, u0, x_data, u_noisy_list, t_span, runtime,
                             trained_model=None, optimal_params=None, log=None, max_distance=None):
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    folder_name = f"FisherKPP_{config.training_method}_initial_guess_error_{config.initial_guess_noise}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)

    if method in ["mdmm", "traditional"]:
        with open(os.path.join(full_folder, "trained_model.pt"), "wb") as f:
            pickle.dump({
                'params': trained_model.params,
                'training_method': trained_model.training_method,
                'hidden_layers': config.hidden_layers,
                'activation': config.activation,
                'precision': config.precision
            }, f)
        np.savez(os.path.join(full_folder, "noisy_data.npz"), x_data=x_data, u_noisy_list=u_noisy_list)
        np.savez(os.path.join(full_folder, "loss_history.npz"), **trained_model.loss_history)
        np.savez(os.path.join(full_folder, "parameter_history.npz"), **trained_model.param_history)
    elif method == "nelder-mead":
        np.savez(os.path.join(full_folder, "optimal_params.npz"), optimal_params=optimal_params)
        np.savez(os.path.join(full_folder, "loss_log.npz"), **log)
        np.savez(os.path.join(full_folder, "noisy_data.npz"), x_data=x_data, u_noisy_list=u_noisy_list)
        np.savez(os.path.join(full_folder, "parameter_history.npz"), D=log["D"], r=log["r"])

    # Determine estimates
    if method in ["mdmm", "traditional"]:
        D_est, r_est = get_dr_params(trained_model.params, method)
        D_est, r_est = float(D_est), float(r_est)
    else:
        D_est, r_est = float(optimal_params[0]), float(optimal_params[1])

    # RMSE over measurement snapshots
    errs = []
    for i, t in enumerate(config.t_eval):
        sol_ml = solve_ivp(lambda tt, y: fisher_kpp_wrapper(tt, y, dx, D_est, r_est),
                           t_span, u0, t_eval=[t], method='RK45')
        u_ml = np.array(sol_ml.y)[:, 0]
        u_sim = np.interp(x_data, x, u_ml)
        errs.append(np.mean((u_sim - u_noisy_list[i])**2))
    rmse_data = float(np.sqrt(np.mean(errs)))
    rrms_params = float(np.sqrt((((D_est - config.D_true)/config.D_true)**2 + ((r_est - config.r_true)/config.r_true)**2)/2))

    # Memory stats
    if method in ["mdmm", "traditional"]:
        mem_list = trained_model.loss_history.get("memory_mb", [])
    else:
        mem_list = log.get("memory_mb", [])
    memory_usage_avg = float(np.mean([m for m in mem_list if np.isfinite(m)])) if mem_list else None
    memory_usage_max = float(np.max([m for m in mem_list if np.isfinite(m)])) if mem_list else None

    # Serialize summary
    summary = {
        "training_method": config.training_method,
        "runtime_seconds": float(runtime),
        "initial_condition": "(1/10) * exp(-x)",
        "boundary_conditions": "Neumann BC: u_x=0 at x=xmin and x=xmax",
        "parameter_initial_guess": [float(config.D_init), float(config.r_init)],
        "parameter_ground_truth": [float(config.D_true), float(config.r_true)],
        "data_noise": float(config.data_error),
        "initial_guess_error": float(config.initial_guess_noise),
        "root_mse_data": rmse_data,
        "root_relative_mse_parameters": rrms_params,
        "plot_times": config.plot_times,
        "xmin": config.xmin, "xmax": config.xmax,
        "tmin": config.tmin, "tmax": config.tmax,
        "t_eval": config.t_eval,
        "hidden_layers": config.hidden_layers,
        "activation": config.activation,
        "parameter_estimates": [float(D_est), float(r_est)],
        "max_distance": float(max_distance) if max_distance is not None else (float('nan') if method=="nelder-mead" else None),
        "average_memory_usage_mb": memory_usage_avg,
        "max_memory_usage_mb": memory_usage_max
    }
    with open(os.path.join(full_folder, "summary_statistics.json"), "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=4)

# =========================================================
# Main
# =========================================================
def main():
    parser = argparse.ArgumentParser(description="Train PINN for Fisher–KPP (JAX/Flax/Optax)")
    parser.add_argument("--epochs", type=int, default=20000, help="Number of training epochs")
    parser.add_argument("--initial_guess_error", type=float, default=5.0, help="Initial guess error multiplier")
    parser.add_argument("--data_error", type=float, default=0.3, help="Data noise level")
    parser.add_argument("--N_data_points", type=int, default=18, help="Number of data points")
    parser.add_argument("--training_method", type=str, default="mdmm", help="mdmm | traditional | nelder-mead")
    parser.add_argument("--precision", type=int, default=64, help="Precision for JAX (32 or 64)")
    args = parser.parse_args()

    start_time = time.time()

    config = Config()
    config.epochs = args.epochs
    config.N_data_points = args.N_data_points
    config.initial_guess_noise = args.initial_guess_error
    config.data_error = args.data_error
    config.D_init = (1 + config.initial_guess_noise) * config.D_true
    config.r_init = (1 + config.initial_guess_noise) * config.r_true
    config.training_method = args.training_method
    config.precision = args.precision
    if config.precision == 64:
        jax.config.update("jax_enable_x64", True)
        config.np_dtype = np.float64
        config.jax_dtype = jnp.float64
    else:
        jax.config.update("jax_enable_x64", False)
        config.np_dtype = np.float32
        config.jax_dtype = jnp.float32

    np.random.seed(config.seed)
    device_str = jax.devices()[0].platform
    print("Using device:", device_str)

    # Data and collocation
    x, dx, u0, x_data, X_data, u_data, u_noisy_list, t_span = create_forward_data(config)
    X_f = generate_collocation_points(config)
    X_ic, u_ic_target, X_bc_left, X_bc_right = create_ic_bc_data(config)

    method = config.training_method.lower()
    hidden_layers = tuple(config.hidden_layers)
    model = PINN(hidden_layers=hidden_layers, activation=config.activation, dtype=config.jax_dtype)

    if method in ["mdmm", "traditional"]:
        if method == "mdmm":
            print("Training using MDMM...")
            trained_model, train_runtime = train_mdmm(model, config, X_data, u_data, X_f, X_ic, u_ic_target, X_bc_left, X_bc_right)
        else:
            print("Training using traditional loss minimization...")
            trained_model, train_runtime = train_traditional(model, config, X_data, u_data, X_f, X_ic, u_ic_target, X_bc_left, X_bc_right)

        D_est, r_est = get_dr_params(trained_model.params, method)
        print(f"\nFinal recovered parameters: D = {float(D_est):.5f}, r = {float(r_est):.5f}")
        plot_results(trained_model, config, x, x_data, u_noisy_list, u0, t_span, dx, device_str)
        plot_loss(trained_model.loss_history, method, config)
        plot_parameters(trained_model.param_history, method, config)
        max_distance = measure_max_distance(trained_model, config, x, dx, u0, t_span)
        print("Max distance (MoL vs PINN):", max_distance)
        runtime = train_runtime
        optimal_params = None
        log = None
    elif method == "nelder-mead":
        print("Training using Nelder–Mead...")
        optimal_params, log = train_nelder_mead(config, x, dx, u0, x_data, u_noisy_list, t_span)
        print(f"\nNelder–Mead: Final recovered parameters: D = {optimal_params[0]:.5f}, r = {optimal_params[1]:.5f}")
        plot_results(optimal_params, config, x, x_data, u_noisy_list, u0, t_span, dx, device_str)
        plot_loss(log, method, config)
        plot_parameters(log, method, config)
        trained_model = None
        max_distance = None
        runtime = time.time() - start_time
    else:
        raise ValueError(f"Unknown training method: {config.training_method}")

    save_outputs_and_summary(config, method, x, dx, u0, x_data, u_noisy_list, t_span, runtime,
                             trained_model=trained_model if method in ["mdmm", "traditional"] else None,
                             optimal_params=optimal_params if method=="nelder-mead" else None,
                             log=log if method=="nelder-mead" else None,
                             max_distance=max_distance)

if __name__ == "__main__":
    main()
