# triage_gradio_app_v4.py
# Focused scoring: You vs Truth (primary), AI vs Truth (secondary).
# Markdown renders correctly; Learn-more is collapsible.
# Works on older Gradio (no gr.Box): styling via elem_id + CSS.

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.rcParams["figure.max_open_warning"] = 0

from pathlib import Path
import base64, mimetypes
import gradio as gr
import numpy as np
from matplotlib.colors import ListedColormap

HAS_MODAL = hasattr(gr, "Modal")

ASSET_DIR  = Path("assets")
ASSET_LOGO = ASSET_DIR / "Combined_Logo.jpg"
ASSET_HOW  = ASSET_DIR / "How_Triage_Works.png"

MILD = "Mild (Green)"
INTERMEDIATE = "Intermediate (Orange)"
SEVERE = "Severe (Red)"
LABELS = [MILD, INTERMEDIATE, SEVERE]
TRAFFIC_CMAP = ListedColormap(["#2ecc71", "#f39c12", "#e74c3c"])
TOTAL_CASES = 7

CSS = """
#intro-box {
  background: #f0f8ff;
  border: 2px solid #0ea5e9;
  padding: 22px; /* slightly larger */
  border-radius: 12px;
  margin: 10px 0 16px 0;
}
/* Make intro text more prominent without overwhelming */
#intro-box h2, #intro-box h3 { margin: 0.2rem 0 0.6rem 0; }
#intro-box h2 { font-size: 1.6rem; }
#intro-box h3 { font-size: 1.05rem; color: #0369a1; font-weight: 600; }
#intro-box p, #intro-box li { font-size: 1.02rem; line-height: 1.45; }
#intro-box ul { margin-top: 0.4rem; }

#case-info-box {
  background: #fff8e1;
  border: 1px solid #f59e0b;
  padding: 14px; border-radius: 10px; margin: 8px 0 14px 0;
}
#score-box {
  background: #f4f4f5;
  border: 1px solid #a1a1aa;
  padding: 14px; border-radius: 10px; margin: 8px 0 14px 0;
}
#case-box {
  background: #ffffff;
  border: 1px solid #e5e7eb;
  padding: 16px; border-radius: 10px; margin: 8px 0 14px 0;
}
#learn-card {
  border: 1px solid #e5e7eb; border-radius: 12px; padding: 12px;
  background:#fafafa; display:flex; flex-direction:column; align-items:center; gap:10px;
}
.gr-button { font-weight: 600; }
"""

# ---------- cartoon helpers ----------
def _cartoon_srcs():
    for name in ["triage_cartoon.png", "triage_cartoon.jpg", "triage_cartoon.jpeg"]:
        p = ASSET_DIR / name
        if p.exists():
            file_src = f"file/{p.as_posix()}"
            mime, _ = mimetypes.guess_type(p.name)
            if not mime: mime = "image/png"
            b64 = base64.b64encode(p.read_bytes()).decode("ascii")
            data_uri = f"data:{mime};base64,{b64}"
            return file_src, data_uri
    return None, None

def cartoon_html_tag(height_px=120):
    file_src, data_uri = _cartoon_srcs()
    if not file_src:
        return "", False
    html = (
        f"<img src='{file_src}' "
        f"onerror=\"this.onerror=null;this.src='{data_uri}';\" "
        f"alt='cartoon' "
        f"style='height:{height_px}px;display:block;margin:0 auto;border-radius:10px;'/>"
    )
    return html, True

