"""Modal runner for the consistency-loss experiment.

── V1 entrypoints (original 4-variant run, GPT-2-small, DO NOT RE-RUN) ─────────

    modal run modal_run.py::run_full_gpt2_small

The function runs all V1 variants on an A10G and writes outputs to a
timestamped directory on the Modal volume. Download via:

    modal volume get consistency-loss-experiment-results /results/full_gpt2_small_outputs.zip .

Previous completed V1 run directory: /results/full_gpt2_small_20260430_174439

── V2 entrypoint (stronger 4-variant ablation ladder, GPT-2-small) ─────────────

    modal run modal_run.py::run_full_gpt2_small_v2

Runs the four stronger ablation variants:
  no_claim_to_claim_attention, claims_from_explanation_only,
  surface_bottleneck_consistency, surface_bottleneck_no_expl_lm

Outputs go to /results/full_gpt2_small_stronger_<timestamp> — never overwrites
the V1 run directory.  Download after completion via:

    modal volume get consistency-loss-experiment-results /results/full_gpt2_small_stronger_<timestamp>.zip .

    # To list all result directories on the volume:
    modal volume ls consistency-loss-experiment-results /results
"""

from __future__ import annotations

import subprocess
import time
from pathlib import Path

import modal


PROJECT_ROOT = Path(__file__).resolve().parent
APP_NAME = "consistency-loss-full-gpt2-small"
VOLUME_NAME = "consistency-loss-experiment-results"


image = (
    modal.Image.debian_slim(python_version="3.12")
    .pip_install(
        "torch",
        "numpy",
        "pandas",
        "matplotlib",
        "scikit-learn",
        "rouge-score",
        "scipy",
    )
    .add_local_dir(PROJECT_ROOT, remote_path="/root/consistency_loss_experiment")
)


volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
app = modal.App(APP_NAME, image=image)


@app.function(
    gpu="A10G",
    timeout=60 * 60 * 12,
    volumes={"/results": volume},
)
def run_full_gpt2_small() -> dict:
    """Run all variants for the full experiment using the GPT-2-small-style config."""

    workdir = Path("/root/consistency_loss_experiment")
    run_id = time.strftime("%Y%m%d_%H%M%S")
    output_dir = Path("/results") / f"full_gpt2_small_{run_id}"
    zip_path = Path("/results") / "full_gpt2_small_outputs.zip"
    log_path = Path("/results") / f"full_gpt2_small_{run_id}.log"

    cmd = [
        "python",
        "run_experiment.py",
        "--full",
        "--model",
        "gpt2_small",
        "--output-dir",
        str(output_dir),
    ]

    with log_path.open("w", encoding="utf-8") as log_file:
        log_file.write("Command: " + " ".join(cmd) + "\n")
        log_file.flush()
        proc = subprocess.run(
            cmd,
            cwd=workdir,
            stdout=log_file,
            stderr=subprocess.STDOUT,
            text=True,
        )

    if proc.returncode != 0:
        volume.commit()
        raise RuntimeError(f"Experiment failed with exit code {proc.returncode}; log: {log_path}")

    subprocess.run(
        ["python", "-m", "zipfile", "-c", str(zip_path), str(output_dir)],
        check=True,
    )
    volume.commit()

    return {
        "run_id": run_id,
        "output_dir": str(output_dir),
        "zip_path": str(zip_path),
        "log_path": str(log_path),
        "download_command": f"modal volume get {VOLUME_NAME} /results/{zip_path.name} .",
    }


