import os
os.environ.setdefault("JAX_PLATFORMS", "cpu")                 # avoid accidental GPU init if not desired
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import minimize
from scipy.stats import qmc, truncnorm
from numba import njit
import math
import time
import json
import argparse
import pickle
import inspect
import platform

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, jacrev, jvp
from jax import tree_util as jtu
import flax.linen as nn
import optax
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Tuple, Dict, Sequence, Optional

# ---------------------------------------------
# Memory monitoring utilities (portable fallbacks)
# ---------------------------------------------
try:
    import psutil  # preferred if available
except Exception:
    psutil = None
try:
    import resource as _resource  # Unix fallback
except Exception:
    _resource = None

def _read_proc_status_vmrss_bytes():
    try:
        with open("/proc/self/status", "r") as f:
            for line in f:
                if line.startswith("VmRSS:"):
                    parts = line.split()
                    if len(parts) >= 2:
                        return int(parts[1]) * 1024  # kB -> bytes
    except Exception:
        pass
    return None

def get_process_memory_mb():
    """
    Returns the current process resident memory in MB (best-effort).
    Prefers psutil; falls back to /proc/self/status (Linux) or ru_maxrss (Unix).
    """
    mem_bytes = None
    try:
        if psutil is not None:
            mem_bytes = psutil.Process(os.getpid()).memory_info().rss
        else:
            mem_bytes = _read_proc_status_vmrss_bytes()
            if mem_bytes is None and _resource is not None:
                ru = _resource.getrusage(_resource.RUSAGE_SELF)
                # ru_maxrss: bytes on macOS, kilobytes on Linux
                if platform.system() == "Darwin":
                    mem_bytes = int(ru.ru_maxrss)
                else:
                    mem_bytes = int(ru.ru_maxrss) * 1024
    except Exception:
        mem_bytes = None
    if mem_bytes is None:
        return float('nan')
    return mem_bytes / (1024.0 * 1024.0)

# ---------------------------------------------
# Configuration
# ---------------------------------------------
class Config:
    # Spatial domain and discretization for Method of Lines (forward solver)
    xmin = -1.0
    xmax = 1.0
    L = xmax - xmin
    Nx = 4000
    N_data_points = 14

    # Time domain
    tmin = 0.0
    tmax = 0.5

    # Measurement times for inverse problem
    t_eval = [0.2, 0.5]

    # Plot times
    plot_times = [0.0, 0.2, 0.5]

    # Parameter bounds (viscosity D)
    param_bounds = [0.0, 0.07]

    # Ground-truth viscosity
    D_true = 0.01

    # Initial guess: D_init = (1 + noise) * D_true
    initial_guess_noise = 5.0
    D_init = (1 + initial_guess_noise) * D_true

    # Initial condition u(x,0) = -sin(pi x)
    @staticmethod
    def initial_condition(x):
        return -np.sin(np.pi * x)

    # Data noise
    data_error = 0.3

    # PINN architecture
    hidden_layers = [20, 20]
    activation = 'tanh'
    num_fourier = 10  # periodic embedding modes

    # Collocation points
    N_f = 16384
    N_bc = 1024
    N_ic = 1024

    # Training
    epochs = 30000
    learning_rate = 1e-2

    # Seed
    seed = 1234

    # Method: "mdmm", "traditional", or "nelder-mead"
    training_method = "mdmm"

    # Output
    output_folder = "."

    # Precision: 32 or 64
    precision = 64

# ---------------------------------------------
# Data utilities
# ---------------------------------------------
def generate_dense_midpoints(xmin, xmax, N_data_points, sigma_factor=0.25, shift=0.1):
    """
    Generate points with higher density around a (slightly shifted) midpoint using a truncated normal.
    """
    midpoint = shift + (xmax + xmin) / 2.0
    range_width = xmax - xmin
    sigma = range_width * sigma_factor
    a, b = (xmin - midpoint) / sigma, (xmax - midpoint) / sigma
    dist = truncnorm(a, b, loc=midpoint, scale=sigma)
    x_data = dist.rvs(size=N_data_points)
    x_data = np.sort(x_data)
    x_data = x_data[1:-1]  # trim extremes
    return x_data

# ---------------------------------------------
# Burgers equation (Method of Lines) for forward solve
# ---------------------------------------------
@njit
def burgers_jit(t, u, dx, D):
    N = u.shape[0]
    dudt = np.empty_like(u)
    # Dirichlet BCs
    dudt[0] = 0.0
    for i in range(1, N-1):
        u_x = (u[i+1] - u[i-1]) / (2.0 * dx)
        u_xx = (u[i+1] - 2.0 * u[i] + u[i-1]) / (dx*dx)
        dudt[i] = D * u_xx - u[i] * u_x
    dudt[N-1] = 0.0
    return dudt


