from dataclasses import dataclass
from typing import List, Optional, Dict, Tuple

import numpy as np

import jax
import jax.numpy as jnp
from jax import random

# ------------------------------------------------------------
# Constants
# ------------------------------------------------------------

MMHG_IN_PA = 133.322  # Pa per mmHg
SITE_ORDER = ["descending_aorta", 
              "innominate", 
              "left_common_carotid", 
              "left_subclavian"]

# ------------------------------------------------------------
# Data classes & Containers
# ------------------------------------------------------------

@dataclass
class Vessel:
    length_m: float
    radius_m: float
    thickness_m: float
    name: str
    parent_idx: int = -1

@dataclass
class Network:
    vessels: List[Vessel]
    n_vessels: int
    def children(self, idx: int) -> List[int]:
        return [i for i, v in enumerate(self.vessels) if v.parent_idx == idx]
    def terminal_ids(self) -> List[int]:
        return [i for i, v in enumerate(self.vessels) if len(self.children(i)) == 0]

@dataclass
class BaseParams:
    # Baseline β chosen so that c0 is O(5–10 m/s) with this p–A law
    beta_base: float = 3.0e3      # Pa·m  (was 1.0e6 — too stiff)
    rho: float = 1060.0           # kg/m^3
    mu_ref: float = 0.004         # Pa·s
    T: float = 1.0                # s (cardiac period)
    p_ext: float = 0.0            # diastolic reference
    # ---- physiological targets to avoid magic numbers ----
    map_target_mmHg: float = 93.0     # mean arterial pressure
    co_target_ml_s:  float = 83.0     # ~5 L/min
    beta_scale_mean: float = 2.0e6    # centre of log-normal prior for stiffness scale
    term_C_scale:    float = 0.60     # tunes PP; lower ⇒ faster τ
    tau_target_s:    float = 0.40     # desired terminal time constant τ = R2*C (s)

@dataclass
class RCRParameters:
    # We use RCR with R1 fixed small; locals are (R_T=R2, C_T=C)
    R1: jnp.ndarray   # (n_v,)
    R2: jnp.ndarray   # (n_v,)  distal resistance (local token R_T)
    C:  jnp.ndarray   # (n_v,)  terminal compliance (local token C_T)
    Pext: jnp.ndarray # (n_v,)

# ------------------------------------------------------------
# Network & Geometry
# ------------------------------------------------------------

def create_arch_network() -> Network:
    """
    Topology:
      0: aortic_root (inlet), parent=-1
         ├─ 1: innominate (terminal)
         ├─ 2: left_common_carotid (terminal)
         ├─ 3: left_subclavian (terminal)
         └─ 4: descending_aorta (terminal segment here)
    """
    vs = [
        Vessel(0.06, 0.012, 0.0015, "aortic_root", -1),
        Vessel(0.05, 0.0065, 0.0010, "innominate", 0),
        Vessel(0.05, 0.0045, 0.0010, "left_common_carotid", 0),
        Vessel(0.06, 0.0060, 0.0010, "left_subclavian", 0),
        Vessel(0.10, 0.0110, 0.0015, "descending_aorta", 0),
    ]
    return Network(vs, len(vs))

def pack_geometry(net: Network, nx: int):
    r0 = jnp.array([v.radius_m for v in net.vessels])
    h  = jnp.array([v.thickness_m for v in net.vessels])
    A0 = jnp.pi * r0**2
    dx = jnp.array([v.length_m / (nx - 1) for v in net.vessels])
    return r0, h, A0, dx


def terminal_ids_ordered(net: Network) -> List[int]:
    """
    Return terminal vessel indices in the same order as SITE_ORDER.
    This makes theta_loc[s] correspond to y slices for SITE_ORDER[s].
    """
    name_to_idx = {v.name: i for i, v in enumerate(net.vessels)}
    return [name_to_idx[name] for name in SITE_ORDER]

# ------------------------------------------------------------
# RCR Reference & Constructor
# ------------------------------------------------------------

