from __future__ import annotations
from pathlib import Path
import os, re, json
from ..cfg import CIRBenchConfig
from ..registry import discover_cases
from ..utils.api.base import make_runner
from ..utils.case_select import select_cases
from ..utils.runner_common import (
    get_run_dir, get_case_dir, report_mode, resume_mode, next_missing_shot, shot_paths,
    gen_with_retries, persist_model_io, write_text, write_json, fmt_cmd_block,
    extract_ir_and_meta, atk, sha256,
    write_repro_sh, maybe_short_circuit_prompt_only, maybe_materialize_external_io
)
from ..utils.logging_utils import get_logger
from ..utils.metrics import exact_match, edit_similarity, bleu, rouge_l
from ..utils.evalers import assemble_ok, verify_ok, alive_equiv

PASS_LIST = [
    "CorrelatedValuePropagation","DeadStoreElim","EarlyCSE","GlobalOpt","GVN","IndVarSimplify","InstCombine",
    "InstSimplify","IPSCCP","JumpThreading","LCSSA","LICM","LoopRotate","LoopSimplify","LoopUnroll",
    "MemCpyOpt","Reassociate","SimplifyCFG","SROA","TailCallElim"
]

def _prompt_normal(ir: str, pass_name: str, hint: str) -> str:
    return (
        "You are an LLVM IR refactor assistant.\n"
        f"Apply exactly this optimization pass, no others: {pass_name}\n"
        "IR:\n<IR>\n" + ir + "\n</IR>\n\n"
        "If no transformations are needed, reply: No; else output ONLY <IR_OUT>...</IR_OUT>"
    )

def _prompt_reverse(before_ir: str, after_ir: str) -> str:
    items = "\n".join(f"- {p}" for p in PASS_LIST)
    return (
        "I have applied ONE LLVM pass to transform BEFORE LLVM IR into AFTER LLVM IR.\n"
        "Identify which pass was applied from the following list.\n"
        "List:\n" + items + "\n"
        "Return ONLY the pass name in this format without any explanation: <CIR_JSON>{\"pass\":\"<ONE_OF_LIST>\"}</CIR_JSON>\n"
        "BEFORE:\n<BEFORE_IR>\n" + before_ir + "\n</BEFORE_IR>\n"
        "AFTER:\n<AFTER_IR>\n" + after_ir + "\n</AFTER_IR>"
    )

def _summary_flags(shots):
    def flag(key, k): return bool(atk(shots, k, key))
    return {
        "pass_at_1": flag("pass",1), "pass_at_5": flag("pass",5),
        "valid_at_1": flag("valid",1), "valid_at_5": flag("valid",5),
        "equiv_at_1": flag("equiv",1), "equiv_at_5": flag("equiv",5),
    }