# ---------------------------------------------
# Forward data creation
# ---------------------------------------------
def create_forward_data(config: Config, device: str):
    """
    Build spatial grid, solve forward problem at measurement times, sample noisy data.
    Returns:
      x, dx, u0, x_data, X_data_tensor (jnp[N_total,2]), u_data_tensor (jnp[N_total,1]),
      u_noisy_list (list per t in t_eval), t_span
    """
    # Grid & IC
    x = np.linspace(config.xmin, config.xmax, config.Nx)
    dx = x[1] - x[0]
    u0 = config.initial_condition(x)

    # Forward solve at t_eval
    t_span = (config.tmin, config.tmax)
    sol = solve_ivp(lambda t, u: burgers_jit(t, u, dx, config.D_true),
                    t_span, u0, t_eval=config.t_eval, method='Radau')

    # Measurement locations 
    x_data = generate_dense_midpoints(config.xmin, config.xmax, config.N_data_points + 4, sigma_factor=0.5)[1:-1]

    # Noise & assembly
    X_data_list = []
    u_noisy_list = []
    for i, t in enumerate(config.t_eval):
        u_interp = np.interp(x_data, x, sol.y[:, i])
        noise = np.random.normal(0.0, config.data_error * np.abs(u_interp), size=u_interp.shape)
        u_noisy = u_interp + noise
        u_noisy_list.append(u_noisy)
        X_data_list.append(np.column_stack([x_data, np.full_like(x_data, t, dtype=float)]))

    X_data = np.vstack(X_data_list)                        # (N_total, 2)
    u_data = np.hstack(u_noisy_list).reshape(-1, 1)        # (N_total, 1)

    dtype = jnp.float64 if config.precision == 64 else jnp.float32
    X_data_tensor = jnp.array(X_data, dtype=dtype)
    u_data_tensor = jnp.array(u_data, dtype=dtype)

    return x, dx, u0, x_data, X_data_tensor, u_data_tensor, u_noisy_list, t_span

def generate_collocation_points(config: Config, device: str):
    """
    Sobol sampling in (x,t) over [xmin, xmax] × [tmin, tmax].
    """
    sampler = qmc.Sobol(d=2, scramble=True, seed=config.seed)
    sobol_samples = sampler.random(config.N_f)
    X_f = np.empty((config.N_f, 2), dtype=float)
    X_f[:, 0] = sobol_samples[:, 0] * (config.xmax - config.xmin) + config.xmin
    X_f[:, 1] = sobol_samples[:, 1] * (config.tmax - config.tmin) + config.tmin
    dtype = jnp.float64 if config.precision == 64 else jnp.float32
    return jnp.array(X_f, dtype=dtype)

def create_ic_bc_data(config: Config, device: str):
    """
    IC at t=0 via Sobol in x; BC at x=xmin/xmax via Sobol in t.
    """
    dtype = jnp.float64 if config.precision == 64 else jnp.float32

    # IC
    sampler_ic = qmc.Sobol(d=1, scramble=True, seed=config.seed)
    x_ic = sampler_ic.random(config.N_ic)[:, 0] * (config.xmax - config.xmin) + config.xmin
    x_ic = np.sort(x_ic)
    X_ic = np.column_stack([x_ic, np.zeros_like(x_ic)])
    u_ic_target = config.initial_condition(x_ic).reshape(-1, 1)

    # BC (t)
    sampler_bc = qmc.Sobol(d=1, scramble=True, seed=config.seed + 1)
    t_bc = sampler_bc.random(config.N_bc)[:, 0] * (config.tmax - config.tmin) + config.tmin
    t_bc = np.sort(t_bc)
    X_bc_left = np.column_stack([np.full_like(t_bc, config.xmin), t_bc])
    X_bc_right = np.column_stack([np.full_like(t_bc, config.xmax), t_bc])

    return (jnp.array(X_ic, dtype=dtype),
            jnp.array(u_ic_target, dtype=dtype),
            jnp.array(X_bc_left, dtype=dtype),
            jnp.array(X_bc_right, dtype=dtype))

# ---------------------------------------------
# Flax PINN with periodic Fourier features in x
# ---------------------------------------------
def get_activation(name: str):
    name = name.lower()
    if name == 'tanh':   return jnp.tanh
    if name == 'relu':   return jax.nn.relu
    if name == 'sigmoid':return jax.nn.sigmoid
    if name == 'gelu':   return jax.nn.gelu
    if name == 'silu':   return jax.nn.silu
    raise ValueError(f"Unsupported activation: {name}")

class PINN(nn.Module):
    hidden_layers: Tuple[int, ...]
    activation: str
    dtype: Any
    num_fourier: int
    xmin: float
    L: float

    @nn.compact
    def __call__(self, X):  # X: (N,2) with columns (x,t)
        act = get_activation(self.activation)
        x = X[:, 0:1]
        t = X[:, 1:2]
        freqs = jnp.arange(1, self.num_fourier + 1, dtype=self.dtype).reshape(1, -1)
        x_scaled = 2.0 * jnp.pi * (x - self.xmin) / self.L
        sin_feat = jnp.sin(x_scaled * freqs)  # (N, num_fourier)
        cos_feat = jnp.cos(x_scaled * freqs)  # (N, num_fourier)
        features = jnp.concatenate([sin_feat, cos_feat, t], axis=1)  # (N, 2*num_fourier + 1)

        h = features
        for hdim in self.hidden_layers:
            h = nn.Dense(hdim, dtype=self.dtype)(h)
            h = act(h)
        out = nn.Dense(1, dtype=self.dtype)(h)
        return out  # (N,1)

