import jax
import os

jax.config.update("jax_enable_x64", False)
os.environ["JAX_PLATFORM_NAME"] = "gpu"

print("default_backend:", jax.default_backend())
print("devices:", jax.devices())

import json
import time
from dataclasses import dataclass
from typing import List, Optional, Dict, Tuple, Callable

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import jax.numpy as jnp
from jax import random, lax
from tqdm.auto import tqdm

import diffrax
import optax
from flax import nnx

# TFMPE imports
from tfmpe.estimators.tfmpe import TFMPE, NormalDistribution
from tfmpe.estimators.training import fit_bottom_up
from tfmpe.preprocessing.tokens import Tokens
from tfmpe.preprocessing.utils import Independence, Labeller
from tfmpe.nn.transformer import Transformer, TransformerConfig
from jaxtyping import PRNGKeyArray

# =============================================================================
# HEMODYNAMICS SIMULATOR
# =============================================================================

MMHG_IN_PA = 133.322  # Pa per mmHg


@dataclass
class Vessel:
    length_m: float
    radius_m: float
    thickness_m: float
    name: str
    parent_idx: int = -1


@dataclass
class Network:
    vessels: List[Vessel]
    n_vessels: int

    def children(self, idx: int) -> List[int]:
        return [i for i, v in enumerate(self.vessels) if v.parent_idx == idx]

    def terminal_ids(self) -> List[int]:
        return [i for i, v in enumerate(self.vessels) if len(self.children(i)) == 0]


@dataclass
class BaseParams:
    beta_base: float = 3.0e3
    rho: float = 1060.0
    mu_ref: float = 0.004
    T: float = 1.0
    p_ext: float = 0.0
    map_target_mmHg: float = 93.0
    co_target_ml_s: float = 83.0
    beta_scale_mean: float = 2.0e6
    term_C_scale: float = 0.60
    tau_target_s: float = 0.40


@dataclass
class RCRParameters:
    R1: jnp.ndarray
    R2: jnp.ndarray
    C: jnp.ndarray
    Pext: jnp.ndarray


SITE_ORDER = ["descending_aorta", "innominate", "left_common_carotid", "left_subclavian"]


def create_arch_network() -> Network:
    vs = [
        Vessel(0.06, 0.012, 0.0015, "aortic_root", -1),
        Vessel(0.05, 0.0065, 0.0010, "innominate", 0),
        Vessel(0.05, 0.0045, 0.0010, "left_common_carotid", 0),
        Vessel(0.06, 0.0060, 0.0010, "left_subclavian", 0),
        Vessel(0.10, 0.0110, 0.0015, "descending_aorta", 0),
    ]
    return Network(vs, len(vs))


def pack_geometry(net: Network, nx: int):
    r0 = jnp.array([v.radius_m for v in net.vessels])
    h = jnp.array([v.thickness_m for v in net.vessels])
    A0 = jnp.pi * r0**2
    dx = jnp.array([v.length_m / (nx - 1) for v in net.vessels])
    return r0, h, A0, dx


def terminal_ids_ordered(net: Network) -> List[int]:
    name_to_idx = {v.name: i for i, v in enumerate(net.vessels)}
    return [name_to_idx[name] for name in SITE_ORDER]


def st_like_refs(net: Network, base: BaseParams):
    tids = terminal_ids_ordered(net)
    r = np.array([net.vessels[i].radius_m for i in tids])
    Rt_ref = (r[0] ** 4) / (r**4)
    Rt_ref = Rt_ref / Rt_ref.mean() * 1.2e8
    C_ref = (r**3) / (r[0] ** 3)
    C_ref = C_ref / C_ref.mean() * 1.5e-9
    return Rt_ref.astype(np.float32), C_ref.astype(np.float32)


def psi_from_A(A, A0, beta, rho):
    A = jnp.maximum(A, 1e-16 * A0)
    return (beta / (3.0 * rho * A0)) * (A**1.5 - A0**1.5)


def pressure_from_A(A, A0, beta, p_ext):
    A = jnp.maximum(A, 1e-16 * A0)
    return p_ext + (beta / A0) * (jnp.sqrt(A) - jnp.sqrt(A0))


def A_from_pressure(p, A0, beta, p_ext):
    s = jnp.sqrt(A0) + (A0 / beta) * (p - p_ext)
    s = jnp.maximum(s, 1e-12)
    return s**2


def characteristic_speed(A, A0, beta, rho):
    A = jnp.maximum(A, 1e-16 * A0)
    c2 = (beta / (2.0 * rho * A0)) * jnp.sqrt(A)
    return jnp.sqrt(jnp.maximum(c2, 1e-16))


