import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", "cpu")

import os
import json

import numpy as np

import jax
from jax import random
from tqdm.auto import tqdm

from cfd_data import (
    create_arch_network, Network, BaseParams,
    SITE_ORDER, terminal_ids_ordered,
    default_rcr, st_like_refs,
    sample_priors, pack_theta, token_mask,
    build_joint_observation, build_stage1_locals,
)

from cfd_sim import simulate_5cycles_then_sample
from cfd_viz import quick_metrics, plot_flows_last_cycle, plot_brachial_pressure, plot_diagnostics

# ------------------------------------------------------------
# Stage-1 & Stage-2 dataset builders
# ------------------------------------------------------------
def make_stage1_dataset(rng: np.random.Generator,
                        N_full_sims: int,
                        net: Network, base: BaseParams,
                        N_t: int, nx: int, dt_init: float,
                        eta: float):
    tids = terminal_ids_ordered(net)
    n_term = len(tids)
    D = 3 + 2*n_term

    rcr_ref = default_rcr(net, base)

    thetas = []
    y_list = []
    masks  = []
    site_idx = []

    for _ in tqdm(range(N_full_sims), desc="Stage-1 sims"):
        key = random.PRNGKey(rng.integers(0, 2**32-1))
        theta_g, theta_loc = sample_priors(key, net, base, rcr_ref)
        # simulate
        resampled, _diag = simulate_5cycles_then_sample(net, base, theta_g, theta_loc, N_t, nx, dt_init)
        y_locals = build_stage1_locals(resampled, eta, rng)
        # pack params
        theta_flat = pack_theta(theta_g, theta_loc)
        # append slices (one per terminal)
        for s in range(n_term):
            thetas.append(np.array(theta_flat))
            y_list.append(y_locals[s])
            masks.append(token_mask(n_term, s))
            site_idx.append(s)

    thetas = np.stack(thetas, axis=0)          # [N_full_sims*n_term, D]
    ys     = np.stack(y_list, axis=0)          # [N_full_sims*n_term, N_t]
    masks  = np.stack(masks, axis=0)           # [N_full_sims*n_term, D]
    site_idx = np.array(site_idx)              # [N_full_sims*n_term]
    y_mean = ys.mean(axis=0); y_std = ys.std(axis=0) + 1e-6
    return thetas, ys, masks, site_idx, y_mean, y_std

def make_stage2_dataset(rng: np.random.Generator,
                        N_full_sims: int,
                        net: Network, base: BaseParams,
                        N_t: int, nx: int, dt_init: float,
                        eta: float):
    tids = terminal_ids_ordered(net)
    n_term = len(tids)
    D = 3 + 2*n_term
    rcr_ref = default_rcr(net, base)

    thetas = []
    ys     = []
    masks  = []

    for _ in tqdm(range(N_full_sims), desc="Stage-2 sims"):
        key = random.PRNGKey(rng.integers(0, 2**32-1))
        theta_g, theta_loc = sample_priors(key, net, base, rcr_ref)
        resampled, _diag = simulate_5cycles_then_sample(net, base, theta_g, theta_loc, N_t, nx, dt_init)
        y, y_clean, sigmas = build_joint_observation(resampled, N_t, eta, rng)
        thetas.append(np.array(pack_theta(theta_g, theta_loc)))
        ys.append(y.astype(np.float64))
        masks.append(np.ones((D,), dtype=np.float64))

    thetas = np.stack(thetas, axis=0)   # [N2, D]
    ys     = np.stack(ys, axis=0)       # [N2, 4*N_t+2]
    masks  = np.stack(masks, axis=0)    # [N2, D]
    y_mean = ys.mean(axis=0); y_std = ys.std(axis=0) + 1e-6
    return thetas, ys, masks, y_mean, y_std