@dataclass
class TrainedPINN:
    params: Dict[str, Any]
    model: Any
    training_method: str
    loss_history: Dict[str, Sequence[float]]
    param_history: Dict[str, Sequence[float]]

def model_apply(params_all, model, X):
    return model.apply({'params': params_all['net']}, X)

def get_D(params_all, method: str):
    if method == "traditional":
        return jnp.exp(params_all['theta']['logD'])
    return params_all['theta']['D']

# ---------------------------------------------
# Losses
# ---------------------------------------------
def ic_loss_fn(params_all, model, X_ic, u_ic_target):
    u_ic_pred = model_apply(params_all, model, X_ic)
    return jnp.mean((u_ic_pred - u_ic_target)**2)

def bc_loss_fn(params_all, model, X_bc_left, X_bc_right):
    zeros_left  = jnp.zeros((X_bc_left.shape[0], 1), dtype=X_bc_left.dtype)
    zeros_right = jnp.zeros((X_bc_right.shape[0],1), dtype=X_bc_right.dtype)
    u_left  = model_apply(params_all, model, X_bc_left)
    u_right = model_apply(params_all, model, X_bc_right)
    return 0.5 * (jnp.mean((u_left - zeros_left)**2) + jnp.mean((u_right - zeros_right)**2))

def physics_loss_fn(params_all, model, X_f, method: str, eps=0.0):
    """
    PDE residual MSE: f = u_t + u * u_x - D * u_xx
    Uses JVP-of-grad to obtain u_xx efficiently (avoids forming full Hessians).
    """
    dtype = X_f.dtype

    def u_scalar(z):  # z: (2,)
        return model_apply(params_all, model, z.reshape(1, 2))[0, 0]

    grad_u = jax.jit(jax.grad(u_scalar))
    def derivs(z):
        g = grad_u(z)
        u_t = g[1]
        u_x = g[0]
        ex = jnp.array([1.0, 0.0], dtype=dtype)
        _, u_xx = jvp(lambda z_: jax.grad(u_scalar)(z_)[0], (z,), (ex,))
        return u_t, u_x, u_xx

    u_vals = model_apply(params_all, model, X_f)[:, 0]            # (N,)
    u_t, u_x, u_xx = vmap(derivs)(X_f)                            # (N,), (N,), (N,)
    D = get_D(params_all, method)
    f = u_t + u_vals * u_x - D * u_xx
    return jnp.mean(f**2)

def data_rrms_loss(params_all, model, X_data, u_data, eps=1e-6):
    u_pred = model_apply(params_all, model, X_data)
    return jnp.sqrt(jnp.mean(((u_pred - u_data) / (u_data + eps))**2))

# ---------------------------------------------
# Learning-rate schedules 
# ---------------------------------------------
def make_linear_const_schedule(init_lr: float, final_lr1: float, const_lr2: float,
                               switch_epoch: int, total_epochs: int, dtype):
    p1_start = 0
    p1_end   = int(switch_epoch)
    init1, final1, const2 = init_lr, final_lr1, const_lr2

    p1_start_f = jnp.asarray(p1_start, jnp.int32)
    p1_end_f   = jnp.asarray(p1_end,   jnp.int32)

    def schedule(count):
        count = count.astype(jnp.int32)
        length = jnp.maximum(p1_end_f - p1_start_f, jnp.int32(1))
        progress = jnp.clip((count - p1_start_f) / length, 0, 1).astype(dtype)
        lr_lin = (init1 + progress * (final1 - init1)).astype(dtype)
        cond = (count >= p1_start_f) & (count < p1_end_f)
        lr = jnp.where(cond, lr_lin, jnp.asarray(const2, dtype))
        return lr
    return schedule

# ---------------------------------------------
# MDMM 
# ---------------------------------------------
class LagrangeMultiplier(NamedTuple):
    value: Any

def mdmm_prepare_update(tree):
    pred = lambda x: isinstance(x, LagrangeMultiplier)
    return jtu.tree_map(
        lambda x: LagrangeMultiplier(-x.value) if pred(x) else x,
        tree, is_leaf=pred
    )

def optax_prepare_update():
    def init_fn(params):
        del params
        return optax.EmptyState()
    def update_fn(updates, state, params=None):
        del params
        return mdmm_prepare_update(updates), state
    return optax.GradientTransformation(init_fn, update_fn)

class Constraint(NamedTuple):
    init: Callable
    loss: Callable

def eq_constraint(fun, damping=1., weight=1., reduction=jnp.sum):
    def init_fn(*args, **kwargs):
        out = fun(*args, **kwargs)
        return {'lambda': LagrangeMultiplier(jnp.zeros_like(out))}
    def loss_fn(params, *args, **kwargs):
        inf = fun(*args, **kwargs)
        return weight * reduction(params['lambda'].value * inf + damping * inf**2 / 2.0), inf
    return Constraint(init_fn, loss_fn)