def default_rcr(net: Network, base: BaseParams, qin_ml_s: Optional[float] = None) -> RCRParameters:
    n = net.n_vessels
    R1 = np.zeros(n, dtype=np.float32)
    R2 = np.zeros(n, dtype=np.float32)
    C = np.zeros(n, dtype=np.float32)
    Pext = np.zeros(n, dtype=np.float32)

    Rt_ref, _ = st_like_refs(net, base)
    co_used = base.co_target_ml_s if (qin_ml_s is None) else qin_ml_s
    R_SYS_TARGET = (base.map_target_mmHg * MMHG_IN_PA) / (co_used * 1e-6)
    k = R_SYS_TARGET * np.sum(1.0 / Rt_ref)
    Rt_ref = k * Rt_ref

    C_ref = base.term_C_scale * (base.tau_target_s / Rt_ref)

    tids = terminal_ids_ordered(net)
    for k_idx, vidx in enumerate(tids):
        R1[vidx] = 0.02 * Rt_ref[k_idx]
        R2[vidx] = Rt_ref[k_idx]
        C[vidx] = C_ref[k_idx]
        Pext[vidx] = base.p_ext

    return RCRParameters(jnp.array(R1), jnp.array(R2), jnp.array(C), jnp.array(Pext))


def make_stepper(
    net: Network,
    base: BaseParams,
    beta_scale: float,
    mu_pa_s: float,
    nx: int,
    dt0: float,
    Qin_amp_ml_s: float,
    q_cap: float,
):
    n = net.n_vessels
    r0, h, A0, dx = pack_geometry(net, nx)
    rho = jnp.asarray(base.rho)

    def beta_from_c0(A0, rho, c0_root=5.0, r0_root=None, expo=-0.05):
        r = jnp.sqrt(A0 / jnp.pi)
        r_ref = r0_root if r0_root is not None else r[0]
        c = c0_root * (r / r_ref) ** expo
        return 2.0 * rho * (c**2) * jnp.sqrt(A0)

    beta_nom = beta_from_c0(A0, rho, c0_root=5.0, r0_root=r0[0], expo=-0.05)
    beta_i = beta_nom * (beta_scale / 1.0e6)

    mu = jnp.asarray(mu_pa_s)
    nu = mu / rho
    r_shape = r0
    K_R = 2.0 * jnp.pi * nu / jnp.maximum(r_shape**2, 1e-12)

    children_list = [net.children(i) for i in range(n)]
    is_terminal = jnp.array([1 if len(children_list[i]) == 0 else 0 for i in range(n)], dtype=jnp.int32)

    dt0_j = jnp.asarray(dt0, dtype=jnp.float32)
    Qin_amp = jnp.asarray(Qin_amp_ml_s / 1e6, dtype=jnp.float32)

    def lw_interior(U, dt):
        A = U[:, 0, :]
        Q = U[:, 1, :]

        F1 = Q
        F2 = Q**2 / jnp.maximum(A, 1e-16) + psi_from_A(A, A0[:, None], beta_i[:, None], rho)
        F = jnp.stack([F1, F2], axis=1)

        S_Q = -(K_R[:, None] * Q) / jnp.maximum(A, 1e-16)
        S = jnp.stack([jnp.zeros_like(Q), S_Q], axis=1)

        F_diff = F[:, :, 1:] - F[:, :, :-1]
        S_avg = 0.5 * (S[:, :, 1:] + S[:, :, :-1])
        dt_over_dx = (dt / dx)[:, None, None]

        U_half = 0.5 * (U[:, :, :-1] + U[:, :, 1:]) - 0.5 * dt_over_dx * F_diff + 0.5 * dt * S_avg
        A_half = U_half[:, 0, :]
        Q_half = U_half[:, 1, :]

        F_half2 = Q_half**2 / jnp.maximum(A_half, 1e-16) + psi_from_A(A_half, A0[:, None], beta_i[:, None], rho)
        F_half = jnp.stack([Q_half, F_half2], axis=1)

        U_new = U.at[:, :, 1:-1].set(
            U[:, :, 1:-1] - dt_over_dx * (F_half[:, :, 1:] - F_half[:, :, :-1]) + dt * S[:, :, 1:-1]
        )

        A_new = jnp.clip(U_new[:, 0, :], 0.6 * A0[:, None], 1.4 * A0[:, None])
        U_new = U_new.at[:, 0, :].set(A_new)
        U_new = jnp.where(jnp.isfinite(U_new), U_new, U)

        return U_new

    def inlet_bc(U, t):
        ramp = 1.0 - jnp.exp(-5.0 * t / base.T)
        q_in = Qin_amp * ramp * (1.0 + 0.5 * jnp.sin(2.0 * jnp.pi * t / base.T))
        A0_neighbor = U[0, 0, 1]
        U = U.at[0, 1, 0].set(q_in)
        U = U.at[0, 0, 0].set(0.9 * A0_neighbor + 0.1 * A0[0])
        return U

    def junction_bc(U):
        kids = children_list[0]
        if len(kids) == 0:
            return U

        A_p_int = U[0, 0, -2]
        q_p_int = U[0, 1, -2]
        U = U.at[0, 0, -1].set(A_p_int)
        U = U.at[0, 1, -1].set(q_p_int)

        p_parent = pressure_from_A(A_p_int, A0[0], beta_i[0], base.p_ext)

        A_c = []
        Zc = []
        for cidx in kids:
            A_c_upd = A_from_pressure(p_parent, A0[cidx], beta_i[cidx], base.p_ext)
            A_c.append(A_c_upd)
            c_speed = characteristic_speed(A_c_upd, A0[cidx], beta_i[cidx], base.rho)
            Zc.append(base.rho * c_speed / jnp.maximum(A_c_upd, 1e-16))

        A_c = jnp.stack(A_c)
        Zc = jnp.stack(Zc)
        w = 1.0 / jnp.maximum(Zc, 1e-16)

        q_parent_out = q_p_int
        q_children = q_parent_out * w / (jnp.sum(w) + 1e-16)
        for j, cidx in enumerate(kids):
            U = U.at[cidx, 0, 0].set(A_c[j])
            U = U.at[cidx, 1, 0].set(q_children[j])

        return U

    def rcr_outlets(U, P_c, dt, R1, R2, C, Pext):
        Q_out = U[:, 1, -1]
        C_safe = jnp.where(is_terminal == 1, jnp.maximum(C, 1e-16), 1.0)
        R2_safe = jnp.where(is_terminal == 1, jnp.maximum(R2, 1e-16), 1.0)

        dP_c = (Q_out - (P_c - Pext) / R2_safe) / C_safe
        P_c_next = jnp.where(is_terminal == 1, P_c + dt * dP_c, P_c)

        P_out = jnp.where(is_terminal == 1, R1 * Q_out + P_c_next, base.p_ext)
        A_targ = A_from_pressure(P_out, A0, beta_i, base.p_ext)
        A_targ = jnp.clip(A_targ, 0.6 * A0, 1.4 * A0)

        A_new = jnp.where(is_terminal == 1, A_targ, U[:, 0, -1])
        U = U.at[:, 0, -1].set(A_new)
        U = U.at[:, 1, -1].set(jnp.where(is_terminal == 1, U[:, 1, -2], U[:, 1, -1]))

        return U, P_c_next

    def step(carry, _):
        U, P_c, t, dt, R1, R2, C, Pext = carry

        A = U[:, 0, :]
        Q = U[:, 1, :]

        c_speed = characteristic_speed(A, A0[:, None], beta_i[:, None], rho)
        u = Q / jnp.maximum(A, 1e-16)
        smax = jnp.max(jnp.abs(u[:, 1:-1]) + c_speed[:, 1:-1])
        smax = jnp.where(jnp.isfinite(smax), smax, 1.0)

        dx_min = jnp.min(dx)
        dt_cfl = 0.45 * dx_min / jnp.maximum(smax, 1e-9)
        dt = jnp.minimum(dt_cfl, dt0_j)

        U = lw_interior(U, dt)
        U = inlet_bc(U, t)
        U = junction_bc(U)
        U, P_c = rcr_outlets(U, P_c, dt, R1, R2, C, Pext)

        t_next = t + dt
        return (U, P_c, t_next, dt, R1, R2, C, Pext), None

    @jax.jit
    def run_chunk(U, P_c, t, dt, R1, R2, C, Pext, steps: int):
        def body(i, carry):
            carry, _ = step(carry, None)
            return carry

        U_n, P_c_n, t_n, dt_n, *_ = lax.fori_loop(0, steps, body, (U, P_c, t, dt, R1, R2, C, Pext))
        return U_n, P_c_n, t_n, dt_n

    return run_chunk, (r0, A0, beta_i, rho, children_list, is_terminal, dx)