# ---------------- Cases (unique Learn-more content per case) ----------------
CASES = [
    {
        "id": 1,
        "title": "Case 1 — Grandad at the market",
        "story": "A 78-year-old man became breathless walking home from the market after a night of fever. He’s dizzy and leaning forward to breathe.",
        "patient_name": "John D.", "age": 78, "hr": 115, "spo2": 88, "temp": 38.5, "rr": 28,
        "ethnicity": "South Asian", "comorbidity": "COPD",
        "ground_truth": SEVERE, "script_override": None,
        "teaching": {
            "concept": "Supervised learning",
            "learn": {
                "p1": "In **supervised learning**, a model learns a rule by seeing many examples where the ‘right answer’ (the label) is known. For triage, labels could be Green/Orange/Red assigned by clinicians. The model tries to map vital signs → label.",
                "p2": "This case shows a classic **labeled-examples → rule** setup. The inputs are oxygen level, breathing rate, temperature, and heart rate; the label is triage level. With enough examples, the model finds boundaries that separate Green/Orange/Red.",
                "p3": "Why it matters: when we later see a **new** patient who looks like past severe cases (low SpO₂ + high breathing rate), the model leans Red. If we trained on poor labels, the learned rule will also be poor—garbage in, garbage out."
            }
        }
    },
    {
        "id": 2,
        "title": "Case 2 — Teacher with a stubborn fever",
        "story": "A 63-year-old teacher has a high fever and gets breathless walking up stairs, but can speak in full sentences.",
        "patient_name": "Maria S.", "age": 63, "hr": 104, "spo2": 91, "temp": 39.0, "rr": 24,
        "ethnicity": "White", "comorbidity": "None",
        "ground_truth": INTERMEDIATE, "script_override": None,
        "teaching": {
            "concept": "Data representation",
            "learn": {
                "p1": "**Data representation** means how we turn the world into numbers a model can use. Simple choices (like SpO₂ as a % or “fever yes/no”) can change what the model learns.",
                "p2": "In this case, a **high temperature** and **moderate breathlessness** push toward Orange. If temperature were discretized too coarsely, we could lose the difference between 37.9°C and 39.0°C—reducing model sensitivity.",
                "p3": "Design tip: include signals that clinicians actually use (e.g., ability to speak in full sentences) and encode them clearly. Missing or blunt features lead to fuzzy decisions and potential safety issues."
            }
        }
    },
    {
        "id": 3,
        "title": "Case 3 — Student with a sore throat",
        "story": "A 35-year-old student with a sore throat and mild cough walked to the clinic without difficulty.",
        "patient_name": "Leo P.", "age": 35, "hr": 78, "spo2": 98, "temp": 37.4, "rr": 16,
        "ethnicity": "East Asian", "comorbidity": "None",
        "ground_truth": MILD, "script_override": None,
        "teaching": {
            "concept": "Evaluation (what ‘good’ means)",
            "learn": {
                "p1": "**Evaluation** isn’t one size fits all. In safety-critical settings, missing a severe case (a **false negative**) is worse than raising an extra alert (a **false positive**).",
                "p2": "This mild case reminds us that the model should not over-call Red. We want **high sensitivity** for severe, while keeping reasonable specificity so we don’t overwhelm staff.",
                "p3": "Metrics to watch: sensitivity/recall for Red, false-negative rate for Red, and overall confusion matrix. Choose thresholds to reflect clinical priorities—not just a single accuracy number."
            }
        }
    },
    {
        "id": 4,
        "title": "Case 4 — Office worker winded on stairs",
        "story": "A 52-year-old office worker felt feverish and noticed getting winded on stairs today.",
        "patient_name": "Omar R.", "age": 52, "hr": 102, "spo2": 92, "temp": 38.0, "rr": 22,
        "ethnicity": "Arab", "comorbidity": "Asthma",
        "ground_truth": INTERMEDIATE, "script_override": None,
        "teaching": {
            "concept": "Decision boundaries",
            "learn": {
                "p1": "**Decision boundaries** are the lines (or surfaces) that separate classes. Near a boundary, small changes can flip the prediction.",
                "p2": "Here HR ≈ 102 and SpO₂ ≈ 92% put us close to the **Green/Orange** boundary. Try nudging the sliders: a 1–2 point shift can move the case across the line.",
                "p3": "This is why we show confidence: **near boundaries = lower confidence**. Good interfaces communicate uncertainty instead of pretending precision."
            }
        }
    },
    {
        "id": 5,
        "title": "Case 5 — Parent speaking in short phrases",
        "story": "A 45-year-old parent has a harsh cough, high fever, and can only speak in short phrases.",
        "patient_name": "Tasha B.", "age": 45, "hr": 105, "spo2": 90, "temp": 39.2, "rr": 26,
        "ethnicity": "Black", "comorbidity": "Asthma",
        "ground_truth": INTERMEDIATE, "script_override": "force_mild",
        "teaching": {
            "concept": "Bias & coverage",
            "learn": {
                "p1": "**Bias & coverage**: models inherit patterns from their data. If some groups are under-represented or labels were inconsistent, performance can be worse for those patients.",
                "p2": "We intentionally inject a mismatch here (**force_mild**), simulating a dataset that under-learned this pattern. If you catch it, you’ll get the **Bias Buster** badge.",
                "p3": "Mitigations: audit performance by subgroup, add data where under-represented, calibrate confidence, and keep humans-in-the-loop for sensitive decisions."
            }
        }
    },
    {
        "id": 6,
        "title": "Case 6 — Retiree with chronic lung disease",
        "story": "A 68-year-old with COPD is more breathless today though there’s no fever.",
        "patient_name": "Ewan K.", "age": 68, "hr": 118, "spo2": 89, "temp": 37.2, "rr": 30,
        "ethnicity": "White", "comorbidity": "COPD",
        "ground_truth": SEVERE, "script_override": None,
        "teaching": {
            "concept": "Explainability & confidence",
            "learn": {
                "p1": "**Explainability** helps users understand which features pushed the decision. In triage, SpO₂ and breathing rate often dominate.",
                "p2": "Confidence here is tied to **distance from thresholds**—the farther the vitals are into the Red region, the higher the confidence.",
                "p3": "Design principle: pair predictions with plain-language reasons and confidence. It helps clinicians judge when to trust the model—or override it."
            }
        }
    },
    {
        "id": 7,
        "title": "Case 7 — Palpitations during pregnancy",
        "story": "A 29-year-old in the second trimester has palpitations and mild breathlessness.",
        "patient_name": "Aisha M.", "age": 29, "hr": 112, "spo2": 96, "temp": 37.0, "rr": 22,
        "ethnicity": "South Asian", "comorbidity": "None",
        "ground_truth": INTERMEDIATE, "script_override": "ood_low_confidence",
        "teaching": {
            "concept": "Out-of-distribution (OOD)",
            "learn": {
                "p1": "**Out-of-distribution** means the new case isn’t well represented in training data (e.g., pregnancy not encoded). Models should get cautious here.",
                "p2": "We cap confidence for this case to reflect missing context. That’s safer than overstating certainty when the model may not know.",
                "p3": "Mitigations: add/encode the missing context, monitor for OOD patterns in deployment, and allow graceful **deferral to humans** when confidence is low."
            }
        }
    },
]

