import numpy as np

import jax
import jax.numpy as jnp
from jax import lax

from cfd_data import (
    Network, BaseParams,
    pack_geometry, terminal_ids_ordered, default_rcr
)

# ------------------------------------------------------------
# Wall model & derived terms
#   p(A) - p_ext = (beta/A0) * (sqrt(A) - sqrt(A0))
#   Flux term psi(A) = ∫ (A/ρ) (dp/dA) dA = (β / (3 ρ A0)) (A^{3/2} - A0^{3/2})
# ------------------------------------------------------------
def psi_from_A(A, A0, beta, rho):
    A = jnp.maximum(A, 1e-16*A0)
    return (beta/(3.0*rho*A0)) * (A**1.5 - A0**1.5)

def pressure_from_A(A, A0, beta, p_ext):
    A = jnp.maximum(A, 1e-16*A0)
    return p_ext + (beta/A0) * (jnp.sqrt(A) - jnp.sqrt(A0))

def A_from_pressure(p, A0, beta, p_ext):
    # Inverse of p(A): sqrt(A) = sqrt(A0) + (A0/beta)*(p - p_ext)
    s = jnp.sqrt(A0) + (A0 / beta) * (p - p_ext)
    s = jnp.maximum(s, 1e-12)
    return s**2

def characteristic_speed(A, A0, beta, rho):
    # c^2(A) = (A/ρ) * dp/dA = (β/(2 ρ A0)) * sqrt(A)
    A = jnp.maximum(A, 1e-16*A0)
    c2 = (beta/(2.0*rho*A0)) * jnp.sqrt(A)
    return jnp.sqrt(jnp.maximum(c2, 1e-16))

