from __future__ import annotations
from pathlib import Path
import os, shutil, json, time
from ..cfg import CIRBenchConfig
from ..registry import discover_cases
from ..utils.api.base import make_runner
from ..utils.runner_common import (
    get_run_dir, get_case_dir, report_mode, resume_mode, next_missing_shot, shot_paths,
    gen_with_retries, persist_model_io, write_text, write_json, fmt_cmd_block,
    extract_ir_and_meta, atk, build_equiv_dict, sha256,  # no billing
    write_repro_sh, maybe_short_circuit_prompt_only, maybe_materialize_external_io
)
from ..utils.logging_utils import get_logger
from ..utils.metrics import exact_match, edit_similarity, bleu, rouge_l
from ..utils.evalers import assemble_ok, verify_ok, alive_equiv
from ..utils.case_select import select_cases

def _format_hint(meta: dict) -> str:
    cand = (meta or {}).get("hint") or meta.get("errors") or meta.get("error") or meta.get("desc") or meta.get("description")
    if cand is None: return "(no hint provided)"
    if isinstance(cand, list): return "\n".join(f"- {str(x)}" for x in cand)
    if isinstance(cand, dict):
        import json as _j
        try: return _j.dumps(cand, ensure_ascii=False, indent=2)
        except Exception: return str(cand)
    return str(cand)

def _prompt(ir: str, hint: str | None) -> str:
    head = (
        "You are an LLVM IR repair assistant.\n"
        "Fix the IR if it fails assembly or verifier; otherwise reply exactly: No\n"
    )
    # Only include hint when present (e.g., not in hard mode)
    if hint:
        head += "Hint:\n" + hint + "\n"
    return (
        head
        + "Begin IR:\n<IR>\n" + ir + "\n</IR>\n\n"
        "Output ONLY the repaired IR as <IR_OUT>...</IR_OUT> (or No)."
    )

def _recover_ir_out_loose(text: str) -> str | None:
    """
    Fallback parser: if the model output contains "<IR_OUT>" but lacks a closing
    tag, treat everything after the first occurrence as the IR. Also trims common
    code‑fence wrappers.
    """
    if not text:
        return None
    try:
        s = str(text)
    except Exception:
        return None
    marker = "<IR_OUT>"
    pos = s.find(marker)
    if pos < 0:
        return None
    body = s[pos + len(marker):]
    # If a proper closing tag exists, keep only the content before it.
    end = body.find("</IR_OUT>")
    if end >= 0:
        body = body[:end]
    body = body.strip()
    # Strip surrounding triple backticks and optional language hints
    if body.startswith("```"):
        # remove the first fence line
        nl = body.find("\n")
        if nl >= 0:
            body = body[nl+1:]
        # strip a trailing fence if present
        if body.rstrip().endswith("```"):
            body = body.rstrip()
            body = body[: body.rfind("```")]
    body = body.strip()
    return body if body else None

def _summary_flags(shots):
    def flag(key, k): return bool(atk(shots, k, key))
    return {
        "pass_at_1": flag("pass",1), "pass_at_5": flag("pass",5),
        "valid_at_1": flag("valid",1), "valid_at_5": flag("valid",5),
        "equiv_at_1": flag("equiv",1), "equiv_at_5": flag("equiv",5),
    }