def run_task(cfg: CIRBenchConfig, proj_root: Path):
    logger = get_logger()
    tdir = proj_root / "cirbench" / "refactor"
    cases = select_cases(discover_cases(tdir))
    if not cases:
        print("No refactor cases found."); return

    from ..utils.api.base import make_runner, select_model_cfg
    runner_cfg = select_model_cfg(cfg)
    runner = make_runner(runner_cfg)
    run_dir = get_run_dir(proj_root)

    mode = (os.getenv("CIRBENCH_REFACTOR_MODE") or "normal").strip().lower()
    for case in cases:
        case_dir = get_case_dir(run_dir, "refactor", case.id)
        before_ir = (getattr(case, "before_ir", None) or case.raw_ir).read_text(encoding="utf-8")
        after_ir  = (getattr(case, "after_ir",  None) or getattr(case, "gold", None))
        after_ir  = after_ir.read_text(encoding="utf-8") if (after_ir and after_ir.exists()) else ""

        hint_txt  = case.prompt.read_text(encoding="utf-8") if (getattr(case,"prompt",None) and case.prompt.exists()) else ""

        if mode == "normal":
            expected = case.id.split("_")[1] if "_" in case.id else "InstCombine"
            prompt = _prompt_normal(before_ir, expected, hint_txt)
        else:
            prompt = _prompt_reverse(before_ir, after_ir)

        psha = sha256(prompt)
        if report_mode():
            write_text(case_dir/"prompt.txt", prompt); write_text(case_dir/"prompt.sha256", psha)

        # Prepare first shot directory for prompt-only / from-files workflows
        sp0 = shot_paths(case_dir, 1); sp0["dir"].mkdir(parents=True, exist_ok=True)

        # === Prompt-only: write prompt and stop before calling LLM ===
        if maybe_short_circuit_prompt_only(cfg, sp0["dir"], prompt):
            print(f"{case.id}: prompt-only (no LLM)")
            try:
                write_json(sp0["dir"]/ "metrics.json", {"case_id": case.id, "mode": mode, "shots": []})
            except Exception:
                pass
            continue

        shots = []
        if mode == "reverse":
            # single shot — reverse mode is *classification only* (EM), no IR parsing / no llvm-as/opt/alive
            sp = shot_paths(case_dir, 1); sp["dir"].mkdir(parents=True, exist_ok=True)

            # Choose source: from-files 'resp' if provided, else call LLM
            src_text = None
            src_meta = None

            mode_ext = maybe_materialize_external_io(cfg, case.id, sp["dir"])
            if mode_ext == "resp":
                # External response text stands in for model output
                try:
                    src_text = (sp["dir"]/ "model.resp.txt").read_text(encoding="utf-8")
                except Exception:
                    src_text = ""
                src_meta = {"source": "from-files", "kind": "resp"}
            else:
                from ..utils.api.base import make_runner, select_model_cfg
                runner_cfg = select_model_cfg(cfg)
                runner = make_runner(runner_cfg)
                outc = gen_with_retries(runner, prompt)
                src_text = outc.text
                src_meta = outc.meta or {}

            # Persist a model-like IO record for consistency
            persist_model_io(case_dir, 1, src_text or "", src_meta)

            # EM: expected pass name substring check (case-insensitive)
            expected = case.id.split("_")[1] if "_" in case.id else ""
            em = (expected.lower() in (src_text or "").lower()) if expected else False

            # For reverse mode, the shot schema is intentionally minimal
            shot = {
                "k": 1,
                "em": bool(em),
                "unparseable": False,
                "latency_ms": src_meta.get("latency_ms"),
                "tokens": {"in": src_meta.get("prompt_tokens"), "out": src_meta.get("out_tokens")},
                "raw_meta": src_meta,
            }
            shots.append(shot)

            # Minimal repro script
            write_repro_sh(sp, "refactor", cfg, func=None)
        else:
            # normal mode: up to k=5 with early stop on Alive2 success
            start_k = 1
            if resume_mode():
                mk = next_missing_shot(case_dir, 5)
                start_k = mk if mk <= 5 else 6
                if start_k == 6:
                    print(f"[refactor] skip {case.id} (already 1..5)")
                    # still write empty metrics for visibility
                    mdir = shot_paths(case_dir, 1)["dir"]; mdir.mkdir(parents=True, exist_ok=True)
                    write_json(mdir/"metrics.json", {"case_id": case.id, "mode": mode, "shots": shots, "summary": _summary_flags(shots)})
                    continue

            for k in range(start_k, 6):
                sp = shot_paths(case_dir, k); sp["dir"].mkdir(parents=True, exist_ok=True)

                # Try external artifacts first
                mode_ext = maybe_materialize_external_io(cfg, case.id, sp["dir"])
                ir_out = None
                ierr = False
                meta = {}
                parse_note = "ok"
                resp_text = None

                if mode_ext == "resp":
                    try:
                        resp_text = (sp["dir"]/ "model.resp.txt").read_text(encoding="utf-8")
                    except Exception:
                        resp_text = ""
                    ir_out, ierr, meta_json = extract_ir_and_meta(resp_text)
                    if not ir_out:
                        t = resp_text or ""
                        m = re.search(r"<IR_OUT>\s*", t)
                        if m:
                            ir_out = t[m.end():].strip()
                            ierr = False if ("define" in ir_out) else True
                            parse_note = "recovered_from_unclosed_ir_out"
                        elif "define" in t:
                            ir_out = t.strip()
                            ierr = False
                            parse_note = "recovered_entire_text"
                    meta = {"source": "from-files", "kind": "resp"}
                    # Persist a model-like IO record
                    persist_model_io(case_dir, k, resp_text or "", meta)

                elif mode_ext == "pred":
                    # External repaired IR is already in place as variant.ll
                    var_ll = sp["variant"]
                    if var_ll.exists():
                        ir_out = var_ll.read_text(encoding="utf-8")
                        ierr = False
                        parse_note = "from_pred_variant"
                        meta = {"source": "from-files", "kind": "pred"}

                if ir_out is None:
                    # Fall back to live LLM generation
                    from ..utils.api.base import make_runner, select_model_cfg
                    runner_cfg = select_model_cfg(cfg)
                    runner = make_runner(runner_cfg)
                    hint = None
                    for fn in ("after.ll",):
                        p = getattr(case, fn.replace(".", "_"), None)
                        if p and p.exists():
                            hint = p.read_text(encoding="utf-8"); break
                    outc = gen_with_retries(runner, prompt)
                    ir_out, ierr, meta_json = extract_ir_and_meta(outc.text)
                    # Be liberal: accept `<IR_OUT>` without a closing tag by taking the rest of the text.
                    parse_note = "ok"
                    if not ir_out:
                        t = outc.text or ""
                        m = re.search(r"<IR_OUT>\s*", t)
                        if m:
                            ir_out = t[m.end():].strip()
                            # Heuristic: consider it parsed if it looks like IR.
                            ierr = False if ("define" in ir_out) else True
                            parse_note = "recovered_from_unclosed_ir_out"
                        elif "define" in t:
                            # As a last resort, treat the whole text as IR if it looks like IR
                            ir_out = t.strip()
                            ierr = False
                            parse_note = "recovered_entire_text"
                    meta = outc.meta or {}
                    persist_model_io(case_dir, k, outc.text, meta)

                p_ok = assemble_ok(ir_out or "", cfg) if ir_out else {"ok": False, "stdout": "", "stderr": "", "rc": 1, "cmd": ""}
                v_ok = verify_ok(ir_out or "", cfg) if p_ok.get("ok") else {"ok": False, "stdout": "", "stderr": "", "rc": 1, "cmd": ""}
                ae = alive_equiv(after_ir, ir_out, cfg, dump_path=sp["alive_merged"]) if (v_ok.get("ok") and ir_out and after_ir) else {"status":"skip"}

                em  = exact_match(after_ir or "", ir_out or "")
                eds = edit_similarity(after_ir or "", ir_out or "")
                b2  = bleu(after_ir or "", ir_out or "", max_n=2)
                b4  = bleu(after_ir or "", ir_out or "", max_n=4)
                rg  = rouge_l(after_ir or "", ir_out or "")

                equiv_status = "skip"; equiv_ok = None
                if ae.get("status") == "ok":
                    equiv_status = "ok"; equiv_ok = bool(ae.get("equiv"))
                elif ae.get("status") == "err":
                    equiv_status = "err"; equiv_ok = None
                eqv = {"status": equiv_status, "ok": equiv_ok}

                shot = {
                    "k": k, "pass": bool(p_ok.get("ok")), "valid": bool(v_ok.get("ok")),
                    "equiv": eqv,
                    "metrics": {"em": em, "edit_sim": eds, "bleu2": b2, "bleu4": b4, "rougeL": rg},
                    "unparseable": bool(ierr),
                    "latency_ms": meta.get("latency_ms"),
                    "tokens": {"in": meta.get("prompt_tokens"), "out": meta.get("out_tokens")},
                    "raw_meta": meta,
                    "parse_note": parse_note,
                }
                shots.append(shot)

                if report_mode():
                    # save variant IR
                    if ir_out: write_text(sp["variant"], ir_out)
                    log = []
                    log.append("## Parse")
                    log.append(parse_note)
                    log.append("## Assemble")
                    log.append(fmt_cmd_block((p_ok.get("cmd") or "").split(), p_ok.get("rc"), p_ok.get("stdout",""), p_ok.get("stderr","")))
                    log.append("## Verify")
                    log.append(fmt_cmd_block((v_ok.get("cmd") or "").split(), v_ok.get("rc"), v_ok.get("stdout",""), v_ok.get("stderr","")))
                    if ae.get("cmd"):
                        log.append("## Alive2")
                        log.append(fmt_cmd_block(ae.get("cmd"), ae.get("exit"), ae.get("out",""), ae.get("err","")))
                    write_text(sp["shotlog"], "\n".join(log) + "\n")
                    write_repro_sh(sp, "refactor", cfg, func=None)

                if (equiv_status == "ok" and equiv_ok is True):
                    from ..utils.runner_common import mark_early_stop
                    mark_early_stop(case_dir, k, "equiv_true")
                    break

        # write per-case metrics.json
        mdir = shot_paths(case_dir, 1)["dir"]; mdir.mkdir(parents=True, exist_ok=True)
        if mode == "reverse":
            summary = {"em": bool(shots and shots[0].get("em"))}
        else:
            summary = _summary_flags(shots)
        write_json(mdir/"metrics.json", {
            "case_id": case.id,
            "mode": mode,
            "shots": shots,
            "summary": summary,
        })

        if mode == "reverse":
            em_flag = "T" if (shots and shots[0].get("em")) else "F"
            print(f"{case.id}: EM={em_flag}")
        else:
            def flag(shots, key, k): return "T" if atk(shots, k, key) else "F"
            print(f"{case.id}: Pass@1={flag(shots,'pass',1)} Pass@5={flag(shots,'pass',5)} | "
                  f"Valid@1={flag(shots,'valid',1)} Valid@5={flag(shots,'valid',5)} | "
                  f"Equiv@1={flag(shots,'equiv',1)} Equiv@5={flag(shots,'equiv',5)}")