@app.function(
    gpu="A10G",
    timeout=60 * 60 * 2,
    volumes={"/results": volume},
)
def generate_qualitative_comparison(
    run_dir_name: str = "full_gpt2_small_20260430_174439",
    n_examples: int = 20,
) -> dict:
    """Generate side-by-side epoch-20 explanations for all variants.

    This reuses the existing full-run checkpoints and produces review artifacts
    only; it does not retrain any model.
    """

    import csv
    import gc
    import json
    import re
    import sys

    import torch

    sys.path.insert(0, "/root/consistency_loss_experiment")

    from dataset import build_dataset, build_tokenizer, split_dataset
    from model import ConsistencyTransformer
    from trainer import greedy_decode

    variants = [
        "consistency_loss",
        "no_consistency_loss",
        "claim_only_pooling",
        "random_label_consistency",
    ]

    run_dir = Path("/results") / run_dir_name
    review_dir = Path("/results") / f"{run_dir_name}_qualitative_review"
    review_dir.mkdir(parents=True, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    examples = build_dataset(n=3000, seed=42)
    _, val_examples = split_dataset(examples, val_size=500, seed=42)
    tokenizer = build_tokenizer(examples)

    def clean_decoded(text: str) -> str:
        # The toy tokenizer decodes every token with spaces. Join runs of
        # one-character alphabetic tokens so "C o m p u t e s" is readable.
        toks = text.split()
        out = []
        i = 0
        while i < len(toks):
            if re.fullmatch(r"[A-Za-z]", toks[i]):
                letters = []
                while i < len(toks) and re.fullmatch(r"[A-Za-z]", toks[i]):
                    letters.append(toks[i])
                    i += 1
                out.append("".join(letters))
            else:
                out.append(toks[i])
                i += 1
        s = " ".join(out)
        s = re.sub(r"\s+([.,;:!?])", r"\1", s)
        s = s.replace("( n )", "(n)").replace("( n^2 )", "(n^2)").replace("( 1 )", "(1)")
        s = s.replace("O (n)", "O(n)").replace("O (n^2)", "O(n^2)").replace("O (1)", "O(1)")
        return s

    def prose_part(text: str) -> str:
        return text.split("<claim>", 1)[0].strip()

    rows = []
    for variant in variants:
        ckpt_path = run_dir / "checkpoints" / variant / "epoch_020.pt"
        ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
        model = ConsistencyTransformer(ckpt["model_cfg"]).to(device)
        model.load_state_dict(ckpt["model_state"])
        model.eval()

        for sample_i, ex in enumerate(val_examples[:n_examples], start=1):
            raw = greedy_decode(
                model,
                tokenizer,
                ex.code_snippet,
                max_new_tokens=80,
                device=device,
            )
            clean = clean_decoded(raw)
            prose = prose_part(clean)
            rows.append({
                "sample": sample_i,
                "example_idx": ex.idx,
                "template_name": ex.template_name,
                "variant": variant,
                "time_complexity": ex.time_complexity,
                "space_complexity": ex.space_complexity,
                "correctness": ex.correctness,
                "true_explanation": ex.true_explanation,
                "mismatched_explanation": ex.mismatched_explanation,
                "code_snippet": ex.code_snippet,
                "generation_raw": raw,
                "generation_clean": clean,
                "prose_clean": prose,
                "emits_time_claim": f"time_complexity={ex.time_complexity}" in raw,
                "emits_space_claim": f"space_complexity={ex.space_complexity}" in raw,
                "emits_correctness_claim": f"correctness={ex.correctness}" in raw,
            })

        del model, ckpt
        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    csv_path = review_dir / "qualitative_side_by_side.csv"
    with csv_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        writer.writeheader()
        writer.writerows(rows)

    json_path = review_dir / "qualitative_side_by_side.json"
    json_path.write_text(json.dumps(rows, indent=2), encoding="utf-8")

    md_path = review_dir / "qualitative_side_by_side.pplx.md"
    lines = [
        "# Qualitative Side-by-Side Review Inputs",
        "",
        f"Run directory: `{run_dir_name}`",
        f"Validation examples reviewed: {n_examples}",
        "",
        "This file contains raw model generations for manual review. Scores are intentionally left for the reviewer rather than inferred automatically.",
        "",
    ]
    by_sample = {}
    for row in rows:
        by_sample.setdefault(row["sample"], []).append(row)

    for sample, sample_rows in by_sample.items():
        first = sample_rows[0]
        lines.extend([
            f"## Sample {sample}: `{first['template_name']}`",
            "",
            f"Ground truth: time={first['time_complexity']}, space={first['space_complexity']}, correct={first['correctness']}",
            "",
            f"Reference explanation: {first['true_explanation']}",
            "",
            f"Mismatched training explanation: {first['mismatched_explanation']}",
            "",
            "```python",
            first["code_snippet"],
            "```",
            "",
            "| Variant | Generated prose | Full generation |",
            "|---|---|---|",
        ])
        for row in sample_rows:
            prose = row["prose_clean"].replace("|", "\\|")
            full = row["generation_clean"].replace("|", "\\|")
            lines.append(f"| `{row['variant']}` | {prose} | {full} |")
        lines.append("")

    md_path.write_text("\n".join(lines), encoding="utf-8")

    volume.commit()
    return {
        "review_dir": str(review_dir),
        "csv_path": str(csv_path),
        "json_path": str(json_path),
        "md_path": str(md_path),
        "n_rows": len(rows),
        "device": str(device),
    }


# ────────────────────────────────────────────────────────────────────────────────
# V2 — Stronger Ablation Ladder (GPT-2-small, A10G)
# Run with: modal run modal_run.py::run_full_gpt2_small_v2
# ────────────────────────────────────────────────────────────────────────────────

# Separate App so V2 can be deployed/invoked independently of V1.
# Both apps share the same Modal volume (VOLUME_NAME) and write to separate
# timestamped subdirectories, so they never overwrite each other.
APP_NAME_V2 = "consistency-loss-full-gpt2-small-v2"

app_v2 = modal.App(APP_NAME_V2, image=image)

# Re-use the same volume object declared above.
# Modal resolves volumes by name, so app_v2 and app both mount the same
# persistent storage at /results inside the container.


@app_v2.function(
    gpu="A10G",
    timeout=60 * 60 * 2,   # 2 h; qualitative generation only, no training
    volumes={"/results": volume},
)
def generate_qualitative_comparison_v2(
    run_dir_name: str = "full_gpt2_small_stronger_20260430_200556",
    n_examples: int = 20,
    include_epoch_initial: bool = True,
) -> dict:
    """Generate side-by-side epoch-5 (initial) and epoch-20 (final) explanations
    for all four V2 stronger-ablation variants.

    Loads existing checkpoints from the completed V2 run — no retraining.
    Writes CSV, JSON, and Markdown review artifacts to:
      /results/<run_dir_name>_qualitative_review/

    The V1 review directory (full_gpt2_small_20260430_174439_qualitative_review)
    is NEVER touched.

    Checkpoint availability:
      checkpoint_every=5 during training, so epoch_005.pt, epoch_010.pt,
      epoch_015.pt, epoch_020.pt are the expected saves.  epoch_005 is used
      as "initial" (earliest saved); epoch_020 is used as "final".
      If either is missing the function raises a clear FileNotFoundError
      listing available checkpoints so the caller can adjust.

    Launch command:
        modal run modal_run.py::generate_qualitative_comparison_v2

    Detached launch (returns immediately):
        modal run --detach modal_run.py::generate_qualitative_comparison_v2

    With non-default arguments:
        modal run modal_run.py::generate_qualitative_comparison_v2 \\
            --run-dir-name full_gpt2_small_stronger_20260430_200556 \\
            --n-examples 30

    Download outputs after completion:
        modal volume ls consistency-loss-experiment-results \\
            /results/full_gpt2_small_stronger_20260430_200556_qualitative_review
        modal volume get consistency-loss-experiment-results \\
            /results/full_gpt2_small_stronger_20260430_200556_qualitative_review \\
            ./v2_qualitative_review
    """

    import csv
    import gc
    import json
    import re
    import sys

    import torch

    sys.path.insert(0, "/root/consistency_loss_experiment")

    from dataset import build_dataset, build_tokenizer, split_dataset
    from model import ConsistencyTransformer
    from trainer import greedy_decode, simple_bleu1, rouge_l_score

    # ── V2 variants (hardcoded so the set is visible in logs) ────────────────
    v2_variants = [
        "no_claim_to_claim_attention",
        "claims_from_explanation_only",
        "surface_bottleneck_consistency",
        "surface_bottleneck_no_expl_lm",
    ]

    # ── Paths ─────────────────────────────────────────────────────────────────
    run_dir    = Path("/results") / run_dir_name
    review_dir = Path("/results") / f"{run_dir_name}_qualitative_review"
    review_dir.mkdir(parents=True, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ── Dataset (same seed/size as training) ─────────────────────────────────
    examples = build_dataset(n=3000, seed=42)
    _, val_examples = split_dataset(examples, val_size=500, seed=42)
    tokenizer = build_tokenizer(examples)

    # ── Helper: clean up toy-tokenizer spacing ────────────────────────────────
    def _clean_decoded(text: str) -> str:
        """Join single-char alphabetic tokens and fix punctuation spacing."""
        toks = text.split()
        out, i = [], 0
        while i < len(toks):
            if re.fullmatch(r"[A-Za-z]", toks[i]):
                letters = []
                while i < len(toks) and re.fullmatch(r"[A-Za-z]", toks[i]):
                    letters.append(toks[i])
                    i += 1
                out.append("".join(letters))
            else:
                out.append(toks[i])
                i += 1
        s = " ".join(out)
        s = re.sub(r"\s+([.,;:!?])", r"\1", s)
        for raw, fixed in [
            ("( n )", "(n)"), ("( n^2 )", "(n^2)"), ("( 1 )", "(1)"),
            ("O (n)", "O(n)"), ("O (n^2)", "O(n^2)"), ("O (1)", "O(1)"),
        ]:
            s = s.replace(raw, fixed)
        return s

    def _prose_part(text: str) -> str:
        """Return only the prose portion before the first <claim> token."""
        return text.split("<claim>", 1)[0].strip()

    # ── Checkpoint probe: find initial and final checkpoint paths ─────────────
    def _find_checkpoints(variant: str):
        """Return (initial_path, final_path) for a variant.

        initial = earliest available checkpoint (epoch_005 preferred).
        final   = epoch_020 (required).
        Raises FileNotFoundError with diagnostics if either is missing.
        """
        ckpt_dir = run_dir / "checkpoints" / variant
        if not ckpt_dir.exists():
            raise FileNotFoundError(
                f"Checkpoint directory missing: {ckpt_dir}\n"
                f"Available under {run_dir / 'checkpoints'}: "
                + str([p.name for p in (run_dir / "checkpoints").iterdir()])
            )

        available = sorted(ckpt_dir.glob("epoch_*.pt"))
        if not available:
            raise FileNotFoundError(
                f"No epoch_*.pt files in {ckpt_dir}"
            )

        # Final must be epoch_020
        final_path = ckpt_dir / "epoch_020.pt"
        if not final_path.exists():
            raise FileNotFoundError(
                f"Final checkpoint epoch_020.pt not found in {ckpt_dir}.\n"
                f"Available: {[p.name for p in available]}"
            )

        # Initial: prefer epoch_005, fall back to earliest available
        initial_path = ckpt_dir / "epoch_005.pt"
        if not initial_path.exists():
            initial_path = available[0]  # earliest by filename sort

        return initial_path, final_path

    # ── Load one checkpoint and return (model, epoch_number) ─────────────────
    def _load_model(ckpt_path: Path):
        ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
        use_surface = ckpt.get("use_surface_heads", False)
        model = ConsistencyTransformer(
            ckpt["model_cfg"], use_surface_heads=use_surface
        ).to(device)
        model.load_state_dict(ckpt["model_state"])
        model.eval()
        return model, ckpt["epoch"]

    # ── Generate rows for all variants × checkpoint epochs ────────────────────
    rows = []          # one dict per (variant, checkpoint_label, example)
    manifest = []      # summary of what was loaded, for the return value

    for variant in v2_variants:
        initial_path, final_path = _find_checkpoints(variant)

        checkpoints_to_run = []
        if include_epoch_initial:
            checkpoints_to_run.append(("initial", initial_path))
        checkpoints_to_run.append(("final", final_path))

        for ckpt_label, ckpt_path in checkpoints_to_run:
            model, epoch_num = _load_model(ckpt_path)

            manifest.append({
                "variant":      variant,
                "ckpt_label":   ckpt_label,
                "ckpt_path":    str(ckpt_path),
                "epoch_number": epoch_num,
            })

            for sample_i, ex in enumerate(val_examples[:n_examples], start=1):
                raw   = greedy_decode(
                    model, tokenizer, ex.code_snippet,
                    max_new_tokens=80, device=device,
                )
                clean = _clean_decoded(raw)
                prose = _prose_part(clean)

                bleu  = simple_bleu1(prose, ex.true_explanation)
                rouge = rouge_l_score(prose, ex.true_explanation)

                rows.append({
                    # Identity
                    "sample":               sample_i,
                    "example_idx":          ex.idx,
                    "template_name":        ex.template_name,
                    "variant":              variant,
                    "checkpoint_label":     ckpt_label,
                    "epoch_number":         epoch_num,
                    # Ground truth
                    "time_complexity":      ex.time_complexity,
                    "space_complexity":     ex.space_complexity,
                    "correctness":          ex.correctness,
                    "true_explanation":     ex.true_explanation,
                    "mismatched_explanation": ex.mismatched_explanation,
                    "code_snippet":         ex.code_snippet,
                    # Generation
                    "generation_raw":       raw,
                    "generation_clean":     clean,
                    "prose_clean":          prose,
                    # Automatic metrics
                    "bleu1":                round(bleu,  4),
                    "rouge_l":              round(rouge, 4),
                    # Claim emission flags
                    "emits_time_claim":
                        f"time_complexity={ex.time_complexity}" in raw,
                    "emits_space_claim":
                        f"space_complexity={ex.space_complexity}" in raw,
                    "emits_correctness_claim":
                        f"correctness={ex.correctness}" in raw,
                })

            del model
            if device.type == "cuda":
                torch.cuda.empty_cache()
            gc.collect()

    # ── Write CSV ─────────────────────────────────────────────────────────────
    csv_path = review_dir / "qualitative_side_by_side_v2.csv"
    with csv_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        writer.writeheader()
        writer.writerows(rows)

    # ── Write JSON ────────────────────────────────────────────────────────────
    json_path = review_dir / "qualitative_side_by_side_v2.json"
    json_path.write_text(
        json.dumps({"manifest": manifest, "rows": rows}, indent=2),
        encoding="utf-8",
    )

    # ── Write Markdown ────────────────────────────────────────────────────────
    md_path = review_dir / "qualitative_side_by_side_v2.pplx.md"

    # Per-variant aggregate metrics summary (final checkpoint only)
    variant_summary = {}
    for variant in v2_variants:
        final_rows = [
            r for r in rows
            if r["variant"] == variant and r["checkpoint_label"] == "final"
        ]
        if final_rows:
            variant_summary[variant] = {
                "bleu1":   round(sum(r["bleu1"]   for r in final_rows) / len(final_rows), 4),
                "rouge_l": round(sum(r["rouge_l"] for r in final_rows) / len(final_rows), 4),
                "time_claim_rate":
                    round(sum(r["emits_time_claim"]        for r in final_rows) / len(final_rows), 3),
                "space_claim_rate":
                    round(sum(r["emits_space_claim"]       for r in final_rows) / len(final_rows), 3),
                "correctness_claim_rate":
                    round(sum(r["emits_correctness_claim"] for r in final_rows) / len(final_rows), 3),
                "n": len(final_rows),
            }

    # V2 variant labels for human-readable display
    VARIANT_LABELS = {
        "no_claim_to_claim_attention":   "No Claim→Claim Attn",
        "claims_from_explanation_only":  "Claims from Expl Only",
        "surface_bottleneck_consistency": "Surface Bottleneck",
        "surface_bottleneck_no_expl_lm": "Surface + No Expl LM",
    }

    md_lines = [
        "# V2 Qualitative Side-by-Side Review",
        "",
        f"**Run:** `{run_dir_name}`  ",
        f"**Variants:** V2 stronger ablation ladder (4 variants)  ",
        f"**Validation examples:** {n_examples}  ",
        f"**Checkpoint epochs reviewed:** "
        + ("initial (epoch_005) + final (epoch_020)"
           if include_epoch_initial else "final (epoch_020) only"),
        "",
        "Scores below are automatic proxies. Manual review of the prose column",
        "is the primary purpose of this document.",
        "",
        "## Aggregate Metrics (final checkpoint, automatic)",
        "",
        "| Variant | BLEU-1 | ROUGE-L | Time claim % | Space claim % | Correct claim % |",
        "|---|---|---|---|---|---|",
    ]
    for variant in v2_variants:
        s = variant_summary.get(variant, {})
        label = VARIANT_LABELS.get(variant, variant)
        md_lines.append(
            f"| `{label}` "
            f"| {s.get('bleu1', 'n/a')} "
            f"| {s.get('rouge_l', 'n/a')} "
            f"| {s.get('time_claim_rate', 'n/a')} "
            f"| {s.get('space_claim_rate', 'n/a')} "
            f"| {s.get('correctness_claim_rate', 'n/a')} |"
        )
    md_lines.append("")

    # ── Per-sample section ────────────────────────────────────────────────────
    md_lines += [
        "## Per-Sample Generations",
        "",
        "Each sample shows: reference explanation, mismatched training explanation,",
        "and per-variant generated prose at initial and final checkpoints.",
        "Columns: Variant | Ckpt | BLEU-1 | ROUGE-L | Emits claims? | Prose",
        "",
    ]

    by_sample: dict = {}
    for row in rows:
        by_sample.setdefault(row["sample"], []).append(row)

    for sample, sample_rows in sorted(by_sample.items()):
        first = sample_rows[0]
        time_claim_ok  = "✓" if first["emits_time_claim"]        else "✗"
        space_claim_ok = "✓" if first["emits_space_claim"]       else "✗"
        corr_claim_ok  = "✓" if first["emits_correctness_claim"] else "✗"

        md_lines.extend([
            f"### Sample {sample}: `{first['template_name']}`",
            "",
            f"Ground truth: "
            f"time=`{first['time_complexity']}`, "
            f"space=`{first['space_complexity']}`, "
            f"correct=`{first['correctness']}`",
            "",
            f"**Reference explanation:** {first['true_explanation']}",
            "",
            f"**Mismatched training explanation:** {first['mismatched_explanation']}",
            "",
            "```python",
            first["code_snippet"],
            "```",
            "",
            "| Variant | Ckpt | BLEU-1 | ROUGE-L "
            "| Time✓ | Space✓ | Correct✓ | Generated prose |",
            "|---|---|---|---|---|---|---|---|",
        ])

        for row in sample_rows:
            label    = VARIANT_LABELS.get(row["variant"], row["variant"])
            ckpt_lbl = row["checkpoint_label"]
            epoch_n  = row["epoch_number"]
            prose    = row["prose_clean"].replace("|", "\\|")
            t_ok  = "✓" if row["emits_time_claim"]        else "✗"
            sp_ok = "✓" if row["emits_space_claim"]       else "✗"
            co_ok = "✓" if row["emits_correctness_claim"] else "✗"
            md_lines.append(
                f"| `{label}` | {ckpt_lbl} (ep{epoch_n}) "
                f"| {row['bleu1']} | {row['rouge_l']} "
                f"| {t_ok} | {sp_ok} | {co_ok} "
                f"| {prose} |"
            )
        md_lines.append("")

    # ── Scoring guide ─────────────────────────────────────────────────────────
    md_lines += [
        "## Manual Scoring Guide",
        "",
        "For each sample, score each variant's **final-checkpoint prose** on:",
        "",
        "| Criterion | Scale | Notes |",
        "|---|---|---|",
        "| Fluency | 0–2 | 0=incoherent, 1=partial, 2=fluent |",
        "| Factual accuracy | 0–2 | Does prose match ground-truth complexity/correctness? |",
        "| Claim alignment | 0–1 | Do emitted `<claim>` tags match ground truth? |",
        "| Initial→Final improvement | 0–1 | Did the prose improve from initial to final ckpt? |",
        "",
        "Aggregate per-variant totals and compare across the V2 ablation ladder.",
        "Key question: does `surface_bottleneck_consistency` or `claims_from_explanation_only`",
        "show better factual accuracy than `no_claim_to_claim_attention` (closest to V1 baseline)?",
        "",
        "Files in this directory:",
        f"- `{csv_path.name}` — machine-readable, one row per (variant, checkpoint, example)",
        f"- `{json_path.name}` — same data as JSON with manifest header",
        f"- `{md_path.name}` — this document",
    ]

    md_path.write_text("\n".join(md_lines), encoding="utf-8")

    # ── Commit and return ─────────────────────────────────────────────────────
    volume.commit()

    return {
        "run_dir_name":  run_dir_name,
        "review_dir":    str(review_dir),
        "csv_path":      str(csv_path),
        "json_path":     str(json_path),
        "md_path":       str(md_path),
        "n_rows":        len(rows),
        "n_examples":    n_examples,
        "device":        str(device),
        "manifest":      manifest,
        "variant_summary": variant_summary,
        "download_command": (
            f"modal volume get {VOLUME_NAME} "
            f"/results/{run_dir_name}_qualitative_review "
            f"./v2_qualitative_review"
        ),
    }


@app_v2.function(
    gpu="A10G",
    timeout=60 * 60 * 12,   # 12 h ceiling; V2 (4 variants) should finish in ~4–6 h
    volumes={"/results": volume},
)
def run_full_gpt2_small_v2() -> dict:
    """
    Full GPT-2-small run of the V2 stronger ablation ladder on an A10G.

    Variants trained (all 4 V2):
      - no_claim_to_claim_attention    : claim→claim attention blocked
      - claims_from_explanation_only   : claims attend explanation tokens only
      - surface_bottleneck_consistency : consistency via LM logit distributions
      - surface_bottleneck_no_expl_lm  : surface bottleneck + no LM loss on expl tokens

    Config: gpt2_small (768 d_model, 12 heads, 12 layers, ~117 M params),
            3000 examples, 20 epochs, batch 32, lr 5e-5.

    Outputs land in /results/full_gpt2_small_stronger_<YYYYMMDD_HHMMSS>/ on
    the Modal volume. A zip archive of the full output directory is also written
    at /results/full_gpt2_small_stronger_<timestamp>.zip for easy download.
    A separate log file is written at /results/full_gpt2_small_stronger_<timestamp>.log.

    The V1 run at /results/full_gpt2_small_20260430_174439 is NEVER touched.

    Launch command (from project root, authenticated Modal workspace):
        modal run modal_run.py::run_full_gpt2_small_v2

    Detached launch (returns immediately, runs in background):
        modal run --detach modal_run.py::run_full_gpt2_small_v2

    Monitor progress:
        modal app logs consistency-loss-full-gpt2-small-v2

    Download outputs after completion:
        # First find the timestamped zip name on the volume:
        modal volume ls consistency-loss-experiment-results /results
        # Then download:
        modal volume get consistency-loss-experiment-results \\
            /results/full_gpt2_small_stronger_<timestamp>.zip .
        # Or download the raw directory tree:
        modal volume get consistency-loss-experiment-results \\
            /results/full_gpt2_small_stronger_<timestamp> ./local_v2_results
    """

    workdir   = Path("/root/consistency_loss_experiment")
    run_id    = time.strftime("%Y%m%d_%H%M%S")
    run_tag   = f"full_gpt2_small_stronger_{run_id}"
    output_dir = Path("/results") / run_tag
    zip_path   = Path("/results") / f"{run_tag}.zip"
    log_path   = Path("/results") / f"{run_tag}.log"

    # Guard: refuse to overwrite the existing V1 run directory.
    v1_sentinel = Path("/results/full_gpt2_small_20260430_174439")
    assert not (output_dir == v1_sentinel), (
        f"V2 output dir collides with V1 sentinel: {v1_sentinel}"
    )

    # V2 variant names, explicit rather than relying on --v2-only so the exact
    # set is visible in the log and not silently extended if VARIANTS_V2 grows.
    v2_variants = [
        "no_claim_to_claim_attention",
        "claims_from_explanation_only",
        "surface_bottleneck_consistency",
        "surface_bottleneck_no_expl_lm",
    ]

    cmd = [
        "python",
        "run_experiment.py",
        "--full",
        "--model",   "gpt2_small",
        "--epochs",  "20",
        "--batch",   "32",
        "--output-dir", str(output_dir),
        "--variants",
    ] + v2_variants

    with log_path.open("w", encoding="utf-8") as log_file:
        log_file.write(f"run_tag   : {run_tag}\n")
        log_file.write(f"output_dir: {output_dir}\n")
        log_file.write(f"variants  : {v2_variants}\n")
        log_file.write("command   : " + " ".join(cmd) + "\n")
        log_file.write("-" * 72 + "\n")
        log_file.flush()

        proc = subprocess.run(
            cmd,
            cwd=workdir,
            stdout=log_file,
            stderr=subprocess.STDOUT,
            text=True,
        )

    # Always commit so the log is accessible even on failure.
    volume.commit()

    if proc.returncode != 0:
        raise RuntimeError(
            f"V2 experiment failed (exit code {proc.returncode}).  "
            f"Inspect log: modal volume get {VOLUME_NAME} {log_path} ."
        )

    # Zip the output directory for single-file download convenience.
    # -c creates a new archive; use the directory path as the source so
    # the zip contains one top-level folder named run_tag.
    subprocess.run(
        ["python", "-m", "zipfile", "-c", str(zip_path), str(output_dir)],
        check=True,
        cwd="/results",
    )
    volume.commit()

    return {
        "run_tag":    run_tag,
        "run_id":     run_id,
        "output_dir": str(output_dir),
        "zip_path":   str(zip_path),
        "log_path":   str(log_path),
        "variants":   v2_variants,
        "download_zip_command": (
            f"modal volume get {VOLUME_NAME} /results/{run_tag}.zip ."
        ),
        "download_dir_command": (
            f"modal volume get {VOLUME_NAME} /results/{run_tag} ./local_v2_results"
        ),
    }