def combine_constraints(*constraints):
    inits, losses = zip(*constraints)
    def init_fn(*a, **k):
        return tuple(fn(*a, **k) for fn in inits)
    def loss_fn(params, *a, **k):
        outs = [fn(p, *a, **k) for p, fn in zip(params, losses)]
        return sum(x[0] for x in outs), tuple(x[1] for x in outs)
    return Constraint(init_fn, loss_fn)

def bound_hard(fun_value, lo, hi, damping=1., weight=1., reduction=jnp.sum):
    return eq_constraint(lambda *args, **kwargs: jnp.clip(fun_value(*args, **kwargs), lo, hi) - fun_value(*args, **kwargs),
                         damping=damping, weight=weight, reduction=reduction)

# ---------------------------------------------
# Training: MDMM
# ---------------------------------------------
def train_mdmm(model: PINN, config: Config,
               X_data, u_data, X_f, X_ic, u_ic, X_bc_left, X_bc_right):

    dtype = jnp.float64 if config.precision == 64 else jnp.float32
    rng = jax.random.PRNGKey(config.seed)
    net_params = model.init(rng, jnp.zeros((1, 2), dtype=dtype))['params']

    theta = {'D': jnp.array(config.D_init, dtype)}
    params_all = {'net': net_params, 'theta': theta}

    # Constraints on scalar losses
    physics_c = eq_constraint(lambda p: physics_loss_fn(p, model, X_f, "mdmm"), damping=1., weight=1.)
    ic_c      = eq_constraint(lambda p: ic_loss_fn(p, model, X_ic, u_ic), damping=1., weight=1.)
    bc_c      = eq_constraint(lambda p: bc_loss_fn(p, model, X_bc_left, X_bc_right), damping=1., weight=1.)
    bound_D   = bound_hard(lambda p: p['theta']['D'], config.param_bounds[0], config.param_bounds[1], damping=1., weight=1.)
    constraints = combine_constraints(physics_c, ic_c, bc_c, bound_D)

    mdmm_params = constraints.init(params_all)
    params_all['mdmm'] = mdmm_params

    # LR schedule: 1e-3 -> 1e-4 then const 1e-4
    schedule = make_linear_const_schedule(init_lr=1e-3, final_lr1=1e-4, const_lr2=1e-4,
                                          switch_epoch=max(0, config.epochs - 30000),
                                          total_epochs=config.epochs, dtype=dtype)
    optimizer = optax.chain(
        optax.adam(learning_rate=schedule),
        optax_prepare_update(),
    )
    opt_state = optimizer.init(params_all)

    def total_loss_and_logs(p):
        dloss = data_rrms_loss(p, model, X_data, u_data)
        mdmm_term, _ = constraints.loss(p['mdmm'], p)
        ic_val   = ic_loss_fn(p, model, X_ic, u_ic)
        bc_val   = bc_loss_fn(p, model, X_bc_left, X_bc_right)
        phys_val = physics_loss_fn(p, model, X_f, "mdmm")
        return dloss + mdmm_term, (dloss, ic_val, bc_val, phys_val)

    @jax.jit
    def mdmm_step(p, state):
        (val, logs), grads = jax.value_and_grad(total_loss_and_logs, has_aux=True)(p)
        updates, state = optimizer.update(grads, state, p)
        p = optax.apply_updates(p, updates)
        dloss, ic_val, bc_val, phys_val = logs
        return p, state, dloss, ic_val, bc_val, phys_val

    loss_hist = {"data": [], "ic": [], "bc": [], "physics": [], "memory_mb": []}
    param_hist = {"D": []}

    for it in range(config.epochs):
        params_all, opt_state, dloss, ic_val, bc_val, phys_val = mdmm_step(params_all, opt_state)

        loss_hist["data"].append(float(dloss))
        loss_hist["ic"].append(float(ic_val))
        loss_hist["bc"].append(float(bc_val))
        loss_hist["physics"].append(float(phys_val))
        D_val = float(get_D(params_all, "mdmm"))
        param_hist["D"].append(D_val)
        loss_hist["memory_mb"].append(float(get_process_memory_mb()))

        if (it + 1) % 500 == 0:
            print(f"Iter {it+1:5d}, Data: {float(dloss):.5e}, Phys: {float(phys_val):.5e}, "
                  f"IC: {float(ic_val):.5e}, BC: {float(bc_val):.5e}, D: {D_val:.5f}")

    trained = TrainedPINN(params=params_all, model=model, training_method="mdmm",
                          loss_history=loss_hist, param_history=param_hist)
    return trained