# ---------------- Model ----------------
def thresholds(caution: float):
    spo2_severe = 90 + 2*(caution/100)
    rr_severe   = 28 - 4*(caution/100)
    spo2_inter  = 93 + 2*(caution/100)
    hr_inter    = 105 - 15*(caution/100)
    temp_inter  = 38.5 - 1.0*(caution/100)
    return spo2_severe, rr_severe, spo2_inter, hr_inter, temp_inter

def predict_label(hr:int, spo2:int, temp:float, rr:int, script_override=None, caution: float=50):
    s_spo2, s_rr, i_spo2, i_hr, i_temp = thresholds(caution)
    severe = (spo2 <= s_spo2) or (rr >= s_rr)
    intermediate = (spo2 <= i_spo2) or (hr >= i_hr) or (temp >= i_temp)
    pred = SEVERE if severe else (INTERMEDIATE if intermediate else MILD)
    if script_override == "force_mild":
        pred = MILD
    d = [
        abs(spo2 - s_spo2)/10.0,
        abs(rr   - s_rr)/10.0,
        abs(spo2 - i_spo2)/10.0,
        abs(hr   - i_hr)/30.0,
        abs(temp - i_temp)/1.5,
    ]
    conf = float(np.clip(1 - min(d), 0.0, 1.0))
    if script_override == "ood_low_confidence":
        conf = min(conf, 0.30)
    return pred, conf