def simulate_5cycles_then_sample(net: Network, base: BaseParams,
                                 theta_g: jnp.ndarray, theta_loc: jnp.ndarray,
                                 N_t: int, nx: int, dt_init: float,
                                 verbose: bool = False):
    DTYPE = jnp.float32
    
    tids = terminal_ids_ordered(net)
    n = net.n_vessels
    n_term = len(tids)

    # Cast inputs to consistent dtype
    theta_g = jnp.asarray(theta_g, dtype=DTYPE)
    theta_loc = jnp.asarray(theta_loc, dtype=DTYPE)

    beta_scale = float(jnp.exp(theta_g[0]))
    mu_pa_s    = float(jnp.exp(theta_g[1]))
    Qin_ml_s   = float(jnp.exp(theta_g[2]))

    rcr = default_rcr(net, base, qin_ml_s=Qin_ml_s)
    
    # Explicit dtype cast for all RCR arrays
    R1 = rcr.R1.astype(DTYPE)
    R2 = rcr.R2.astype(DTYPE)
    C = rcr.C.astype(DTYPE)
    Pext = rcr.Pext.astype(DTYPE)

    for s, vidx in enumerate(tids):
        R2 = R2.at[vidx].set(DTYPE(jnp.exp(theta_loc[s, 0])))
        C  = C.at[vidx].set(DTYPE(jnp.exp(theta_loc[s, 1])))

    run_chunk, meta = make_stepper(net, base, beta_scale, mu_pa_s, nx, dt_init, Qin_ml_s, q_cap=1e9)
    r0, A0, beta_i, rho, children_list, is_terminal, dx = meta

    U = jnp.zeros((n, 2, nx), dtype=jnp.float32)
    U = U.at[:, 0, :].set(A0[:, None])
    P_c = jnp.where(is_terminal == 1, Pext, 0.0)

    steps_per_cycle_est = int(np.ceil(base.T / dt_init))
    chunk = max(1, steps_per_cycle_est // 200)
    t_target = 5.0 * base.T

    mids = nx // 2
    t_hist = [[] for _ in range(n)]
    q_hist = [[] for _ in range(n)]
    p_hist = [[] for _ in range(n)]

    t = jnp.asarray(0.0)
    dt = jnp.asarray(dt_init)
    U_dev, P_c_dev = U, P_c

    while float(t) < float(t_target):
        n_now = int(chunk)
        U_dev, P_c_dev, t, dt = run_chunk(U_dev, P_c_dev, t, dt, R1, R2, C, Pext, n_now)

        Uh = np.array(U_dev)
        Ah = Uh[:, 0, :]
        Qh = Uh[:, 1, :]

        beta_host = np.array(beta_i)
        A0_host = np.array(A0)
        ph_mid = base.p_ext + (beta_host / A0_host) * (np.sqrt(np.maximum(Ah[:, mids], 1e-16)) - np.sqrt(A0_host))

        th = float(t)
        for i in range(n):
            t_hist[i].append(th)
            q_hist[i].append(Qh[i, mids])
            p_hist[i].append(ph_mid[i])

    traces = {}
    for i, v in enumerate(net.vessels):
        traces[v.name] = {"t": np.array(t_hist[i]), "q": np.array(q_hist[i]), "p": np.array(p_hist[i])}

    t_all = traces["aortic_root"]["t"]
    t_end = t_all[-1]
    t_start = t_end - base.T + 1e-12

    def last_cycle(x_t, x_y):
        m = (x_t >= t_start) & (x_t <= t_end + 1e-12)
        return x_t[m], x_y[m]

    last = {}
    for k in traces:
        tt, qq = last_cycle(traces[k]["t"], traces[k]["q"])
        _, pp = last_cycle(traces[k]["t"], traces[k]["p"])
        last[k] = {"t": tt, "q": qq, "p": pp}

    grid = np.linspace(0.0, base.T, N_t, endpoint=False)
    res = {}
    for k in last:
        tt = last[k]["t"] - t_start
        res[k] = {
            "t": grid,
            "q": np.interp(grid, tt, last[k]["q"]),
            "p": np.interp(grid, tt, last[k]["p"]),
        }
    return res


# =============================================================================
# TFMPE INTERFACE FUNCTIONS
# =============================================================================

@dataclass
class TFMPEConfig:
    """Configuration for TFMPE hemodynamics inference."""
    N_t: int = 50
    nx: int = 81
    dt_init: float = 2e-4
    eta: float = 0.05
    latent_dim: int = 32
    n_encoder: int = 2
    n_decoder: int = 2
    n_heads: int = 2
    n_ff: int = 2
    n_rounds: int = 1  # fit_bottom_up currently only supports n_rounds=1
    n_samples_per_round: int = 200
    n_val_samples: int = 20
    n_iter_per_round: int = 500
    batch_size: int = 32
    learning_rate: float = 1e-3
    n_posterior_samples: int = 500
    output_dir: str = "tfmpe_hemo_results"


def create_prior_fn(net: Network, base: BaseParams, n_terminals: int) -> Callable:
    rcr_ref = default_rcr(net, base)
    tids = terminal_ids_ordered(net)
    Rt_ref = jnp.array([float(rcr_ref.R2[i]) for i in tids])
    C_ref = jnp.array([float(rcr_ref.C[i]) for i in tids])

    def prior_fn(rng: PRNGKeyArray, n: int, n_samples: int = 1):
        k1, k2, k3, k4 = random.split(rng, 4)

        log_beta = random.normal(k1, (n_samples,)) * 0.3 + jnp.log(base.beta_scale_mean)
        log_mu   = random.normal(k2, (n_samples,)) * 0.2 + jnp.log(0.004)
        log_Qin  = random.normal(k3, (n_samples,)) * 0.2 + jnp.log(85.0)
        
        eps = random.normal(k4, (n_samples, n, n_terminals))
        log_Rt = jnp.log(Rt_ref)[None, None, :] + 0.4 * eps
        log_C  = jnp.log(C_ref)[None, None, :] - 0.4 * eps
 


        return {
            "log_beta": log_beta[:, None, None].astype(jnp.float32),
            "log_mu":   log_mu[:, None, None].astype(jnp.float32),
            "log_Qin":  log_Qin[:, None, None].astype(jnp.float32),
            "log_Rt":   log_Rt[..., None].astype(jnp.float32),  # (S, P, 4, 1)
            "log_C":    log_C[..., None].astype(jnp.float32),   # (S, P, 4, 1)
        }, None

    return prior_fn

def create_local_fn(net: Network, base: BaseParams, n_terminals: int) -> Callable:
    rcr_ref = default_rcr(net, base)
    tids = terminal_ids_ordered(net)
    Rt_ref = np.array([float(rcr_ref.R2[i]) for i in tids])
    C_ref = np.array([float(rcr_ref.C[i]) for i in tids])

    def local_fn(
        rng: PRNGKeyArray,
        global_samples: Dict[str, jnp.ndarray],
        n: int,
    ) -> Tuple[Dict[str, jnp.ndarray], None]:
        n_samples = global_samples["log_beta"].shape[0]
        
        eps = random.normal(rng, (n_samples, n, n_terminals))
        log_Rt = jnp.log(jnp.asarray(Rt_ref))[None, None, :] + 0.4 * eps
        log_C  = jnp.log(jnp.asarray(C_ref ))[None, None, :] - 0.4 * eps

        return {
            "log_Rt": log_Rt[..., None].astype(jnp.float32),  # (S, P, 4, 1)
            "log_C":  log_C[..., None].astype(jnp.float32),   # (S, P, 4, 1)
        }, None

    return local_fn


def _numpy_seed_from_key(key: PRNGKeyArray) -> int:
    # Use both uint32 halves for a stable seed
    k0 = int(np.uint32(np.array(key[0])))
    k1 = int(np.uint32(np.array(key[1])))
    return int(((k0 << 16) ^ k1) % (2**32 - 1))

def create_simulator_fn(net: Network, base: BaseParams, config: TFMPEConfig, n_terminals: int) -> Callable:

    def simulator_fn(rng: PRNGKeyArray, params_dict, n: int): # <- n is the number of independent arterial networks that im asking to simulate where one arterial network is one patient and one netwok is defined by the 4 sites
        if hasattr(params_dict, "decode"):
            params_dict = params_dict.decode()

        n_patients = int(n)

        log_beta = params_dict["log_beta"][:, 0, 0]
        log_mu   = params_dict["log_mu"][:, 0, 0]
        log_Qin  = params_dict["log_Qin"][:, 0, 0]
        log_Rt   = params_dict["log_Rt"][:, :, :, 0]   # (S, P, 4)
        log_C    = params_dict["log_C"][:, :, :, 0]    # (S, P, 4)
        nsamp = int(log_beta.shape[0])

        # Use numpy float64 for CFD numerical stability
        log_beta_np = np.array(log_beta, dtype=np.float64)
        log_mu_np   = np.array(log_mu, dtype=np.float64)
        log_Qin_np  = np.array(log_Qin, dtype=np.float64)
        log_Rt_np   = np.array(log_Rt, dtype=np.float64)
        log_C_np    = np.array(log_C, dtype=np.float64)

        rng_np = np.random.default_rng(_numpy_seed_from_key(rng))
        all_y = []

        # Reference values for inactive outlets
        tids = terminal_ids_ordered(net)

        for i in tqdm(range(nsamp), desc=f"Simulating (n_patients={n_patients})", leave=False):
            theta_g_i = jnp.array([log_beta_np[i], log_mu_np[i], log_Qin_np[i]], dtype=jnp.float32)
            y_groups = np.zeros((n_patients, n_terminals, config.N_t), dtype=np.float64)

            for p in range(n_patients):
                theta_loc_phys = np.zeros((n_terminals, 2), dtype=np.float64)
                for s in range(n_terminals):
                    theta_loc_phys[s, 0] = log_Rt_np[i, p, s]
                    theta_loc_phys[s, 1] = log_C_np[i, p, s]

                theta_loc_i = jnp.array(theta_loc_phys, dtype=jnp.float32)

                resampled = simulate_5cycles_then_sample(
                    net, base, theta_g_i, theta_loc_i,
                    N_t=config.N_t, nx=config.nx, dt_init=config.dt_init
                )

                for s, site_name in enumerate(SITE_ORDER):
                    q = np.asarray(resampled[site_name]["q"], dtype=np.float64)
                    sigma = 0.0#config.eta * max(float(np.max(np.abs(q))), 1e-9)
                    y_groups[p, s, :] = q + rng_np.normal(0.0, sigma, size=q.shape)
            all_y.append(y_groups)

        y_np = np.stack(all_y, axis=0)
        if not np.isfinite(y_np).all():
            raise RuntimeError("Non-finite values detected in simulator outputs.")
        
        # Cast to float32 for TFMPE
        y = jnp.asarray(y_np, dtype=jnp.float32)[..., None]
        return {"y": y}, None

    return simulator_fn

def _repeat_context_for_sampling(y_obs: Dict[str, jnp.ndarray], n_rep: int) -> Dict[str, jnp.ndarray]:
    """
    Repeat a single observation along sample axis so context batch matches params batch.
    Expect y_obs['y'] shape: (1, n_patients, 4, N_t, 1).
    Return shape: (n_rep, n_patients, 4, N_t, 1).
    """
    y = y_obs["y"]
    if y.shape[0] == n_rep:
        return y_obs
    if y.shape[0] != 1:
        # If user passes multiple observations, do not alter
        return y_obs
    return {"y": jnp.repeat(y, repeats=n_rep, axis=0)}




# =============================================================================
# VISUALIZATION AND TABLES
# =============================================================================

def plot_posterior_marginals(theta_g_samples: np.ndarray, theta_l_samples: np.ndarray,
                             true_theta_g: np.ndarray, true_theta_l: np.ndarray,
                             output_dir: str):
    # Global parameters
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    param_names = [r'$\log \beta_{scale}$', r'$\log \mu$', r'$\log Q_{in}$']

    for i, (ax, name) in enumerate(zip(axes, param_names)):
        samples = theta_g_samples[:, i]
        ax.hist(samples, bins=30, density=True, alpha=0.7, color='steelblue', edgecolor='white')
        ax.axvline(true_theta_g[i], color='red', linestyle='--', linewidth=2, label='True')
        ax.axvline(np.mean(samples), color='green', linestyle='-', linewidth=2, label='Mean')
        ax.set_xlabel(name, fontsize=12)
        ax.set_ylabel('Density' if i == 0 else '')
        ax.legend()
        ax.set_title(f'Posterior: {name}')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'posterior_global.png'), dpi=150, bbox_inches='tight')
    plt.close()

    # Local parameters
    n_terminals = theta_l_samples.shape[1]
    fig, axes = plt.subplots(n_terminals, 2, figsize=(10, 3 * n_terminals))
    for s in range(n_terminals):
        for j, param_name in enumerate([r'$\log R_T$', r'$\log C_T$']):
            ax = axes[s, j]
            samples = theta_l_samples[:, s, j]
            ax.hist(samples, bins=30, density=True, alpha=0.7, color='steelblue', edgecolor='white')
            ax.axvline(true_theta_l[s, j], color='red', linestyle='--', linewidth=2, label='True')
            ax.axvline(np.mean(samples), color='green', linestyle='-', linewidth=2, label='Mean')
            ax.set_xlabel(param_name)
            ax.set_title(f'{SITE_ORDER[s]}: {param_name}')
            if s == 0 and j == 0:
                ax.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'posterior_local.png'), dpi=150, bbox_inches='tight')
    plt.close()