# ---------------------------------------------
# Training: Traditional (data + IC + BC + physics)
# ---------------------------------------------
def train_traditional(model: PINN, config: Config,
                      X_data, u_data, X_f, X_ic, u_ic, X_bc_left, X_bc_right):

    dtype = jnp.float64 if config.precision == 64 else jnp.float32
    rng = jax.random.PRNGKey(config.seed)
    net_params = model.init(rng, jnp.zeros((1, 2), dtype=dtype))['params']
    theta = {'logD': jnp.array(np.log(config.D_init), dtype)}
    params_all = {'net': net_params, 'theta': theta}

    # LR schedule: 1e-2 -> 1e-4 then const 1e-4
    schedule = make_linear_const_schedule(init_lr=1e-2, final_lr1=1e-4, const_lr2=1e-4,
                                          switch_epoch=max(0, config.epochs - 30000),
                                          total_epochs=config.epochs, dtype=dtype)
    optimizer = optax.adam(learning_rate=schedule)
    opt_state = optimizer.init(params_all)

    def losses(p):
        ld = data_rrms_loss(p, model, X_data, u_data)
        li = ic_loss_fn(p, model, X_ic, u_ic)
        lb = bc_loss_fn(p, model, X_bc_left, X_bc_right)
        lp = physics_loss_fn(p, model, X_f, "traditional")
        return ld, li, lb, lp

    def total_loss(p):
        ld, li, lb, lp = losses(p)
        return ld + li + lb + lp, (ld, li, lb, lp)

    @jax.jit
    def step(p, state):
        (tot, logs), grads = jax.value_and_grad(total_loss, has_aux=True)(p)
        updates, state = optimizer.update(grads, state, p)
        p = optax.apply_updates(p, updates)
        ld, li, lb, lp = logs
        return p, state, tot, ld, li, lb, lp

    loss_hist = {"data": [], "ic": [], "bc": [], "physics": [], "memory_mb": []}
    param_hist = {"D": []}

    for it in range(config.epochs):
        params_all, opt_state, tot, ld, li, lb, lp = step(params_all, opt_state)

        loss_hist["data"].append(float(ld))
        loss_hist["ic"].append(float(li))
        loss_hist["bc"].append(float(lb))
        loss_hist["physics"].append(float(lp))
        D_val = float(get_D(params_all, "traditional"))
        param_hist["D"].append(D_val)
        loss_hist["memory_mb"].append(float(get_process_memory_mb()))

        if (it + 1) % 500 == 0:
            print(f"Iter {it+1:5d}, Total: {float(tot):.5e}, Data: {float(ld):.5e}, "
                  f"IC: {float(li):.5e}, BC: {float(lb):.5e}, Phys: {float(lp):.5e}, D: {D_val:.5f}")

    trained = TrainedPINN(params=params_all, model=model, training_method="traditional",
                          loss_history=loss_hist, param_history=param_hist)
    return trained

# ---------------------------------------------
# Nelder–Mead (optimize D via forward solves)
# ---------------------------------------------
def train_nelder_mead(config: Config, x, dx, u0, x_data, u_noisy_list, t_span):
    log = {"loss": [], "D": [], "memory_mb": []}

    def objective_fn(p):
        D_est = p[0]
        # memory sampling
        log["memory_mb"].append(float(get_process_memory_mb()))
        if D_est < 0.0:
            return 1e12
        sol = solve_ivp(lambda t, u: burgers_jit(t, u, dx, D_est),
                        t_span, u0, t_eval=config.t_eval, method='Radau')
        if (not sol.success) or (len(sol.t) < len(config.t_eval)):
            return 1e12
        y = np.array(sol.y)
        err = 0.0
        for i, t in enumerate(config.t_eval):
            u_model = np.interp(x_data, x, y[:, i])
            err += float(np.sqrt(np.mean(((u_model - u_noisy_list[i]) / (u_noisy_list[i] + 1e-6))**2)))
        curr = err / len(config.t_eval)
        log["loss"].append(curr)
        log["D"].append(float(D_est))
        return curr

    initial_guess = [config.D_init]
    bounds = [(config.param_bounds[0], config.param_bounds[1])]
    res = minimize(objective_fn, initial_guess, method='Nelder-Mead', bounds=bounds)
    print("Nelder–Mead optimization success:", res.success)
    return res.x, log