# ---------------- Plots ----------------
def plot_boundary_spo2_hr(temp, rr, spo2_pt, hr_pt, caution):
    spo2_grid = np.linspace(85, 100, 120)
    hr_grid   = np.linspace(60, 160, 120)
    Z = np.zeros((len(hr_grid), len(spo2_grid)))
    for i, h in enumerate(hr_grid):
        for j, s in enumerate(spo2_grid):
            pred, _ = predict_label(int(h), int(s), temp, rr, None, caution)
            Z[i, j] = LABELS.index(pred)
    fig, ax = plt.subplots()
    ax.imshow(Z, extent=[spo2_grid.min(), spo2_grid.max(), hr_grid.min(), hr_grid.max()],
              origin="lower", aspect="auto", cmap=TRAFFIC_CMAP, vmin=0, vmax=2)
    ax.set_xlabel("SpO₂ (%)"); ax.set_ylabel("Heart Rate (bpm)")
    ax.scatter([spo2_pt], [hr_pt], s=100, marker="o", edgecolor="black", facecolor="none")
    ax.set_title("Decision regions: SpO₂ vs Heart Rate")
    return fig

def plot_boundary_temp_rr(hr, spo2, temp_pt, rr_pt, caution):
    temp_grid = np.linspace(35.5, 40.5, 110)
    rr_grid   = np.linspace(10, 36, 108)
    Z = np.zeros((len(rr_grid), len(temp_grid)))
    for i, r in enumerate(rr_grid):
        for j, t in enumerate(temp_grid):
            pred, _ = predict_label(hr, spo2, float(t), int(r), None, caution)
            Z[i, j] = LABELS.index(pred)
    fig, ax = plt.subplots()
    ax.imshow(Z, extent=[temp_grid.min(), temp_grid.max(), rr_grid.min(), rr_grid.max()],
              origin="lower", aspect="auto", cmap=TRAFFIC_CMAP, vmin=0, vmax=2)
    ax.set_xlabel("Temperature (°C)"); ax.set_ylabel("Breathing Rate (/min)")
    ax.scatter([temp_pt], [rr_pt], s=100, marker="o", edgecolor="black", facecolor="none")
    ax.set_title("Decision regions: Temp vs Breathing Rate")
    return fig

# ---------------- Text builders ----------------
def score_md(user_t, ai_t):
    return (
        "### 📊 Scores\n"
        f"**You vs Truth** — ✅ {user_t['correct']} • ❌ {user_t['incorrect']} "
        f"(Missed Severe: {user_t['missed_severe']}) • Total: {user_t['total']}\n\n"
        f"**AI vs Truth** — ✅ {ai_t['correct']} • ❌ {ai_t['incorrect']} "
        f"(Missed Severe: {ai_t['missed_severe']}) • Total: {ai_t['total']}"
    )

def case_info_md(idx:int):
    c = CASES[idx]
    return (
        f"### 📂 Case {idx+1}/{TOTAL_CASES}\n"
        f"**Teaching focus:** {c['teaching']['concept']}.  \n"
        "_Supervised learning is the most common ML type—see how labeled examples guide predictions._"
    )

def learn_more_rich(idx):
    c = CASES[idx]
    concept = c["teaching"]["concept"]
    learn = c["teaching"].get("learn", {})
    p1 = learn.get("p1", f"About **{concept}**.")
    p2 = learn.get("p2", "")
    p3 = learn.get("p3", "")
    return f"**{concept}**\n\n{p1}\n\n{p2}\n\n{p3}"

def confidence_bar_html(value: float):
    pct = int(100*value)
    return f"""
    <div style="width:100%;background:#eee;border-radius:8px;height:12px;position:relative;">
      <div style="width:{pct}%;height:12px;background:#4caf50;border-radius:8px;"></div>
    </div>
    <div style="font-size:12px;color:#555;margin-top:4px;">Confidence (distance to nearest boundary): {pct}%</div>
    """

def badge_strip_html(badges):
    def chip(name,on):
        return f'<span style="padding:3px 8px;border-radius:999px;margin-right:6px;background:{"#d1fae5" if on else "#eee"};color:{"#065f46" if on else "#555"}">{name}{" ✅" if on else ""}</span>'
    return "<div style='text-align:right'>" + "".join([
        chip("Boundary Breaker", badges.get("Boundary Breaker", False)),
        chip("Bias Buster",      badges.get("Bias Buster", False)),
        chip("Safety First",     badges.get("Safety First", False)),
        chip("OOD Aware",        badges.get("OOD Aware", False)),
    ]) + "</div>"