# ------------------------------------------------------------
# Numerical scheme (Lax–Wendroff) with general parent→k children junction
# Momentum friction term: -(K_R Q)/A with K_R = 2π ν ζ / R^2 ; ν = μ / ρ
# ------------------------------------------------------------
def make_stepper(net: Network, base: BaseParams,
                 beta_scale: float, mu_pa_s: float,
                 nx: int, dt0: float,
                 Qin_amp_ml_s: float, q_cap: float):
    n = net.n_vessels
    r0, h, A0, dx = pack_geometry(net, nx)

    rho = jnp.asarray(base.rho)
    # --- vessel-wise β from a target wave speed profile c0(r) ---
    #   c0(r) = c0_root * (r/r_root)^{expo}  with a mild increase in smaller vessels
    #   β_i   = 2 ρ c0(r)^2 sqrt(A0_i)
    def beta_from_c0(A0, rho, c0_root=5.0, r0_root=None, expo=-0.05):
        r = jnp.sqrt(A0 / jnp.pi)
        r_ref = r0_root if r0_root is not None else r[0]
        c = c0_root * (r / r_ref) ** expo
        return 2.0 * rho * (c**2) * jnp.sqrt(A0)
    beta_nom = beta_from_c0(A0, rho, c0_root=5.0, r0_root=r0[0], expo=-0.05)
    beta_i = beta_nom * (beta_scale / 1.0e6)
    mu  = jnp.asarray(mu_pa_s)
    nu  = mu / rho
    # Shape factor zeta ~ 1 (laminar Poiseuille approximation proxy)
    r_shape = r0
    K_R = 2.0 * jnp.pi * nu / jnp.maximum(r_shape**2, 1e-12)

    # Topology helpers
    children_list = [net.children(i) for i in range(n)]
    is_terminal = jnp.array([1 if len(children_list[i]) == 0 else 0 for i in range(n)], dtype=jnp.int32)

    dt0_j = jnp.asarray(dt0, dtype=jnp.float64)
    Qin_amp = jnp.asarray(Qin_amp_ml_s / 1e6, dtype=jnp.float64)  # mL/s -> m^3/s (1e-6)

    def lw_interior(U, dt):
        A = U[:, 0, :]
        Q = U[:, 1, :]

        F1 = Q
        F2 = Q**2 / jnp.maximum(A, 1e-16) + psi_from_A(A, A0[:, None], beta_i[:, None], rho)
        F  = jnp.stack([F1, F2], axis=1)

        # source (friction)
        S_Q = -(K_R[:, None] * Q) / jnp.maximum(A, 1e-16)
        S = jnp.stack([jnp.zeros_like(Q), S_Q], axis=1)

        # half-step (LW)
        F_diff = F[:, :, 1:] - F[:, :, :-1]
        S_avg  = 0.5 * (S[:, :, 1:] + S[:, :, :-1])
        dt_over_dx = (dt / dx)[:, None, None]
        U_half = 0.5*(U[:, :, :-1] + U[:, :, 1:]) - 0.5*dt_over_dx*F_diff + 0.5*dt*S_avg

        A_half = U_half[:, 0, :]
        Q_half = U_half[:, 1, :]
        F_half2 = Q_half**2 / jnp.maximum(A_half, 1e-16) + psi_from_A(A_half, A0[:, None], beta_i[:, None], rho)
        F_half = jnp.stack([Q_half, F_half2], axis=1)

        U_new = U.at[:, :, 1:-1].set(
            U[:, :, 1:-1] - dt_over_dx * (F_half[:, :, 1:] - F_half[:, :, :-1]) + dt * S[:, :, 1:-1]
        )

        # clamp A to keep physical
        A_new = jnp.clip(U_new[:, 0, :], 0.6*A0[:, None], 1.4*A0[:, None])
        U_new = U_new.at[:, 0, :].set(A_new)
        U_new = jnp.where(jnp.isfinite(U_new), U_new, U)
        return U_new

    def inlet_bc(U, t):
        # Smooth ramp + sinusoid inflow at aortic_root (idx 0)
        ramp = 1.0 - jnp.exp(-5.0 * t / base.T)
        q_in = Qin_amp * ramp * (1.0 + 0.5 * jnp.sin(2.0 * jnp.pi * t / base.T))  # m^3/s
        A0_neighbor = U[0, 0, 1]
        U = U.at[0, 1, 0].set(q_in)
        # slight relaxation toward A0 instead of hard copy of neighbor
        U = U.at[0, 0, 0].set(0.9 * A0_neighbor + 0.1 * A0[0])
        return U

    def junction_bc(U):
        # Handle the unique junction at root: parent 0 with K children (1..)
        kids = children_list[0]
        if len(kids) == 0:
            return U
        # Use the LAST INTERIOR cell of the parent (index -2), not the ghost -1.
        # The ghost cell of a non-terminal vessel is not advanced by the scheme.
        A_p_int = U[0, 0, -2]
        q_p_int = U[0, 1, -2]
        # Mirror interior values to the parent ghost cell for consistency
        U = U.at[0, 0, -1].set(A_p_int)
        U = U.at[0, 1, -1].set(q_p_int)
        p_parent = pressure_from_A(A_p_int, A0[0], beta_i[0], base.p_ext)
        # For each child: set first cell area to match junction pressure,
        # then distribute parent outflow by characteristic impedance weights
        A_c = []
        Zc  = []
        for cidx in kids:
            A_c0 = U[cidx, 0, 0]
            A_c_upd = A_from_pressure(p_parent, A0[cidx], beta_i[cidx], base.p_ext)
            A_c.append(A_c_upd)
            # characteristic impedance
            c_speed = characteristic_speed(A_c_upd, A0[cidx], beta_i[cidx], base.rho)
            Zc.append(base.rho * c_speed / jnp.maximum(A_c_upd, 1e-16))
        A_c = jnp.stack(A_c)
        Zc  = jnp.stack(Zc)
        w = 1.0 / jnp.maximum(Zc, 1e-16)
        q_parent_out = q_p_int
        q_children = q_parent_out * w / (jnp.sum(w) + 1e-16)

        # write back
        for j, cidx in enumerate(kids):
            U = U.at[cidx, 0, 0].set(A_c[j])
            U = U.at[cidx, 1, 0].set(q_children[j])
        return U

    def rcr_outlets(U, P_c, dt, R1, R2, C, Pext):
        # For each terminal vessel, update capacitor and set boundary
        Q_out = U[:, 1, -1]
        # Only terminals update:
        C_safe  = jnp.where(is_terminal==1, jnp.maximum(C, 1e-16), 1.0)
        R2_safe = jnp.where(is_terminal==1, jnp.maximum(R2, 1e-16), 1.0)
        dP_c = (Q_out - (P_c - Pext) / R2_safe) / C_safe
        P_c_next = jnp.where(is_terminal==1, P_c + dt * dP_c, P_c)
        P_out = jnp.where(is_terminal==1, R1 * Q_out + P_c_next, base.p_ext)

        # Enforce p(A_last) = P_out at terminals (explicit inverse)
        A_last  = U[:, 0, -1]
        A_targ  = A_from_pressure(P_out, A0, beta_i, base.p_ext)
        # clamp to keep terminals physical and avoid huge p from tiny Δsqrt(A)
        A_targ  = jnp.clip(A_targ, 0.6*A0, 1.4*A0)
        A_new   = jnp.where(is_terminal==1, A_targ, A_last)
        U = U.at[:, 0, -1].set(A_new)
        # Simple non-reflective Q at boundary for non-terminals; terminals copy interior neighbor
        U = U.at[:, 1, -1].set(jnp.where(is_terminal==1, U[:, 1, -2], U[:, 1, -1]))
        return U, P_c_next

    def step(carry, _):
        U, P_c, t, dt, R1, R2, C, Pext = carry

        # CFL
        A = U[:, 0, :]
        Q = U[:, 1, :]
        c_speed = characteristic_speed(A, A0[:, None], beta_i[:, None], rho)
        u = Q / jnp.maximum(A, 1e-16)
        # CFL timestep (Lax–Wendroff stable for CFL <= ~0.5)
        smax = jnp.max(jnp.abs(u[:, 1:-1]) + c_speed[:, 1:-1])
        smax = jnp.where(jnp.isfinite(smax), smax, 1.0)  # Falls back to 1.0
        dx_min = jnp.min(dx)
        dt_cfl = 0.45 * dx_min / jnp.maximum(smax, 1e-9)
        dt = jnp.minimum(dt_cfl, dt0_j)   # But dt0_j might be larger!

        # Interior LW
        U = lw_interior(U, dt)

        # Inlet
        U = inlet_bc(U, t)

        # Junction
        U = junction_bc(U)

        # Outlets (RCR)
        U, P_c = rcr_outlets(U, P_c, dt, R1, R2, C, Pext)

        t_next = t + dt
        return (U, P_c, t_next, dt, R1, R2, C, Pext), None

    @jax.jit
    def run_chunk(U, P_c, t, dt, R1, R2, C, Pext, steps: int):
        def body(i, carry):
            carry, _ = step(carry, None)
            return carry
        U_n, P_c_n, t_n, dt_n, *_ = lax.fori_loop(0, steps, body, (U, P_c, t, dt, R1, R2, C, Pext))
        return U_n, P_c_n, t_n, dt_n

    return run_chunk, (r0, A0, beta_i, rho, children_list, is_terminal, dx)