# ---------------------------------------------
# Plotting
# ---------------------------------------------
def plot_results(model_or_params, config: Config, x, x_data, u_noisy_list, u0, t_span, dx, device: str):
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    folder_name = f"Burgers_{config.training_method}_initial_guess_error_{config.initial_guess_noise}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)

    num_plots = len(config.plot_times)
    plt.figure(figsize=(7 * num_plots, 6))

    if isinstance(model_or_params, TrainedPINN):
        D_est = float(get_D(model_or_params.params, model_or_params.training_method))
    else:
        D_est = float(model_or_params[0])

    def f_est(t, u): return burgers_jit(t, u, dx, D_est)
    measurement_data = {t: (x_data, u_noisy) for t, u_noisy in zip(config.t_eval, u_noisy_list)}
    trajectories = {}

    for idx, t_plot in enumerate(config.plot_times):
        sol_true = solve_ivp(lambda t, u: burgers_jit(t, u, dx, config.D_true),
                             t_span, u0, t_eval=[t_plot], method='Radau')
        u_true = np.array(sol_true.y)[:, 0]
        sol_est = solve_ivp(f_est, t_span, u0, t_eval=[t_plot], method='Radau')
        if not sol_est.success:
            u_est = np.full_like(u_true, np.nan)
        else:
            u_est = np.array(sol_est.y)[:, 0]

        ax = plt.subplot(1, num_plots, idx + 1)
        ax.plot(x, u_true, 'k--', label='True solution (MoL)')
        ax.plot(x, u_est, 'g-', label='Estimated trajectory')
        trajectories[f"x_time_{t_plot}"] = x
        trajectories[f"u1_time_{t_plot}_true"] = u_true
        trajectories[f"u1_time_{t_plot}_estimated"] = u_est

        if isinstance(model_or_params, TrainedPINN):
            X_plot = jnp.array(np.column_stack([x, np.full_like(x, t_plot)]),
                               dtype=jnp.float64 if config.precision == 64 else jnp.float32)
            u_model = np.array(model_or_params.model.apply({'params': model_or_params.params['net']}, X_plot)).reshape(-1)
            ax.plot(x, u_model, 'b-', label='PINN prediction')
            trajectories[f"u1_time_{t_plot}_PINN"] = u_model

        if t_plot in measurement_data:
            x_meas, u_meas = measurement_data[t_plot]
            ax.plot(x_meas, u_meas, 'ro', label='Noisy data')
        ax.set_xlabel('x')
        ax.set_ylabel('u')
        ax.set_title(f"t = {t_plot}")
        ax.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(figures_folder, "Solution_comparison.png"), dpi=300)
    np.savez(os.path.join(full_folder, "trajectories.npz"), **trajectories)

def plot_loss(loss_log, method: str, config: Config):
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    folder_name = f"Burgers_{method}_initial_guess_error_{config.initial_guess_noise}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)

    if method in ["mdmm", "traditional"]:
        fig, axs = plt.subplots(2, 2, figsize=(10, 8))
        axs[0, 0].semilogy(loss_log["data"]);    axs[0, 0].set_title("Data Loss")
        axs[0, 1].semilogy(loss_log["ic"]);      axs[0, 1].set_title("IC Loss")
        axs[1, 0].semilogy(loss_log["bc"]);      axs[1, 0].set_title("BC Loss")
        axs[1, 1].semilogy(loss_log["physics"]); axs[1, 1].set_title("Physics Loss")
        for ax in axs.flat:
            ax.set_xlabel("Epoch")
            ax.set_xscale('log')
            ax.set_yscale('log')
        plt.tight_layout()
    else:
        plt.figure(figsize=(6, 4))
        plt.semilogy(loss_log["loss"])
        plt.xlabel("Iteration")
        plt.title("Objective Loss (Nelder–Mead)")
        plt.tight_layout()
    plt.savefig(os.path.join(figures_folder, "Loss_evolution.png"), dpi=300)

def plot_parameters(param_log, method: str, config: Config):
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    folder_name = f"Burgers_{method}_initial_guess_error_{config.initial_guess_noise}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)

    plt.figure(figsize=(6, 4))
    plt.plot(param_log["D"])
    plt.xlabel("Epoch" if method in ["mdmm", "traditional"] else "Iteration")
    plt.ylabel("D")
    plt.yscale("log")
    plt.xscale("log")
    plt.axhline(y=config.D_true, color='r', linestyle='--')
    plt.title("D parameter")
    plt.tight_layout()
    plt.savefig(os.path.join(figures_folder, "Parameter_evolution.png"), dpi=300)

# ---------------------------------------------
# Distance metric
# ---------------------------------------------
def measure_max_distance(model_or_params, config: Config, x, dx, u0, t_span, device: str):
    distances = []
    for t in config.t_eval:
        if isinstance(model_or_params, TrainedPINN):
            D_val = float(get_D(model_or_params.params, model_or_params.training_method))
        else:
            D_val = float(model_or_params[0])

        sol_ml = solve_ivp(lambda tt, u: burgers_jit(tt, u, dx, D_val),
                           t_span, u0, t_eval=[t], method='Radau')

        if not sol_ml.success or sol_ml.y.ndim == 1:
            u_ml = np.full(u0.shape, np.nan)
        else:
            u_ml = np.array(sol_ml.y)[:, 0]

        X_compare = jnp.array(np.column_stack([x, np.full_like(x, t)]),
                              dtype=jnp.float64 if config.precision == 64 else jnp.float32)

        if isinstance(model_or_params, TrainedPINN):
            u_pinn = np.array(model_or_params.model.apply({'params': model_or_params.params['net']}, X_compare)).reshape(-1)
        else:
            u_pinn = u_ml

        max_distance = np.nanmax(np.abs(u_ml - u_pinn))
        distances.append(max_distance)

    return float(np.nanmax(distances))