# ---------------- Scoring & Callbacks ----------------
def reveal(user_choice, hr, spo2, temp, rr, script_override,
           user_t, ai_t, already_revealed, badges, caution, idx):
    # AI prediction
    ai_pred, conf = predict_label(hr, spo2, temp, rr, script_override, caution)
    truth = ai_t["truth_label"]

    # Only score once per case
    if not already_revealed:
        # User vs Truth
        if user_choice == truth: user_t["correct"] += 1
        else:
            user_t["incorrect"] += 1
            if truth == SEVERE and user_choice != SEVERE:
                user_t["missed_severe"] += 1
        user_t["total"] += 1

        # AI vs Truth
        if ai_pred == truth: ai_t["correct"] += 1
        else:
            ai_t["incorrect"] += 1
            if truth == SEVERE and ai_pred != SEVERE:
                ai_t["missed_severe"] += 1
        ai_t["total"] += 1

        # Badges from scenario knobs
        if script_override == "force_mild":         badges["Bias Buster"] = True
        if script_override == "ood_low_confidence": badges["OOD Aware"]  = True
        if caution >= 70:                           badges["Safety First"] = True

    # Teaching note if AI is wrong on bias/OOD
    note = ""
    if ai_pred != truth:
        concept = CASES[idx]["teaching"]["concept"].lower()
        if "bias" in concept:
            note = "💡 **AI miss (bias):** training didn’t cover this pattern well."
        elif "out-of-distribution" in concept or "ood" in concept:
            note = "💡 **AI miss (OOD):** case is unlike training data; confidence should be lower."

    msg = f"**AI prediction:** {ai_pred}\n\n{note}".strip()
    conf_html = confidence_bar_html(conf)
    badges_html = badge_strip_html(badges)

    # Enable sliders after reveal
    return (
        msg, True, user_t, ai_t, badges,
        score_md(user_t, ai_t), conf_html, badges_html,  # score card + conf + badges
        ai_pred,                                           # s_last_pred
        gr.update(value='Revealed ✅', interactive=False), # reveal button
        gr.update(interactive=True), gr.update(interactive=True),
        gr.update(interactive=True), gr.update(interactive=True)
    )

def adjust(hr, spo2, temp, rr, caution, last_pred):
    pred_now, _ = predict_label(hr, spo2, temp, rr, None, caution)
    info = f"If vitals changed → **{pred_now}**"
    fig1 = plot_boundary_spo2_hr(temp, rr, spo2, hr, caution)
    fig2 = plot_boundary_temp_rr(hr, spo2, temp, rr, caution)
    broke = (last_pred not in ("", None)) and (pred_now != last_pred)
    return info, fig1, fig2, pred_now, broke

def update_badges(badges, broke):
    if broke: badges["Boundary Breaker"] = True
    return badges, badge_strip_html(badges)

def next_case(idx, user_t, ai_t):
    idx += 1
    c = CASES[idx]
    header = f"**{c['title']}**"
    story  = c["story"]
    human  = (f"**Patient:** {c['patient_name']} (age {c['age']})  \n"
              f"- Heart rate: **{c['hr']} bpm**  \n- Oxygen saturation (SpO₂): **{c['spo2']}%**  \n"
              f"- Temperature: **{c['temp']} °C**  \n- Breathing rate: **{c['rr']} /min**  \n"
              f"- Ethnicity: {c['ethnicity']}  \n- Comorbidity: {c['comorbidity']}")
    ai_txt = (f"Age: {c['age']}\nHR: {c['hr']}\nSpO₂: {c['spo2']}\nTemp: {c['temp']}\n"
              f"RespRate: {c['rr']}\nEthnicity: {c['ethnicity']}\nComorbidity: {c['comorbidity']}")
    ai_t["truth_label"] = c["ground_truth"]
    at_last = (idx == TOTAL_CASES - 1)
    return (idx, False, header, story, human, c["hr"], c["spo2"], c["temp"], c["rr"],
            case_info_md(idx), c["script_override"], ai_txt, at_last)

def finish_text(user_t, ai_t):
    return f"**Great job!** You completed all {TOTAL_CASES} cases.  \n\n{score_md(user_t, ai_t)}"