# ------------------------------------------------------------
# Main
# ------------------------------------------------------------
def main():
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--N_t", type=int, default=50, help="samples per cycle per site")
    p.add_argument("--eta", type=float, default=0.05, help="heteroscedastic flow noise fraction")
    p.add_argument("--stage1_full_sims", type=int, default=20)
    p.add_argument("--stage2_full_sims", type=int, default=20)
    p.add_argument("--nx", type=int, default=81)
    p.add_argument("--dt_init", type=float, default=2e-4)
    p.add_argument("--outdir", type=str, default="artifacts")
    args = p.parse_args()

    np.random.seed(args.seed)

    # Folders
    figs_dir = os.path.join(args.outdir, "figs")
    ds_dir   = os.path.join(args.outdir, "datasets")
    os.makedirs(figs_dir, exist_ok=True)
    os.makedirs(ds_dir, exist_ok=True)

    print(f"Using JAX devices: {[d.platform for d in jax.devices()]}")

    # Network & base params
    net = create_arch_network()
    base = BaseParams(T=1.0)

    # ---- One sanity simulation at reference hyperparams (for plots/metrics) ----
    rng0 = random.PRNGKey(999)
    rcr_ref = default_rcr(net, base)
    theta_g0, theta_l0 = sample_priors(rng0, net, base, rcr_ref)  # just a random plausible draw
    resampled, diag = simulate_5cycles_then_sample(net, base, theta_g0, theta_l0,
                                                   N_t=args.N_t, nx=args.nx, dt_init=args.dt_init)

    # Plots
    plot_flows_last_cycle(resampled, os.path.join(figs_dir, "flows_4sites_final_cycle.png"))
    plot_brachial_pressure(resampled, os.path.join(figs_dir, "brachial_pressure_final_cycle.png"))
    plot_diagnostics(diag, os.path.join(figs_dir, "diagnostics.png"))

    # Metrics
    metrics = quick_metrics(resampled, diag)
    with open(os.path.join(args.outdir, "sanity_metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2)
    print("Sanity metrics:", json.dumps(metrics, indent=2))

    # Also save a held-out observation case (clean + noisy joint y)
    rng_np = np.random.default_rng(args.seed+101)
    y_obs, y_obs_clean, _ = build_joint_observation(resampled, args.N_t, args.eta, rng_np)
    np.savez_compressed(os.path.join(ds_dir, "heldout_obs_case.npz"),
                        y_obs=y_obs, y_clean=y_obs_clean, N_t=args.N_t, sites=SITE_ORDER)

    # ---- Stage-1 dataset (sliced locals) ----
    th1, y1, m1, site_idx, y1_mean, y1_std = make_stage1_dataset(
        rng=np.random.default_rng(args.seed+1),
        N_full_sims=args.stage1_full_sims,
        net=net, base=base,
        N_t=args.N_t, nx=args.nx, dt_init=args.dt_init,
        eta=args.eta
    )
    np.savez_compressed(os.path.join(ds_dir, "stage1_sliced.npz"),
                        thetas=th1, y=y1, masks=m1, site_idx=site_idx,
                        y_mean=y1_mean, y_std=y1_std)

    # ---- Stage-2 dataset (joint) ----
    th2, y2, m2, y2_mean, y2_std = make_stage2_dataset(
        rng=np.random.default_rng(args.seed+2),
        N_full_sims=args.stage2_full_sims,
        net=net, base=base,
        N_t=args.N_t, nx=args.nx, dt_init=args.dt_init,
        eta=args.eta
    )
    np.savez_compressed(os.path.join(ds_dir, "stage2_joint.npz"),
                        thetas=th2, y=y2, masks=m2, y_mean=y2_mean, y_std=y2_std)

    # ---- Metadata ----
    tids = terminal_ids_ordered(net)
    Rt_ref, C_ref = st_like_refs(net, base)
    meta = {
        "N_t": int(args.N_t),
        "sites_order": SITE_ORDER,  # (descending, innominate, LCC, LSubcl)
        "D": int(3 + 2*len(tids)),
        "globals_tokens": ["log_beta_scale", "log_mu", "log_Qin"],
        "locals_per_terminal": ["log_R_T", "log_C_T"],
        "n_terminals": int(len(tids)),
        "terminal_vessels": [net.vessels[i].name for i in tids],
        "priors": {
            "globals": {
                "log_beta_scale": {"mean": np.log(base.beta_scale_mean), "sd": 0.3},
                "log_mu":         {"mean": np.log(0.004),  "sd": 0.2},
                "log_Qin":        {"mean": np.log(85.0),   "sd": 0.2}
            },
            "locals": {
                "log_R_T_sd": 0.4,
                "log_C_T_sd": 0.4,
                "R_T_ref": Rt_ref.tolist(),
                "C_T_ref":  C_ref.tolist(),
                "tau_target_s": base.tau_target_s,
                "term_C_scale": base.term_C_scale
            }
        },
        "observation_layout": "concat[Q_desc(0:N_t), Q_inn(N_t:2N_t), Q_LCC(2N_t:3N_t), Q_LSubcl(3N_t:4N_t), systolic, diastolic]",
        "noise": {"eta_flow": args.eta, "sigma_pressure_mmHg": 2.5},
        "sim_config": {
            "nx": args.nx, "dt_init": args.dt_init, "cycles": 5, "keep_final_cycle": True
        },
        "units": {
            "length": "m", "pressure": "Pa (mmHg where stated)", "flow": "m^3/s (internally); plots unscaled"
        },
        "notes": "RCR outlets with locals (R_T=R2, C_T). R_T_ref and C_T_ref from simple ST-like scaling."
    }
    with open(os.path.join(ds_dir, "dataset_meta.json"), "w") as f:
        json.dump(meta, f, indent=2)

    print(f"\nArtifacts written to: {os.path.abspath(args.outdir)}")
    print(" - figs/: flows_4sites_final_cycle.png, brachial_pressure_final_cycle.png, diagnostics.png")
    print(" - datasets/: stage1_sliced.npz, stage2_joint.npz, dataset_meta.json, heldout_obs_case.npz")
    print(" - sanity_metrics.json")

if __name__ == "__main__":
    main()