def plot_training_losses(all_losses: List, output_dir: str):
    n_rounds = len(all_losses)
    fig, axes = plt.subplots(n_rounds, 2, figsize=(12, 4 * n_rounds))
    if n_rounds == 1:
        axes = axes.reshape(1, -1)

    for r, (train_local, val_local, train_global, val_global) in enumerate(all_losses):
        ax = axes[r, 0]
        ax.plot(train_local, label='Train', alpha=0.8)
        ax.plot(val_local, label='Val', alpha=0.8)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Loss')
        ax.set_title(f'Round {r}: Local Likelihood')
        ax.legend()
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3)

        ax = axes[r, 1]
        if len(train_global) > 0:
            ax.plot(train_global, label='Train', alpha=0.8)
            ax.plot(val_global, label='Val', alpha=0.8)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Loss')
        ax.set_title(f'Round {r}: Global Posterior')
        ax.legend()
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_losses.png'), dpi=150, bbox_inches='tight')
    plt.close()


def plot_posterior_predictive(net: Network, base: BaseParams, config: TFMPEConfig,
                             theta_g_samples: np.ndarray, theta_l_samples: np.ndarray,
                             y_obs: np.ndarray, output_dir: str, n_ppc: int = 20):
    n_terminals = len(SITE_ORDER)
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()
    t_grid = np.linspace(0, base.T, config.N_t, endpoint=False)

    for s_idx, (ax, site_name) in enumerate(zip(axes, SITE_ORDER)):
        y_site = y_obs[0, 0, s_idx, :, 0]  # sample 0, patient 0, site s_idx
        ax.plot(t_grid, y_site, 'k-', linewidth=2, label='Observed', zorder=10)

        indices = np.random.choice(len(theta_g_samples), min(n_ppc, len(theta_g_samples)), replace=False)
        for idx in tqdm(indices, desc=f"PPC {site_name}", leave=False):
            theta_g_i = jnp.array(theta_g_samples[idx])
            theta_l_i = jnp.array(theta_l_samples[idx])
            try:
                resampled = simulate_5cycles_then_sample(
                    net, base, theta_g_i, theta_l_i,
                    N_t=config.N_t, nx=config.nx, dt_init=config.dt_init
                )
                y_pred = resampled[site_name]['q']
                ax.plot(t_grid, y_pred, 'b-', alpha=0.2, linewidth=0.5)
            except Exception:
                pass

        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Flow (m³/s)')
        ax.set_title(f'{site_name}')
        ax.grid(True, alpha=0.3)
        if s_idx == 0:
            ax.plot([], [], 'b-', alpha=0.5, label='Posterior predictive')
            ax.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'posterior_predictive.png'), dpi=150, bbox_inches='tight')
    plt.close()