# ------------------------------------------------------------
# Reference terminal values from ST-like scaling (very simple)
#   R_T,ref ∝ 1 / r^4 ;  C_T,ref ∝ r^3
# (Used only to center the LogNormal priors)
# ------------------------------------------------------------
def st_like_refs(net: Network, base: BaseParams):
    tids = terminal_ids_ordered(net)
    r = np.array([net.vessels[i].radius_m for i in tids])
    # Scale to plausible magnitudes:
    Rt_ref = (r[0]**4) / (r**4)
    Rt_ref = Rt_ref / Rt_ref.mean() * 1.2e8  # Pa·s/m³ (was 9.0e6 - increased 100x)
    C_ref  = (r**3) / (r[0]**3)
    C_ref  = C_ref / C_ref.mean() * 1.5e-9   # m³/Pa (was 4.0e-6 - decreased 2700x)
    
    return Rt_ref.astype(np.float64), C_ref.astype(np.float64)


# ------------------------------------------------------------
# Outlet model: RCR  (R1 small, locals are R2=R_T and C)
# ------------------------------------------------------------

def default_rcr(net: Network, base: BaseParams, qin_ml_s: Optional[float] = None) -> RCRParameters:
    n = net.n_vessels
    R1 = np.zeros(n, dtype=np.float64)
    R2 = np.zeros(n, dtype=np.float64)
    C  = np.zeros(n, dtype=np.float64)
    Pext = np.zeros(n, dtype=np.float64)

    Rt_ref, C_ref_st = st_like_refs(net, base)
    # Use the *actual* sampled inflow if provided, otherwise fall back to the target CO
    co_used = base.co_target_ml_s if (qin_ml_s is None) else qin_ml_s
    R_SYS_TARGET = (base.map_target_mmHg * MMHG_IN_PA) / (co_used * 1e-6)
    k = R_SYS_TARGET * np.sum(1.0 / Rt_ref)      # because R_eq = k / Σ(1/R_ref)
    Rt_ref = k * Rt_ref
    # set C so τ ≈ tau_target_s across outlets (prevents huge τ on tiny branches)
    C_ref = base.term_C_scale * (base.tau_target_s / Rt_ref)
    tids = terminal_ids_ordered(net)
    # Assign references
    for k, vidx in enumerate(tids):
        # modest proximal (R1) ~ 5–10% of total characteristic scale
        R1[vidx] = 0.02 * Rt_ref[k]
        R2[vidx] = Rt_ref[k]
        C[vidx]  = C_ref[k]
        Pext[vidx] = base.p_ext
    return RCRParameters(jnp.array(R1), jnp.array(R2), jnp.array(C), jnp.array(Pext))

# ------------------------------------------------------------
# Priors & Parameter Unpacking
# ------------------------------------------------------------

# ------------------------------------------------------------
# Priors
#   Globals: beta_scale ~ LogNormal(log 1e6, 0.3) but acts as a multiplier,
#            mu ~ LogNormal(log 0.004, 0.2) Pa·s, Qin ~ LogNormal(log 85, 0.2) mL/s
#   Locals per terminal s: R_T[s] ~ LogNormal(log R_T,ref(s), 0.4),
#                          C_T[s] ~ LogNormal(log C_T,ref(s), 0.4)
# ------------------------------------------------------------
def sample_priors(rng: jax.Array, net: Network, base: BaseParams, rcr_ref: RCRParameters):
    tids = terminal_ids_ordered(net)
    n_term = len(tids)
    keys = random.split(rng, 3 + 2*n_term)

    # Globals
    # We treat beta_scale as multiplicative around 1e6 baseline; put the mean at 1e6, applied to beta_base.
    log_beta_scale = random.normal(keys[0]) * 0.3 + jnp.log(base.beta_scale_mean)
    log_mu         = random.normal(keys[1]) * 0.2 + jnp.log(0.004)
    log_Qin        = random.normal(keys[2]) * 0.2 + jnp.log(85.0)  # mL/s

    theta_g = jnp.array([log_beta_scale, log_mu, log_Qin])

    # Locals
    Rt_ref = np.array([float(rcr_ref.R2[i]) for i in tids])
    C_ref  = np.array([float(rcr_ref.C[i])  for i in tids])

    eps = random.normal(keys[3], (n_term,))
    # anti-correlate so τ ≈ const on average; variability stays log-normal
    log_Rt = jnp.log(Rt_ref) + 0.4 * eps
    log_C  = jnp.log(C_ref)  - 0.4 * eps
    theta_loc = jnp.stack([log_Rt, log_C], axis=1)  # [n_term, 2]
    return theta_g, theta_loc

