from typing import Dict
import numpy as np
import matplotlib.pyplot as plt

from cfd_data import SITE_ORDER, MMHG_IN_PA

# ------------------------------------------------------------
# Diagnostics & plotting
# ------------------------------------------------------------
def quick_metrics(res_all: Dict[str, Dict[str, np.ndarray]], diag: Dict[str, np.ndarray]) -> Dict[str, float]:
    q_desc = res_all["descending_aorta"]["q"]
    q_inn  = res_all["innominate"]["q"]
    q_lcc  = res_all["left_common_carotid"]["q"]
    q_lsub = res_all["left_subclavian"]["q"]
    p_bra  = res_all["left_subclavian"]["p"]/MMHG_IN_PA
    return {
        "q_peak_desc": float(np.max(q_desc)),
        "q_peak_inn":  float(np.max(q_inn)),
        "q_peak_lcc":  float(np.max(q_lcc)),
        "q_peak_lsub": float(np.max(q_lsub)),
        "brachial_systolic_mmHg": float(np.max(p_bra)),
        "brachial_diastolic_mmHg": float(np.min(p_bra)),
        "courant_mean": float(np.mean(diag["courant"])) if len(diag["courant"]) else 0.0,
        "courant_max":  float(np.max(diag["courant"])) if len(diag["courant"]) else 0.0,
        "clamp_fraction_max": float(np.max(diag["clamp_fraction"])) if len(diag["clamp_fraction"]) else 0.0,
        "junction_residual_rms_over_qpeak": float(
            np.sqrt(np.mean(np.square(diag["junction_residual"]))) / (np.max(np.abs(q_desc))+1e-9)
        ),
    }

def plot_flows_last_cycle(res_all: Dict[str, Dict[str, np.ndarray]], out_png: str):
    t = res_all["descending_aorta"]["t"]
    fig, axes = plt.subplots(4, 1, figsize=(9,8), sharex=True)
    for ax, name in zip(axes, SITE_ORDER):
        ax.plot(t, res_all[name]["q"])
        ax.set_ylabel(f"Q: {name}")
        ax.grid(True, alpha=0.3)
    axes[-1].set_xlabel("Time in final cycle (s)")
    fig.suptitle("Flows at 4 clinical sites (final cycle)")
    plt.tight_layout(); plt.savefig(out_png, dpi=160, bbox_inches="tight"); plt.close()

def plot_brachial_pressure(res_all: Dict[str, Dict[str, np.ndarray]], out_png: str):
    t = res_all["left_subclavian"]["t"]
    pmm = res_all["left_subclavian"]["p"] / MMHG_IN_PA
    plt.figure(figsize=(8,3.5)); plt.plot(t, pmm)
    plt.xlabel("Time in final cycle (s)"); plt.ylabel("Pressure (mmHg)")
    plt.title("Brachial proxy (left subclavian) — final cycle")
    plt.grid(True, alpha=0.3); plt.tight_layout(); plt.savefig(out_png, dpi=160, bbox_inches="tight"); plt.close()

def plot_diagnostics(diag: Dict[str, np.ndarray], out_png: str):
    t = diag["t"]
    fig, axes = plt.subplots(3,1, figsize=(9,7), sharex=True)
    axes[0].plot(t, diag["junction_residual"]); axes[0].set_ylabel("Junction residual (Q)")
    axes[1].plot(t, diag["courant"]);           axes[1].set_ylabel("Courant")
    axes[2].plot(t, diag["clamp_fraction"]);    axes[2].set_ylabel("Clamp fraction"); axes[2].set_xlabel("Time (s)")
    for ax in axes: ax.grid(True, alpha=0.3)
    plt.tight_layout(); plt.savefig(out_png, dpi=160, bbox_inches="tight"); plt.close()