def create_summary_table(theta_g_samples: np.ndarray, theta_l_samples: np.ndarray,
                         true_theta_g: np.ndarray, true_theta_l: np.ndarray,
                         output_dir: str) -> pd.DataFrame:
    rows = []

    global_names = ['log_beta', 'log_mu', 'log_Qin']
    for i, name in enumerate(global_names):
        samples = theta_g_samples[:, i]
        rows.append({
            'Parameter': name,
            'Site': 'Global',
            'True': f'{true_theta_g[i]:.3f}',
            'Mean': f'{np.mean(samples):.3f}',
            'Std': f'{np.std(samples):.3f}',
            '2.5%': f'{np.percentile(samples, 2.5):.3f}',
            '97.5%': f'{np.percentile(samples, 97.5):.3f}',
            'Covers True': 'Yes' if np.percentile(samples, 2.5) <= true_theta_g[i] <= np.percentile(samples, 97.5) else 'No'
        })

    local_names = ['log_R_T', 'log_C_T']
    for s_idx, site_name in enumerate(SITE_ORDER):
        for j, param_name in enumerate(local_names):
            samples = theta_l_samples[:, s_idx, j]
            true_val = true_theta_l[s_idx, j]
            rows.append({
                'Parameter': param_name,
                'Site': site_name,
                'True': f'{true_val:.3f}',
                'Mean': f'{np.mean(samples):.3f}',
                'Std': f'{np.std(samples):.3f}',
                '2.5%': f'{np.percentile(samples, 2.5):.3f}',
                '97.5%': f'{np.percentile(samples, 97.5):.3f}',
                'Covers True': 'Yes' if np.percentile(samples, 2.5) <= true_val <= np.percentile(samples, 97.5) else 'No'
            })

    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(output_dir, 'posterior_summary.csv'), index=False)

    latex_str = df.to_latex(index=False, caption='Posterior summary statistics',
                            label='tab:posterior_summary')
    with open(os.path.join(output_dir, 'posterior_summary.tex'), 'w') as f:
        f.write(latex_str)

    return df