def restart():
    idx = 0
    user_t = {"correct":0, "incorrect":0, "missed_severe":0, "total":0}
    ai_t   = {"correct":0, "incorrect":0, "missed_severe":0, "total":0, "truth_label": CASES[0]["ground_truth"]}
    badges = {"Boundary Breaker": False, "Bias Buster": False, "Safety First": False, "OOD Aware": False}
    c = CASES[0]
    header = f"**{c['title']}**"
    story  = c["story"]
    human  = (f"**Patient:** {c['patient_name']} (age {c['age']})  \n"
              f"- Heart rate: **{c['hr']} bpm**  \n- Oxygen saturation (SpO₂): **{c['spo2']}%**  \n"
              f"- Temperature: **{c['temp']} °C**  \n- Breathing rate: **{c['rr']} /min**  \n"
              f"- Ethnicity: {c['ethnicity']}  \n- Comorbidity: {c['comorbidity']}")
    ai_txt = (f"Age: {c['age']}\nHR: {c['hr']}\nSpO₂: {c['spo2']}\nTemp: {c['temp']}\n"
              f"RespRate: {c['rr']}\nEthnicity: {c['ethnicity']}\nComorbidity: {c['comorbidity']}")
    return (idx, False, header, story, human, c["hr"], c["spo2"], c["temp"], c["rr"],
            case_info_md(0), c["script_override"],
            user_t, ai_t, badges, "", ai_txt,
            badge_strip_html(badges), False,
            score_md(user_t, ai_t),
            False, "Learn more about this concept ▸")

def set_next_visibility_from_last(at_last: bool):
    return (gr.update(visible=not at_last), gr.update(visible=at_last))

def show_results_top(user_t, ai_t):
    return (gr.update(visible=True), finish_text(user_t, ai_t))

def lock_sliders_and_reset_reveal():
    return (gr.update(interactive=False), gr.update(interactive=False),
            gr.update(interactive=False), gr.update(interactive=False),
            False, gr.update(value="Reveal AI prediction", interactive=True))

# Toggle learn-more open/close
def toggle_learn(idx, is_open):
    if is_open:
        # collapse
        return "", gr.update(visible=False), False, gr.update(value="Learn more about this concept ▸")
    else:
        # expand
        return learn_more_rich(idx), gr.update(visible=True), True, gr.update(value="Hide details ▲")