def pack_theta(theta_g: jnp.ndarray, theta_loc: jnp.ndarray) -> jnp.ndarray:
    return jnp.concatenate([theta_g.ravel(), theta_loc.ravel()], axis=0)

def unpack_theta(theta_flat: jnp.ndarray, n_term: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    theta_g = theta_flat[:3]
    theta_l = theta_flat[3:].reshape(n_term, 2)
    return theta_g, theta_l

def token_mask(n_term: int, site_idx: Optional[int]) -> np.ndarray:
    """
    Mask over [3 + 2*n_term].
    - None => all ones (joint stage)
    - s    => ones for globals + that site's two locals
    """
    D = 3 + 2*n_term
    m = np.zeros((D,), dtype=np.float64); m[:3] = 1.0
    if site_idx is None:
        m[:] = 1.0
    else:
        j = 3 + 2*site_idx
        m[j:j+2] = 1.0
    return m

# ------------------------------------------------------------
# Observations
# ------------------------------------------------------------

# ------------------------------------------------------------
# Observables & noise
#   y = [Q_desc, Q_inn, Q_lcc, Q_lsubcl] (N_t each) + [brachial S, brachial D]
#   σ_site = η * max_t |Q_site(t)| ; η=0.05
#   pressure noise ~ N(0, σ_P^2) with σ_P ≈ 2.5 mmHg
# ------------------------------------------------------------

def build_joint_observation(resampled: Dict[str, Dict[str, np.ndarray]],
                            N_t: int,
                            eta: float,
                            rng: np.random.Generator):
    flows = []
    sigmas = []
    for name in SITE_ORDER:
        q = resampled[name]["q"]
        flows.append(q.copy())
        sigmas.append(eta * max(np.max(np.abs(q)), 1e-9))
    y_flow = np.concatenate(flows, axis=0)

    # Brachial proxy: use left_subclavian pressure (final cycle), get S/D in mmHg
    p_brach_pa = resampled["left_subclavian"]["p"]
    p_brach_mmhg = p_brach_pa / MMHG_IN_PA
    syst = float(np.max(p_brach_mmhg)); diast = float(np.min(p_brach_mmhg))

    # Add noise
    y_noisy = y_flow.copy()
    for i in range(4):
        seg = slice(i*N_t, (i+1)*N_t)
        y_noisy[seg] += rng.normal(0.0, sigmas[i], size=N_t)
    syst_noisy = syst + rng.normal(0.0, 2.5)
    diast_noisy = diast + rng.normal(0.0, 2.5)

    y = np.concatenate([y_noisy, np.array([syst_noisy, diast_noisy])], axis=0)
    y_clean = np.concatenate([np.concatenate(flows, axis=0), np.array([syst, diast])], axis=0)
    return y, y_clean, np.array(sigmas, dtype=np.float64)

def build_stage1_locals(resampled: Dict[str, Dict[str, np.ndarray]],
                        eta: float, rng: np.random.Generator) -> Dict[str, np.ndarray]:
    """
    Stage-1 slicing over terminals s=1..n_o.
    We map s=0..3 to SITE_ORDER (each is a terminal in this network).
    Returns dict with:
      y_local [n_o, N_t], masks [n_o, D], site_idx [n_o]
    """
    y_locals = []
    for name in SITE_ORDER:
        q = resampled[name]["q"]
        sigma = eta * max(np.max(np.abs(q)), 1e-9)
        y_locals.append(q + rng.normal(0.0, sigma, size=q.shape))
    return np.stack(y_locals, axis=0)