# =============================================================================
# MAIN INFERENCE PIPELINE
# =============================================================================

def run_inference(config: Optional[TFMPEConfig] = None):
    if config is None:
        config = TFMPEConfig()

    os.makedirs(config.output_dir, exist_ok=True)

    print("=" * 70)
    print("TFMPE HEMODYNAMICS INFERENCE")
    print("=" * 70)

    net = create_arch_network()
    base = BaseParams(T=1.0)
    n_patients = 2
    print(f"Number of patients:{n_patients}")
    n_terminals = len(terminal_ids_ordered(net))  # should be 4

    print(f"\nNetwork: {net.n_vessels} vessels, {n_terminals} terminals")
    print(f"Sites: {SITE_ORDER}")
    print(f"Config: N_t={config.N_t}, n_rounds={config.n_rounds}, n_samples/round={config.n_samples_per_round}")

    labeller = Labeller.for_keys(['log_beta','log_mu','log_Qin','log_Rt','log_C','y'])

    print("\nCreating interface functions...")
    prior_fn = create_prior_fn(net, base, n_terminals)
    local_fn = create_local_fn(net, base, n_terminals)
    simulator_fn = create_simulator_fn(net, base, config, n_terminals)

    print("Generating true parameters and observations...")
    rng = random.PRNGKey(42)
    rng, key = random.split(rng)
    true_params, _ = prior_fn(key, n=n_patients, n_samples=1)

    true_theta_g = np.array([
        float(true_params['log_beta'][0, 0, 0]),
        float(true_params['log_mu'][0, 0, 0]),
        float(true_params['log_Qin'][0, 0, 0]),
    ])
    true_theta_l = np.stack([
        np.array(true_params['log_Rt'][0, 0, :, 0]),
        np.array(true_params['log_C'][0, 0, :, 0])
    ], axis=-1)

    print(f"\nTrue parameters:")
    print(f"  beta_scale = {np.exp(true_theta_g[0]):.2e}")
    print(f"  mu = {np.exp(true_theta_g[1]):.4f}")
    print(f"  Q_in = {np.exp(true_theta_g[2]):.1f} mL/s")

    print("\nRunning simulation to generate observations...")
    rng, key = random.split(rng)
    y_obs_dict, _ = simulator_fn(key, true_params, n=n_patients)
    y_obs = {'y': y_obs_dict['y'].astype(jnp.float32)}
    print(f"Observation shape: {y_obs['y'].shape}, dtype: {y_obs['y'].dtype}")

    print("\nInitializing TFMPE model...")
    rng, key = random.split(rng)
    # Build the SAME independence used in training/sampling masks
    independence = Independence()
 
     # IMPORTANT: token template must match training group count (n_groups)
    template_params, _ = prior_fn(key, n=n_patients, n_samples=10)
    params_tokens = Tokens.from_pytree(
        template_params,
        sample_ndims=1,
        labeller=labeller,
        independence=independence,
    )

    transformer_config = TransformerConfig(
        latent_dim=config.latent_dim,
        n_encoder=config.n_encoder,
        n_decoder=config.n_decoder,
        n_heads=config.n_heads,
        n_ff=config.n_ff,
    )

    rngs = nnx.Rngs(params=random.PRNGKey(0), dropout=random.PRNGKey(1))
    transformer = Transformer(config=transformer_config, tokens=params_tokens, rngs=rngs)
    base_dist = NormalDistribution(rngs=rngs)

    tfmpe = TFMPE(
        vf_network=transformer,
        base_dist=base_dist,
        solver=diffrax.Dopri5(),
        ode_kwargs={'rtol': 1e-3, 'atol': 1e-3}
    )


    optimizer = optax.adam(learning_rate=config.learning_rate)
    opt = nnx.Optimizer(tfmpe, optimizer, wrt=nnx.Param)
    effective_batch_size = min(config.batch_size, config.n_samples_per_round)

    print("\n" + "=" * 70)
    print("TRAINING (fit_bottom_up)")
    print("=" * 70)
    print(f"  n_samples_per_round: {config.n_samples_per_round}")
    print(f"  n_iter_per_round: {config.n_iter_per_round}")
    print(f"  batch_size: {effective_batch_size} (adjusted from {config.batch_size})")
    print(f"  learning_rate: {config.learning_rate}")
    print("\nStarting training...")

    t_start = time.time()
    rng, key = random.split(rng)

    trained_tfmpe, all_losses = fit_bottom_up(
        tfmpe=tfmpe,
        y_obs=y_obs,
        simulator_fn=simulator_fn,
        prior_fn=prior_fn,
        local_fn=local_fn,
        global_names=['log_beta', 'log_mu', 'log_Qin'],
        n_groups=n_patients,
        n_rounds=config.n_rounds,
        n_samples_per_round=config.n_samples_per_round,
        n_val_samples=config.n_val_samples,
        opt=opt,
        n_iter_per_round=config.n_iter_per_round,
        batch_size=effective_batch_size,
        rng=key,
        independence=independence,
        labeller=labeller,
    )

    t_train = time.time() - t_start
    print(f"\nTraining complete! (took {t_train:.1f}s)")

    plot_training_losses(all_losses, config.output_dir)

    print("\n" + "=" * 70)
    print("POSTERIOR SAMPLING (sample_posterior)")
    print("=" * 70)

    print("Creating context tokens from observations...")
    y_obs_for_sampling = _repeat_context_for_sampling(y_obs, config.n_posterior_samples)
    context_tokens = Tokens.from_pytree(
        y_obs_for_sampling,
        sample_ndims=1,
        labeller=labeller,
        independence=independence,
    )
    print(f"Creating parameter template for {config.n_posterior_samples} samples...")
    params_template = {
        'log_beta': jnp.zeros((config.n_posterior_samples, 1, 1), dtype=jnp.float32),
        'log_mu': jnp.zeros((config.n_posterior_samples, 1, 1), dtype=jnp.float32),
        'log_Qin': jnp.zeros((config.n_posterior_samples, 1, 1), dtype=jnp.float32),
        'log_Rt': jnp.zeros((config.n_posterior_samples, n_patients, n_terminals, 1), dtype=jnp.float32),
        'log_C': jnp.zeros((config.n_posterior_samples, n_patients, n_terminals, 1), dtype=jnp.float32),
    }
    params_tokens = Tokens.from_pytree(
        params_template,
        sample_ndims=1,
        labeller=labeller,
        independence=independence,
    )
    print(f"Drawing {config.n_posterior_samples} posterior samples...")
    with tqdm(total=1, desc="Sampling posterior") as pbar:
        posterior_samples = trained_tfmpe.sample_posterior(
            context=context_tokens,
            params=params_tokens
        )
        pbar.update(1)

    samples_dict = posterior_samples.decode()
    theta_g_samples = np.stack([
        np.array(samples_dict['log_beta'][:, 0, 0]),
        np.array(samples_dict['log_mu'][:, 0, 0]),
        np.array(samples_dict['log_Qin'][:, 0, 0]),
    ], axis=-1)

    theta_l_samples = np.stack([
        np.array(samples_dict['log_Rt'][:, 0, :, 0]),  # patient 0
        np.array(samples_dict['log_C'][:, 0, :, 0]),   # patient 0
    ], axis=-1)

    print(f"Posterior samples shape: theta_g={theta_g_samples.shape}, theta_l={theta_l_samples.shape}")

    print("\n" + "=" * 70)
    print("GENERATING OUTPUTS")
    print("=" * 70)

    print("\n[1/4] Creating summary table...")
    summary_df = create_summary_table(
        theta_g_samples, theta_l_samples, true_theta_g, true_theta_l, config.output_dir
    )
    print("\nPosterior Summary:")
    print(summary_df.to_string(index=False))

    print("\n[2/4] Creating posterior marginal plots...")
    plot_posterior_marginals(theta_g_samples, theta_l_samples, true_theta_g, true_theta_l, config.output_dir)

    print("\n[3/4] Creating posterior predictive plot...")
    plot_posterior_predictive(
        net, base, config, theta_g_samples, theta_l_samples,
        np.array(y_obs['y']), config.output_dir, n_ppc=10
    )

    print("\n[4/4] Saving results...")
    np.savez(
        os.path.join(config.output_dir, 'posterior_samples.npz'),
        theta_g=theta_g_samples,
        theta_l=theta_l_samples,
        true_theta_g=true_theta_g,
        true_theta_l=true_theta_l,
        site_names=SITE_ORDER
    )

    config_dict = {k: v for k, v in config.__dict__.items()}
    with open(os.path.join(config.output_dir, 'config.json'), 'w') as f:
        json.dump(config_dict, f, indent=2)

    print("\n" + "=" * 70)
    print("COMPLETE!")
    print("=" * 70)
    print(f"\nOutputs saved to: {os.path.abspath(config.output_dir)}")

    return trained_tfmpe, posterior_samples, summary_df


# =============================================================================
# ENTRY POINT
# =============================================================================

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="TFMPE Hemodynamics Inference")
    parser.add_argument("--n_rounds", type=int, default=1, help="Number of rounds (currently only 1 supported)")
    parser.add_argument("--n_samples", type=int, default=200)
    parser.add_argument("--n_iter", type=int, default=500)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--n_posterior", type=int, default=500)
    parser.add_argument("--output_dir", type=str, default="tfmpe_hemo_results")
    parser.add_argument("--N_t", type=int, default=50)

    args = parser.parse_args()

    config = TFMPEConfig(
        n_rounds=args.n_rounds,
        n_samples_per_round=args.n_samples,
        n_iter_per_round=args.n_iter,
        batch_size=args.batch_size,
        learning_rate=args.lr,
        n_posterior_samples=args.n_posterior,
        output_dir=args.output_dir,
        N_t=args.N_t
    )

    run_inference(config)