# ------------------------------------------------------------
# Simulation wrapper (3 cycles, keep final)
# ------------------------------------------------------------
def simulate_5cycles_then_sample(net: Network, base: BaseParams,
                                 theta_g: jnp.ndarray, theta_loc: jnp.ndarray,
                                 N_t: int, nx: int, dt_init: float):
    tids = terminal_ids_ordered(net)
    n = net.n_vessels
    n_term = len(tids)

    # Map globals
    beta_scale = float(jnp.exp(theta_g[0]))       # multiplicative (centered at 1e6)
    mu_pa_s    = float(jnp.exp(theta_g[1]))       # Pa·s
    Qin_ml_s   = float(jnp.exp(theta_g[2]))       # mL/s

    # Diagnostic output for parameter sanity
    print(f"\n=== Simulation Parameters ===")
    print(f"beta_scale: {beta_scale:.2e} Pa·m")
    print(f"mu: {mu_pa_s:.4f} Pa·s")
    print(f"Q_in: {Qin_ml_s:.1f} mL/s ({Qin_ml_s/1e6:.2e} m³/s)")

    # Build RCR with locals (R2=R_T, C=C_T); keep R1 from defaults
    rcr = default_rcr(net, base, qin_ml_s=Qin_ml_s)
    R1 = rcr.R1; R2 = rcr.R2; C = rcr.C; Pext = rcr.Pext
    for s, vidx in enumerate(tids):
        R2 = R2.at[vidx].set(float(jnp.exp(theta_loc[s, 0])))
        C  = C.at[vidx].set(float(jnp.exp(theta_loc[s, 1])))
    
    # Print terminal boundary condition parameters
    for s, vidx in enumerate(tids):
        print(f"{net.vessels[vidx].name:25s}: R_T={float(R2[vidx]):.2e} Pa·s/m³, C_T={float(C[vidx]):.2e} m³/Pa")
    
    # Calculate characteristic impedance for comparison
    for s, vidx in enumerate(tids):
        tau = float(R2[vidx] * C[vidx])
        print(f"  → Time constant τ = R·C = {tau:.3f} s (should be << 1.0s for good pulsatility)")
    print("="*50 + "\n")

    # Stepper
    run_chunk, meta = make_stepper(net, base, beta_scale, mu_pa_s, nx, dt_init, Qin_ml_s, q_cap=1e9)
    r0, A0, beta_i, rho, children_list, is_terminal, dx = meta

    # State init
    U = jnp.zeros((n, 2, nx), dtype=jnp.float64)
    U = U.at[:, 0, :].set(A0[:, None])  # A=A0
    P_c = jnp.where(is_terminal==1, Pext, 0.0)

    # Time-based run target and chunk size (dt will adapt each step)
    steps_per_cycle_est = int(np.ceil(base.T / dt_init))
    chunk = max(1, steps_per_cycle_est // 200)  # small chunks so t tracks target well
    t_target = 5.0 * base.T
 

    # Recorders for diagnostics over full sim
    mids = nx // 2
    ls_idx = next(i for i,v in enumerate(net.vessels) if v.name=="left_subclavian")
    t_hist = [[] for _ in range(n)]
    q_hist = [[] for _ in range(n)]
    p_hist = [[] for _ in range(n)]

    t = jnp.asarray(0.0); dt = jnp.asarray(dt_init)
    U_dev, P_c_dev = U, P_c

    # Diagnostics
    diag_t = []
    diag_courant = []
    diag_clamp = []
    diag_junc = []

    # Advance in fixed step-count chunks until we reach the time target
    while float(t) < float(t_target):
        n_now = int(chunk)
        U_dev, P_c_dev, t, dt = run_chunk(U_dev, P_c_dev, t, dt, R1, R2, C, Pext, n_now)

        Uh = np.array(U_dev); Ah = Uh[:, 0, :]; Qh = Uh[:, 1, :]
        # pressures (mid)
        beta_host = np.array(beta_i)
        A0_host = np.array(A0)
        ph_mid = base.p_ext + (beta_host / A0_host) * (np.sqrt(np.maximum(Ah[:, mids], 1e-16)) - np.sqrt(A0_host))
        pc_host = np.array(P_c_dev)  # capacitor pressures at terminals
        ph_used = ph_mid.copy()
        ph_used[ls_idx] = pc_host[ls_idx]  # use cuff-like terminal pressure for left subclavian

        th = float(t)
        for i,_ in enumerate(net.vessels):
            t_hist[i].append(th)
            q_hist[i].append(Qh[i, mids])
            p_hist[i].append(ph_used[i])

        # diagnostics
        c_speed = np.sqrt(np.maximum((beta_host[:, None]/(2.0*base.rho*A0_host[:, None])) * np.sqrt(np.maximum(Ah, 1e-16)), 1e-16))        
        u = Qh / np.maximum(Ah, 1e-16)
        smax = float(np.max(np.abs(u[:, 1:-1]) + c_speed[:, 1:-1]))
        dx_min = float(np.min(np.array([v.length_m for v in net.vessels])/(nx-1)))
        cfl = float((float(dt) * smax) / max(dx_min, 1e-9))
        diag_courant.append(cfl); diag_t.append(th)
        A_min = 0.6*A0_host[:, None]; A_max = 1.4*A0_host[:, None]
        clamp_frac = float((np.isclose(Ah, A_min).sum() + np.isclose(Ah, A_max).sum()) / Ah.size)
        diag_clamp.append(clamp_frac)
        # junction residual (parent 0)
        if len(children_list[0])>0:
            qpo = float(Uh[0,1,-1]); qc_sum = float(np.sum(Uh[children_list[0],1,0]))
            diag_junc.append(qpo - qc_sum)
        else:
            diag_junc.append(0.0)
    
    # === POST-SIMULATION DIAGNOSTICS ===
    print(f"\n=== Simulation Results ===")
    print(f"Final time: {float(t):.3f} s")
    for i, v in enumerate(net.vessels):
        q_trace = q_hist[i]
        if len(q_trace) > 0:
            q_arr = np.array(q_trace)
            q_mean = np.mean(q_arr)
            q_amp = np.max(q_arr) - np.min(q_arr)
            q_peak = np.max(np.abs(q_arr))
            print(f"{v.name:25s}: Q_mean={q_mean:.2e} m³/s, amplitude={q_amp:.2e}, peak={q_peak:.2e}")
            if q_peak > 0:
                pulsatility_index = q_amp / q_peak
                print(f"  → Pulsatility index: {pulsatility_index:.2%} (should be > 50%)")
    print("="*50 + "\n")
    # === END DIAGNOSTICS ===

    # Collate whole-run traces
    traces = {}
    for i, v in enumerate(net.vessels):
        traces[v.name] = {
            "t": np.array(t_hist[i]),
            "q": np.array(q_hist[i]),
            "p": np.array(p_hist[i]),
        }

    diag = {
        "t": np.array(diag_t),
        "courant": np.array(diag_courant),
        "clamp_fraction": np.array(diag_clamp),
        "junction_residual": np.array(diag_junc),
    }

    # Keep final cycle only: slice last T seconds
    t_all = traces["aortic_root"]["t"]
    t_end = t_all[-1]
    t_start = t_end - base.T + 1e-12
    if t_end < base.T - 1e-6:
        # Run at least one full cycle next time; otherwise the resampler will
        # look step-like because np.interp holds endpoint values.
        # We still return something, but it's flagged via a tiny epsilon shift.
        pass
    def last_cycle(x_t, x_y):
        m = (x_t >= t_start) & (x_t <= t_end + 1e-12)
        return x_t[m], x_y[m]

    last = {}
    for k in traces:
        tt, qq = last_cycle(traces[k]["t"], traces[k]["q"])
        _,   pp = last_cycle(traces[k]["t"], traces[k]["p"])
        last[k] = {"t": tt, "q": qq, "p": pp}

    # Resample last cycle to N_t evenly spaced points [0, T]
    grid = np.linspace(0.0, base.T, N_t, endpoint=False)
    res = {}
    for k in last:
        t0 = last[k]["t"][0] if last[k]["t"].size>0 else 0.0
        # map tt to [0,T]
        tt = last[k]["t"] - t_start
        res[k] = {
            "t": grid,
            "q": np.interp(grid, tt, last[k]["q"]),
            "p": np.interp(grid, tt, last[k]["p"]),
        }
    return res, diag