# ---------------------------------------------
# Save outputs & summary
# ---------------------------------------------
def save_outputs_and_summary(config: Config, method: str, x, dx, u0, x_data, u_noisy_list, t_span, runtime,
                             trained_model: Optional[TrainedPINN] = None,
                             optimal_params: Optional[np.ndarray] = None,
                             log: Optional[Dict[str, Any]] = None,
                             max_distance: Optional[float] = None):
    base_folder = os.getcwd() if config.output_folder == "." else config.output_folder
    folder_name = f"Burgers_{config.training_method}_initial_guess_error_{config.initial_guess_noise}_data_error_{config.data_error}_N_data_points_{config.N_data_points}"
    full_folder = os.path.join(base_folder, folder_name)
    figures_folder = os.path.join(full_folder, "Figures")
    os.makedirs(figures_folder, exist_ok=True)

    if method in ["mdmm", "traditional"]:
        with open(os.path.join(full_folder, "trained_model.pt"), "wb") as f:
            pickle.dump({
                'params': trained_model.params,
                'training_method': trained_model.training_method,
                'hidden_layers': config.hidden_layers,
                'activation': config.activation,
                'precision': config.precision,
                'num_fourier': config.num_fourier,
                'xmin': config.xmin,
                'L': config.L
            }, f)
        np.savez(os.path.join(full_folder, "noisy_data.npz"), x_data=x_data, u_noisy_list=u_noisy_list)
        np.savez(os.path.join(full_folder, "loss_history.npz"), **trained_model.loss_history)
        np.savez(os.path.join(full_folder, "parameter_history.npz"), **trained_model.param_history)
    elif method == "nelder-mead":
        np.savez(os.path.join(full_folder, "optimal_params.npz"), optimal_params=optimal_params)
        if log is not None:
            np.savez(os.path.join(full_folder, "loss_log.npz"), **log)
        np.savez(os.path.join(full_folder, "noisy_data.npz"), x_data=x_data, u_noisy_list=u_noisy_list)
        if log is not None:
            np.savez(os.path.join(full_folder, "parameter_history.npz"), D=log.get("D", []))

    # Estimate D
    if method in ["mdmm", "traditional"]:
        D_est = float(get_D(trained_model.params, method))
    else:
        D_est = float(optimal_params[0])

    # RRMS on data (MoL with estimated D vs noisy measurements)
    rrms_values = []
    eps = 1e-6
    for i, t in enumerate(config.t_eval):
        sol_ml = solve_ivp(lambda tt, u: burgers_jit(tt, u, dx, D_est),
                           t_span, u0, t_eval=[t], method='Radau')
        if not sol_ml.success or sol_ml.y.ndim == 1:
            u_ml = np.full(u0.shape, np.nan)
        else:
            u_ml = np.array(sol_ml.y)[:, 0]
        u_sim = np.interp(x_data, x, u_ml)
        rrms_values.append(float(np.sqrt(np.mean(((u_sim - u_noisy_list[i]) / (u_noisy_list[i] + eps))**2))))
    root_relative_mse_data = float(np.nanmean(rrms_values))
    root_relative_mse_parameters = float(abs((D_est - config.D_true) / config.D_true))

    # Extract IC function body (human-readable)
    source_str = inspect.getsource(Config.initial_condition)
    source_lines = source_str.splitlines()
    body = ""
    for i, line in enumerate(source_lines):
        if line.strip().startswith("def"):
            body = "\n".join(source_lines[i+1:]).strip()
            break
    if body.startswith("return "):
        body = body[len("return "):].strip()

    # Memory statistics
    if method in ["mdmm", "traditional"] and trained_model is not None:
        mem_list = trained_model.loss_history.get("memory_mb", [])
        finite_vals = [m for m in mem_list if np.isfinite(m)]
        avg_mem = float(np.mean(finite_vals)) if finite_vals else None
        max_mem = float(np.max(finite_vals)) if finite_vals else None
    elif method == "nelder-mead" and log is not None:
        mem_list = log.get("memory_mb", [])
        finite_vals = [m for m in mem_list if np.isfinite(m)]
        avg_mem = float(np.mean(finite_vals)) if finite_vals else None
        max_mem = float(np.max(finite_vals)) if finite_vals else None
    else:
        avg_mem = None
        max_mem = None

    summary = {
        "training_method": config.training_method,
        "runtime_seconds": float(runtime),
        "initial_condition": body,
        "boundary_conditions": "Dirichlet BC: u(xmin,t)=0 and u(xmax,t)=0",
        "parameter_initial_guess": [float(config.D_init)],
        "parameter_ground_truth": [float(config.D_true)],
        "data_noise": float(config.data_error),
        "initial_guess_error": float(config.initial_guess_noise),
        "root_relative_mse_data": float(root_relative_mse_data),
        "root_relative_mse_parameters": float(root_relative_mse_parameters),
        "plot_times": config.plot_times,
        "xmin": config.xmin, "xmax": config.xmax,
        "tmin": config.tmin, "tmax": config.tmax,
        "t_eval": list(config.t_eval),
        "hidden_layers": list(config.hidden_layers),
        "activation": config.activation,
        "param_bounds": list(config.param_bounds),
        "parameter_estimates": [float(D_est)],
        "max_distance": float(max_distance) if max_distance is not None else (float('nan') if method == "nelder-mead" else None),
        "average_memory_usage_mb": avg_mem,
        "max_memory_usage_mb": max_mem
    }

    with open(os.path.join(full_folder, "summary_statistics.json"), "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=4)

# ---------------------------------------------
# Main
# ---------------------------------------------
def main():
    parser = argparse.ArgumentParser(description="Train PINN (JAX/Flax/Optax) for 1D viscous Burgers equation")
    parser.add_argument("--epochs", type=int, default=30000, help="Number of training epochs")
    parser.add_argument("--initial_guess_error", type=float, default=5.0, help="Initial guess error multiplier")
    parser.add_argument("--data_error", type=float, default=0.3, help="Data noise level")
    parser.add_argument("--N_data_points", type=int, default=14, help="Number of spatial data points")
    parser.add_argument("--training_method", type=str, default="mdmm", help="mdmm, traditional, or nelder-mead")
    parser.add_argument("--precision", type=int, default=64, help="Precision (32 or 64)")
    args = parser.parse_args()

    start_time = time.time()
    config = Config()

    # Update config
    config.epochs = args.epochs
    config.N_data_points = args.N_data_points
    config.initial_guess_noise = args.initial_guess_error
    config.data_error = args.data_error
    config.training_method = args.training_method
    config.D_init = (1 + config.initial_guess_noise) * config.D_true
    config.precision = args.precision

    # Precision / dtype
    if config.precision == 64:
        jax.config.update("jax_enable_x64", True)
        dtype = jnp.float64
    else:
        jax.config.update("jax_enable_x64", False)
        dtype = jnp.float32

    # Seeds & device info
    np.random.seed(config.seed)
    device = jax.devices()[0].platform
    print("Using device:", device)

    # Build data
    x, dx, u0, x_data, X_data, u_data, u_noisy_list, t_span = create_forward_data(config, device)
    X_f = generate_collocation_points(config, device)
    X_ic, u_ic, X_bc_left, X_bc_right = create_ic_bc_data(config, device)

    # Model
    model = PINN(hidden_layers=tuple(config.hidden_layers),
                 activation=config.activation,
                 dtype=dtype,
                 num_fourier=config.num_fourier,
                 xmin=config.xmin, L=config.L)

    method = config.training_method.lower()
    if method == "mdmm":
        print("Training using MDMM...")
        trained_model = train_mdmm(model, config, X_data, u_data, X_f, X_ic, u_ic, X_bc_left, X_bc_right)
        D_est = float(get_D(trained_model.params, "mdmm"))
        print(f"\nMDMM: Final recovered parameter: D = {D_est:.5f}")
        plot_results(trained_model, config, x, x_data, u_noisy_list, u0, t_span, dx, device)
        plot_loss(trained_model.loss_history, method, config)
        plot_parameters(trained_model.param_history, method, config)
        max_distance = measure_max_distance(trained_model, config, x, dx, u0, t_span, device)
        print("Maximum distance between MoL and PINN (MDMM):", max_distance)
        optimal_params, log = None, None
    elif method == "traditional":
        print("Training using traditional loss minimization...")
        trained_model = train_traditional(model, config, X_data, u_data, X_f, X_ic, u_ic, X_bc_left, X_bc_right)
        D_est = float(get_D(trained_model.params, "traditional"))
        print(f"\nTraditional: Final recovered parameter: D = {D_est:.5f}")
        plot_results(trained_model, config, x, x_data, u_noisy_list, u0, t_span, dx, device)
        plot_loss(trained_model.loss_history, method, config)
        plot_parameters(trained_model.param_history, method, config)
        max_distance = measure_max_distance(trained_model, config, x, dx, u0, t_span, device)
        print("Maximum distance between MoL and PINN (Traditional):", max_distance)
        optimal_params, log = None, None
    elif method == "nelder-mead":
        print("Training using Nelder–Mead...")
        optimal_params, log = train_nelder_mead(config, x, dx, u0, x_data, u_noisy_list, t_span)
        print(f"\nNelder–Mead: Final recovered parameter: D = {optimal_params[0]:.5f}")
        plot_results(optimal_params, config, x, x_data, u_noisy_list, u0, t_span, dx, device)
        plot_loss(log, method, config)
        plot_parameters(log, method, config)
        trained_model = None
        max_distance = None
    else:
        raise ValueError(f"Unknown training method: {config.training_method}")

    end_time = time.time()
    runtime = end_time - start_time

    if method in ["mdmm", "traditional"]:
        save_outputs_and_summary(config, method, x, dx, u0, x_data, u_noisy_list, t_span, runtime,
                                 trained_model=trained_model, max_distance=max_distance)
    elif method == "nelder-mead":
        save_outputs_and_summary(config, method, x, dx, u0, x_data, u_noisy_list, t_span, runtime,
                                 optimal_params=optimal_params, log=log)

if __name__ == "__main__":
    main()