def run_task(cfg: CIRBenchConfig, proj_root: Path):
    logger = get_logger()
    tdir = proj_root / "cirbench" / "repair"
    cases = select_cases(discover_cases(tdir))
    if not cases:
        print("No repair cases found."); return

    from ..utils.api.base import make_runner, select_model_cfg
    runner_cfg = select_model_cfg(cfg)
    runner = make_runner(runner_cfg)
    run_dir = get_run_dir(proj_root)
    mode = (os.getenv("CIRBENCH_REPAIR_MODE") or "normal").strip().lower()

    for case in cases:
        case_dir = get_case_dir(run_dir, "repair", case.id)

        ir_gold = (case.gold.read_text(encoding="utf-8") if (getattr(case, "gold", None) and case.gold.exists()) else case.raw_ir.read_text(encoding="utf-8"))
        # In hard mode, do not expose concrete hints/errors to the model
        use_hint = (mode != "hard")
        full_hint = _format_hint(getattr(case, "meta", {}) or {})
        hint = full_hint if use_hint else None
        prompt = _prompt(case.raw_ir.read_text(encoding="utf-8"), hint)
        psha = sha256(prompt)

        if report_mode():
            write_text(case_dir/"prompt.txt", prompt)
            write_text(case_dir/"prompt.sha256", psha)

        start_k = 1
        if resume_mode():
            mk = next_missing_shot(case_dir, 5)
            start_k = mk if mk <= 5 else 6
            if start_k == 6:
                print(f"[repair] skip {case.id} (already has 1..5)")
                continue

        shots = []
        for k in range(start_k, 6):  # 1..5
            sp = shot_paths(case_dir, k); sp["dir"].mkdir(parents=True, exist_ok=True)

            # write golden.ll into every shot dir for convenient diff
            try:
                write_text(sp["dir"]/ "golden.ll", ir_gold)
            except Exception:
                pass

            # === Prompt-only: write prompt into this shot dir and stop before LLM ===
            if maybe_short_circuit_prompt_only(cfg, sp["dir"], prompt):
                print(f"{case.id}: prompt-only (no LLM)")
                continue

            # === From-files fast path: optionally skip LLM and use external outputs ===
            mode_ext = maybe_materialize_external_io(cfg, case.id, sp["dir"])
            if mode_ext == "resp":
                # Treat external file as model.resp.txt; parse to IR and evaluate
                resp_text = (sp["dir"]/ "model.resp.txt").read_text(encoding="utf-8")
                # Parse JSON/meta and try to recover IR
                ir_out, ierr, meta_json = extract_ir_and_meta(resp_text)
                parse_note = ""
                if (not ir_out) or ierr:
                    rec = _recover_ir_out_loose(resp_text)
                    if rec:
                        ir_out = rec
                        ierr = False
                        parse_note = "loose_IR_OUT_recovery=1"
                meta = {"source": "from-files", "kind": "resp"}

                # Assemble / Verify / Alive2
                p_ok = assemble_ok(ir_out or "", cfg) if ir_out else {"ok": False, "stdout": "", "stderr": "", "rc": 1, "cmd": ""}
                v_ok = verify_ok(ir_out or "", cfg) if p_ok.get("ok") else {"ok": False, "stdout": "", "stderr": "", "rc": 1, "cmd": ""}
                ae = alive_equiv(ir_gold, ir_out, cfg, dump_path=sp["alive_merged"]) if (v_ok.get("ok") and ir_out) else {"status":"skip"}

                if report_mode():
                    # Persist raw model-like output & meta for this shot
                    persist_model_io(case_dir, k, resp_text, meta)
                    # Save the repaired variant IR if we got any
                    if ir_out:
                        write_text(sp["variant"], ir_out)

                    # Build shot log
                    log = []
                    if parse_note:
                        log.append("## Parse")
                        log.append(parse_note)

                    log.append("## Assemble")
                    log.append(fmt_cmd_block((p_ok.get("cmd") or "").split(), p_ok.get("rc"), p_ok.get("stdout",""), p_ok.get("stderr","")))

                    log.append("## Verify")
                    log.append(fmt_cmd_block((v_ok.get("cmd") or "").split(), v_ok.get("rc"), v_ok.get("stdout",""), v_ok.get("stderr","")))

                    if isinstance(ae, dict) and ae.get("cmd"):
                        log.append("## Alive2")
                        log.append(fmt_cmd_block(ae.get("cmd"), ae.get("exit"), ae.get("out",""), ae.get("err","")))

                    write_text(sp["shotlog"], "\n".join(log) + "\n")
                    write_repro_sh(sp, "repair", cfg, func=None)

                # Text similarity metrics w.r.t. golden
                em  = exact_match(ir_gold or "", ir_out or "")
                eds = edit_similarity(ir_gold or "", ir_out or "")
                b2  = bleu(ir_gold or "", ir_out or "", max_n=2)
                b4  = bleu(ir_gold or "", ir_out or "", max_n=4)
                rg  = rouge_l(ir_gold or "", ir_out or "")

                # Equivalence summary
                equiv_status = "skip"; equiv_ok = None
                if ae.get("status") == "ok":
                    equiv_status = "ok"; equiv_ok = bool(ae.get("equiv"))
                elif ae.get("status") == "err":
                    equiv_status = "err"; equiv_ok = None
                eqv = build_equiv_dict(equiv_status, equiv_ok)

                shot = {
                    "k": k,
                    "pass": bool(p_ok.get("ok")),
                    "valid": bool(v_ok.get("ok")),
                    "equiv": eqv,
                    "metrics": {"em": em, "edit_sim": eds, "bleu2": b2, "bleu4": b4, "rougeL": rg},
                    "unparseable": bool(ierr),
                    "latency_ms": None,
                    "tokens": {"in": None, "out": None},
                    "raw_meta": meta,
                }
                shots.append(shot)

                if (equiv_status == "ok" and equiv_ok is True):
                    from ..utils.runner_common import mark_early_stop
                    mark_early_stop(case_dir, k, "equiv_true")
                    break
                # done with from-files:resp for this shot
                print(f"{case.id}: [from-files:resp] pass={shot['pass']} valid={shot['valid']} equiv={shot['equiv'].get('ok')}")
                continue

            elif mode_ext == "pred":
                # External file already provides repaired IR (variant.ll) for evaluation
                var_ll = sp["variant"]
                if var_ll.exists():
                    ir_out = var_ll.read_text(encoding="utf-8")
                    ierr = False
                    parse_note = ""
                    meta = {"source": "from-files", "kind": "pred"}

                    p_ok = assemble_ok(ir_out or "", cfg) if ir_out else {"ok": False, "stdout": "", "stderr": "", "rc": 1, "cmd": ""}
                    v_ok = verify_ok(ir_out or "", cfg) if p_ok.get("ok") else {"ok": False, "stdout": "", "stderr": "", "rc": 1, "cmd": ""}
                    ae = alive_equiv(ir_gold, ir_out, cfg, dump_path=sp["alive_merged"]) if (v_ok.get("ok") and ir_out) else {"status":"skip"}

                    if report_mode():
                        # Build shot log
                        log = []
                        log.append("## Assemble")
                        log.append(fmt_cmd_block((p_ok.get("cmd") or "").split(), p_ok.get("rc"), p_ok.get("stdout",""), p_ok.get("stderr","")))
                        log.append("## Verify")
                        log.append(fmt_cmd_block((v_ok.get("cmd") or "").split(), v_ok.get("rc"), v_ok.get("stdout",""), v_ok.get("stderr","")))
                        if isinstance(ae, dict) and ae.get("cmd"):
                            log.append("## Alive2")
                            log.append(fmt_cmd_block(ae.get("cmd"), ae.get("exit"), ae.get("out",""), ae.get("err","")))
                        write_text(sp["shotlog"], "\n".join(log) + "\n")
                        write_repro_sh(sp, "repair", cfg, func=None)

                    em  = exact_match(ir_gold or "", ir_out or "")
                    eds = edit_similarity(ir_gold or "", ir_out or "")
                    b2  = bleu(ir_gold or "", ir_out or "", max_n=2)
                    b4  = bleu(ir_gold or "", ir_out or "", max_n=4)
                    rg  = rouge_l(ir_gold or "", ir_out or "")

                    equiv_status = "skip"; equiv_ok = None
                    if ae.get("status") == "ok":
                        equiv_status = "ok"; equiv_ok = bool(ae.get("equiv"))
                    elif ae.get("status") == "err":
                        equiv_status = "err"; equiv_ok = None
                    eqv = build_equiv_dict(equiv_status, equiv_ok)

                    shot = {
                        "k": k,
                        "pass": bool(p_ok.get("ok")),
                        "valid": bool(v_ok.get("ok")),
                        "equiv": eqv,
                        "metrics": {"em": em, "edit_sim": eds, "bleu2": b2, "bleu4": b4, "rougeL": rg},
                        "unparseable": bool(ierr),
                        "latency_ms": None,
                        "tokens": {"in": None, "out": None},
                        "raw_meta": meta,
                    }
                    shots.append(shot)

                    if (equiv_status == "ok" and equiv_ok is True):
                        from ..utils.runner_common import mark_early_stop
                        mark_early_stop(case_dir, k, "equiv_true")
                        break

                    print(f"{case.id}: [from-files:pred] pass={shot['pass']} valid={shot['valid']} equiv={shot['equiv'].get('ok')}")
                    continue
                else:
                    # No usable variant.ll for 'pred' mode; fall through to normal LLM path
                    pass

            hint = None
            for fn in ("golden.ll"):
                p = getattr(case, fn.replace(".", "_"), None)
                if p and p.exists():
                    hint = p.read_text(encoding="utf-8"); break
            outc = gen_with_retries(runner, prompt)
            ir_out, ierr, meta_json = extract_ir_and_meta(outc.text)
            parse_note = ""
            # Fallback: accept everything after <IR_OUT> when closing tag is missing
            if (not ir_out) or ierr:
                rec = _recover_ir_out_loose(outc.text)
                if rec:
                    ir_out = rec
                    ierr = False
                    parse_note = "loose_IR_OUT_recovery=1"
            meta = outc.meta or {}

            p_ok = assemble_ok(ir_out or "", cfg) if ir_out else {"ok": False, "stdout": "", "stderr": "", "rc": 1, "cmd": ""}
            v_ok = verify_ok(ir_out or "", cfg) if p_ok.get("ok") else {"ok": False, "stdout": "", "stderr": "", "rc": 1, "cmd": ""}
            ae = alive_equiv(ir_gold, ir_out, cfg, dump_path=sp["alive_merged"]) if (v_ok.get("ok") and ir_out) else {"status":"skip"}

            if report_mode():
                # Persist raw model output & meta for this shot
                persist_model_io(case_dir, k, outc.text, meta)
                # Save the variant IR if we recovered or parsed any
                if ir_out:
                    write_text(sp["variant"], ir_out)

                # Build shot log now that p_ok / v_ok / ae are available
                log = []
                if parse_note:
                    log.append("## Parse")
                    log.append(parse_note)

                log.append("## Assemble")
                log.append(fmt_cmd_block((p_ok.get("cmd") or "").split(), p_ok.get("rc"), p_ok.get("stdout",""), p_ok.get("stderr","")))

                log.append("## Verify")
                log.append(fmt_cmd_block((v_ok.get("cmd") or "").split(), v_ok.get("rc"), v_ok.get("stdout",""), v_ok.get("stderr","")))

                if isinstance(ae, dict) and ae.get("cmd"):
                    log.append("## Alive2")
                    log.append(fmt_cmd_block(ae.get("cmd"), ae.get("exit"), ae.get("out",""), ae.get("err","")))

                write_text(sp["shotlog"], "\n".join(log) + "\n")
                # Repro script
                write_repro_sh(sp, "repair", cfg, func=None)

            em  = exact_match(ir_gold or "", ir_out or "")
            eds = edit_similarity(ir_gold or "", ir_out or "")
            b2  = bleu(ir_gold or "", ir_out or "", max_n=2)
            b4  = bleu(ir_gold or "", ir_out or "", max_n=4)
            rg  = rouge_l(ir_gold or "", ir_out or "")

            equiv_status = "skip"; equiv_ok = None
            if ae.get("status") == "ok":
                equiv_status = "ok"; equiv_ok = bool(ae.get("equiv"))
            elif ae.get("status") == "err":
                equiv_status = "err"; equiv_ok = None
            eqv = build_equiv_dict(equiv_status, equiv_ok)

            shot = {
                "k": k,
                "pass": bool(p_ok.get("ok")),
                "valid": bool(v_ok.get("ok")),
                "equiv": eqv,
                "metrics": {"em": em, "edit_sim": eds, "bleu2": b2, "bleu4": b4, "rougeL": rg},
                "unparseable": bool(ierr),
                "latency_ms": meta.get("latency_ms"),
                "tokens": {"in": meta.get("prompt_tokens"), "out": meta.get("out_tokens")},
                "raw_meta": meta,
            }
            shots.append(shot)

            if (equiv_status == "ok" and equiv_ok is True):
                from ..utils.runner_common import mark_early_stop
                mark_early_stop(case_dir, k, "equiv_true")
                break

        # write per-case metrics.json under 01_artifacts/
        mdir = shot_paths(case_dir, 1)["dir"]; mdir.mkdir(parents=True, exist_ok=True)
        write_json(mdir/"metrics.json", {
            "case_id": case.id,
            "difficulty": (case.meta.get("difficulty") if isinstance(case.meta, dict) else None),
            "mode": mode,
            "shots": shots,
            "summary": _summary_flags(shots),
        })

        def flag(shots, key, k): return "T" if atk(shots, k, key) else "F"
        print(f"{case.id}: Pass@1={flag(shots,'pass',1)} Pass@5={flag(shots,'pass',5)} | "
              f"Valid@1={flag(shots,'valid',1)} Valid@5={flag(shots,'valid',5)} | "
              f"Equiv@1={flag(shots,'equiv',1)} Equiv@5={flag(shots,'equiv',5)}")