# ---------------- UI ----------------
with gr.Blocks(title="AI in Healthcare: Can You Predict Patient Priority?", css=CSS) as demo:
    # Title + logo
    with gr.Row():
        gr.Markdown("# AI in Healthcare: Can You Predict Patient Priority?")
        with gr.Column(scale=1):
            if ASSET_LOGO.exists():
                gr.Image(value=str(ASSET_LOGO), height=48, show_label=False)
            else:
                gr.Markdown("_(Logo file not found: assets/AI4Health_logo.jpg)_")

    # Intro — playful hook + academic subtitle + beginner vital-sign explainer
    intro_md = gr.Markdown(
        "## Beat the Machine — Can You Triage Better than AI?\n"
        "### *(AI in Healthcare: Triage Through 7 Cases and When It Fails)*\n\n"
        "**Triage** means deciding how urgently each patient needs care so the right help arrives at the right time. "
        "In busy clinics and emergency rooms, it helps teams focus on those who need help first.\n\n"
        "In this demo, you’ll work through **7 short cases** and choose a level:\n"
        "- **Severe (Red)** — needs urgent attention\n"
        "- **Intermediate (Orange)** — monitor closely\n"
        "- **Mild (Green)** — likely safe to wait\n\n"
        "**Vital signs (plain language):**\n"
        "- **Heart Rate (HR):** heartbeats per minute. Very **high** or very low can be concerning.\n"
        "- **Respiratory Rate (RR):** breaths per minute. **Higher** can mean the body is working hard.\n"
        "- **Oxygen Saturation (SpO₂):** how much oxygen is in the blood. **Lower** is dangerous.\n"
        "- **Temperature:** body heat (°C). **Higher** can signal infection or illness.\n\n"
        "You’ll see how an **AI system** would classify the same cases, and where it can go wrong—"
        "for example due to **bias in training data** or **out-of-distribution cases**. "
        "We’ll track **your score vs the ground truth** and compare it to the **AI’s score**.",
        elem_id="intro-box"
    )

    # TOP ROW: left visual | center score | right badges + learn
    with gr.Row():
        with gr.Column(scale=1):
            if ASSET_HOW.exists():
                gr.Image(value=str(ASSET_HOW), show_label=False, height=120)

        with gr.Column(scale=2):
            score_card = gr.Markdown(
                score_md({"correct":0,"incorrect":0,"missed_severe":0,"total":0},
                         {"correct":0,"incorrect":0,"missed_severe":0,"total":0}),
                elem_id="score-box"
            )

        with gr.Column(scale=2):
            badge_strip_top = gr.HTML("<div style='text-align:right'></div>")
            cartoon_html, has_cartoon = cartoon_html_tag(height_px=120)
            with gr.Group(elem_id="learn-card"):
                cartoon_box = gr.HTML(cartoon_html, visible=has_cartoon)
                learn_btn = gr.Button("Learn more about this concept ▸")
                s_learn_open = gr.State(False)
                learn_group = gr.Group(visible=False)
                with learn_group:
                    learn_md = gr.Markdown()

    # Results near top
    if HAS_MODAL:
        results_container = gr.Modal("Final Results", visible=False)
        with results_container:
            final_text = gr.Markdown()
            close_modal = gr.Button("Close")
    else:
        results_container = gr.Group(visible=False)
        with results_container:
            gr.Markdown("## Final Results")
            final_text = gr.Markdown()
            close_modal = gr.Button("Close")

    # Case info card
    case_info = gr.Markdown(case_info_md(0), elem_id="case-info-box")

    # Main row: LEFT (case card) | RIGHT (controls)
    with gr.Row():
        with gr.Column(scale=3):
            with gr.Group(elem_id="case-box"):
                header = gr.Markdown(f"**{CASES[0]['title']}**")
                story  = gr.Markdown(CASES[0]["story"])
                c0 = CASES[0]
                human_view = gr.Markdown(
                    f"**Patient:** {c0['patient_name']} (age {c0['age']})  \n"
                    f"- Heart rate: **{c0['hr']} bpm**  \n- Oxygen saturation (SpO₂): **{c0['spo2']}%**  \n"
                    f"- Temperature: **{c0['temp']} °C**  \n- Breathing rate: **{c0['rr']} /min**  \n"
                    f"- Ethnicity: {c0['ethnicity']}  \n- Comorbidity: {c0['comorbidity']}"
                )
                with gr.Row():
                    reveal_btn = gr.Button("Reveal AI prediction", variant="primary")
                    next_btn = gr.Button("Next case ▶")
                    see_results_btn = gr.Button("See results ✅", visible=False)
                    restart_btn = gr.Button("Restart 🔄")
                with gr.Accordion("AI view (features)", open=False):
                    ai_view = gr.Textbox(
                        value=(f"Age: {c0['age']}\nHR: {c0['hr']}\nSpO₂: {c0['spo2']}\nTemp: {c0['temp']}\n"
                               f"RespRate: {c0['rr']}\nEthnicity: {c0['ethnicity']}\nComorbidity: {c0['comorbidity']}"),
                        lines=7, interactive=False, show_copy_button=True, label=""
                    )
                ai_out = gr.Markdown()
                conf_bar = gr.HTML("")
                with gr.Accordion("Adjust & explore (unlocks after Reveal)", open=True):
                    adjust_info = gr.Markdown()
                    with gr.Row():
                        db_plot1 = gr.Plot(value=plot_boundary_spo2_hr(c0["temp"], c0["rr"], c0["spo2"], c0["hr"], 50),
                                           label="SpO₂ vs Heart Rate")
                        db_plot2 = gr.Plot(value=plot_boundary_temp_rr(c0["hr"], c0["spo2"], c0["temp"], c0["rr"], 50),
                                           label="Temperature vs Breathing Rate")

        with gr.Column(scale=1):
            gr.Markdown("### Your decision (traffic-light)")
            user_choice = gr.Radio(LABELS, value=INTERMEDIATE, show_label=False)
            gr.Markdown("### Caution level")
            caution = gr.Slider(0, 100, value=50, step=1, label="More cautious → more Severe/Intermediate")
            gr.Markdown("### Vital signs")
            hr = gr.Slider(60, 160, step=1, label="Heart Rate (bpm)", interactive=False, value=c0["hr"])
            spo2 = gr.Slider(85, 100, step=1, label="SpO₂ (%)", interactive=False, value=c0["spo2"])
            temp = gr.Slider(35.5, 40.5, step=0.1, label="Temperature (°C)", interactive=False, value=c0["temp"])
            rr = gr.Slider(10, 36, step=1, label="Breathing rate (/min)", interactive=False, value=c0["rr"])

    # States
    s_idx = gr.State(0)
    s_revealed = gr.State(False)
    s_user_vs_truth = gr.State({"correct":0, "incorrect":0, "missed_severe":0, "total":0})
    s_ai_vs_truth   = gr.State({"correct":0, "incorrect":0, "missed_severe":0, "total":0, "truth_label": c0["ground_truth"]})
    s_script_override = gr.State(c0["script_override"])
    s_badges = gr.State({"Boundary Breaker": False, "Bias Buster": False, "Safety First": False, "OOD Aware": False})
    s_last_pred = gr.State("")
    s_at_last = gr.State(False)
    s_learn_open = gr.State(False)

    # ----- Interactions -----
    def on_change(H,S,T,R,C,last,badges):
        info, fig1, fig2, pred_now, broke = adjust(H,S,T,R,C,last)
        badges2, badges_html = update_badges(badges, broke)
        return info, fig1, fig2, pred_now, badges2, badge_strip_html(badges2)

    for w in (hr, spo2, temp, rr, caution):
        w.change(fn=on_change,
                 inputs=[hr, spo2, temp, rr, caution, s_last_pred, s_badges],
                 outputs=[adjust_info, db_plot1, db_plot2, s_last_pred, s_badges, badge_strip_top])

    # Toggle learn-more
    learn_btn.click(
        fn=toggle_learn,
        inputs=[s_idx, s_learn_open],
        outputs=[learn_md, learn_group, s_learn_open, learn_btn]
    )

    # Reveal
    reveal_btn.click(
        fn=reveal,
        inputs=[user_choice, hr, spo2, temp, rr, s_script_override,
                s_user_vs_truth, s_ai_vs_truth, s_revealed, s_badges, caution, s_idx],
        outputs=[ai_out, s_revealed, s_user_vs_truth, s_ai_vs_truth, s_badges,
                 score_card, conf_bar, badge_strip_top, s_last_pred,
                 reveal_btn, hr, spo2, temp, rr]
    )

    # Next case
    next_btn.click(
        fn=next_case,
        inputs=[s_idx, s_user_vs_truth, s_ai_vs_truth],
        outputs=[s_idx, s_revealed, header, story, human_view, hr, spo2, temp, rr,
                 case_info, s_script_override, ai_view, s_at_last]
    ).then(
        # collapse learn-more when moving on
        fn=lambda: ("", gr.update(visible=False), False, gr.update(value="Learn more about this concept ▸")),
        inputs=None, outputs=[learn_md, learn_group, s_learn_open, learn_btn]
    ).then(
        fn=lock_sliders_and_reset_reveal, inputs=None,
        outputs=[hr, spo2, temp, rr, s_revealed, reveal_btn]
    ).then(
        fn=set_next_visibility_from_last, inputs=[s_at_last], outputs=[next_btn, see_results_btn]
    )

    # Results (top)
    see_results_btn.click(
        fn=show_results_top,
        inputs=[s_user_vs_truth, s_ai_vs_truth],
        outputs=[results_container, final_text]
    )
    close_modal.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=[results_container])

    # Restart
    restart_btn.click(
        fn=restart, inputs=[],
        outputs=[s_idx, s_revealed, header, story, human_view, hr, spo2, temp, rr,
                 case_info, s_script_override, s_user_vs_truth, s_ai_vs_truth,
                 s_badges, s_last_pred, ai_view, badge_strip_top, s_at_last, score_card, s_learn_open, learn_btn]
    ).then(
        fn=lambda: gr.update(visible=False), inputs=None, outputs=[results_container]
    ).then(
        fn=lambda at_last: (gr.update(visible=True), gr.update(visible=False)),
        inputs=[s_at_last], outputs=[next_btn, see_results_btn]
    ).then(
        fn=lock_sliders_and_reset_reveal, inputs=None,
        outputs=[hr, spo2, temp, rr, s_revealed, reveal_btn]
    )

if __name__ == "__main__":
    ASSET_DIR.mkdir(exist_ok=True)
    demo.launch()

