# cirbench/reporting/csv.py
from __future__ import annotations
from pathlib import Path
import os, json, csv
from typing import Dict, Any, List, Optional

def _glob_task_modes(run_dir: Path, task: str) -> List[Path]:
    raw = run_dir / "raw"
    prefix = task + "."
    out = []
    if raw.exists():
        for p in raw.iterdir():
            if p.is_dir() and p.name.startswith(prefix):
                out.append(p)
    return sorted(out, key=lambda x: x.name)

def _read_metrics_json(case_dir: Path) -> Optional[Dict[str, Any]]:
    m = case_dir / "01_artifacts" / "metrics.json"
    if m.exists():
        try:
            return json.loads(m.read_text(encoding="utf-8"))
        except Exception:
            return None
    return None

def _ensure_tables_dir(run_dir: Path) -> Path:
    t = run_dir / "tables"
    t.mkdir(parents=True, exist_ok=True)
    return t

def _write_csv(path: Path, header: List[str], rows: List[List[Any]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8", newline="") as f:
        w = csv.writer(f)
        w.writerow(header)
        for r in rows:
            w.writerow(r)

def _best_shot(shots: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
    if not shots: return None
    # Prefer first equiv-ok; then first valid; then first pass; else shot1
    for s in shots:
        if ((s.get("equiv") or {}).get("ok") is True):
            return s
    for s in shots:
        if s.get("valid") is True:
            return s
    for s in shots:
        if s.get("pass") is True:
            return s
    return shots[0]

def write_analysis_csvs(run_dir: Path) -> None:
    tables = _ensure_tables_dir(run_dir)
    for mode_dir in _glob_task_modes(run_dir, "analysis"):
        mode = mode_dir.name.split(".", 1)[1]
        rows = []
        header = ["case_id","difficulty","em","unparseable","latency_ms","tokens_in","tokens_out"]
        for case_dir in sorted(mode_dir.iterdir()):
            if not case_dir.is_dir(): continue
            m = case_dir / "01_artifacts" / "metrics.json"
            if not m.exists(): continue
            try:
                jj = json.loads(m.read_text(encoding="utf-8"))
            except Exception:
                continue
            case_id = jj.get("id") or case_dir.name
            difficulty = jj.get("difficulty")
            rows.append([
                case_id,
                difficulty,
                bool(jj.get("em")),
                bool(jj.get("unparseable")),
                jj.get("latency_ms"),
                (jj.get("tokens") or {}).get("in"),
                (jj.get("tokens") or {}).get("out"),
            ])
        _write_csv(tables / f"analysis.{mode}.csv", header, rows)

def write_shot_csv(task: str, mode_dir: Path, out_csv: Path) -> None:
    rows = []
    header = [
        "case_id","difficulty","k","pass","valid",
        "equiv_status","equiv_ok","equiv_method","alive_timeout","checksum_equal",
        "t_raw_ms","t_golden_ms","t_variant_ms",
        "speedup_model_raw","speedup_model_golden","speedup_golden_raw",
        "codesize_baseline","codesize_variant","codesize_ratio",
        "unparseable","latency_ms","tokens_in","tokens_out",
        # MCA: raw
        "mca_raw_status","mca_raw_iterations","mca_raw_cycles","mca_raw_instructions",
        "mca_raw_uops","mca_raw_ipc","mca_raw_block_rthroughput",
        # MCA: golden
        "mca_golden_status","mca_golden_iterations","mca_golden_cycles","mca_golden_instructions",
        "mca_golden_uops","mca_golden_ipc","mca_golden_block_rthroughput",
        # MCA: variant
        "mca_variant_status","mca_variant_iterations","mca_variant_cycles","mca_variant_instructions",
        "mca_variant_uops","mca_variant_ipc","mca_variant_block_rthroughput",
    ]
    for case_dir in sorted(mode_dir.iterdir()):
        if not case_dir.is_dir(): continue
        jj = _read_metrics_json(case_dir)
        if not jj: continue
        case_id = jj.get("case_id") or case_dir.name
        difficulty = jj.get("difficulty")
        shots = jj.get("shots") or []
        for s in shots:
            eq = s.get("equiv") or {}
            rt = s.get("runtime") or {}
            cd = s.get("codesize") or {}
            mca = s.get("mca") or {}
            mca_raw = mca.get("raw") or {}
            mca_golden = mca.get("golden") or {}
            mca_variant = mca.get("variant") or {}
            rows.append([
                case_id,
                difficulty,
                s.get("k"),
                bool(s.get("pass")),
                bool(s.get("valid")),
                (eq.get("status")),
                (eq.get("ok")),
                (eq.get("method")),
                bool(eq.get("alive_timeout")),
                (eq.get("checksum_equal")),
                rt.get("t_raw_ms"),
                rt.get("t_golden_ms"),
                rt.get("t_variant_ms"),
                rt.get("speedup_model_raw"),
                rt.get("speedup_model_golden"),
                rt.get("speedup_golden_raw"),
                cd.get("baseline"),
                cd.get("variant"),
                cd.get("ratio"),
                bool(s.get("unparseable")),
                s.get("latency_ms"),
                (s.get("tokens") or {}).get("in"),
                (s.get("tokens") or {}).get("out"),
                # MCA raw
                mca_raw.get("status"),
                mca_raw.get("iterations"),
                mca_raw.get("cycles"),
                mca_raw.get("instructions"),
                mca_raw.get("uops"),
                mca_raw.get("ipc"),
                mca_raw.get("block_rthroughput"),
                # MCA golden
                mca_golden.get("status"),
                mca_golden.get("iterations"),
                mca_golden.get("cycles"),
                mca_golden.get("instructions"),
                mca_golden.get("uops"),
                mca_golden.get("ipc"),
                mca_golden.get("block_rthroughput"),
                # MCA variant
                mca_variant.get("status"),
                mca_variant.get("iterations"),
                mca_variant.get("cycles"),
                mca_variant.get("instructions"),
                mca_variant.get("uops"),
                mca_variant.get("ipc"),
                mca_variant.get("block_rthroughput"),
            ])
    _write_csv(out_csv, header, rows)

def write_best_csv(task: str, mode_dir: Path, out_csv: Path) -> None:
    rows = []
    header = [
        "case_id","difficulty",
        "pass_at_1","pass_at_5","valid_at_1","valid_at_5","equiv_at_1","equiv_at_5",
        "best_k",
        "best_equiv_status","best_equiv_ok","best_equiv_method","best_alive_timeout","best_checksum_equal",
        "best_t_raw_ms","best_t_golden_ms","best_t_variant_ms",
        "best_speedup_model_raw","best_speedup_model_golden","best_speedup_golden_raw",
        "best_codesize_baseline","best_codesize_variant","best_codesize_ratio",
        # MCA (best shot)
        "best_mca_raw_status","best_mca_raw_iterations","best_mca_raw_cycles","best_mca_raw_instructions",
        "best_mca_raw_uops","best_mca_raw_ipc","best_mca_raw_block_rthroughput",
        "best_mca_golden_status","best_mca_golden_iterations","best_mca_golden_cycles","best_mca_golden_instructions",
        "best_mca_golden_uops","best_mca_golden_ipc","best_mca_golden_block_rthroughput",
        "best_mca_variant_status","best_mca_variant_iterations","best_mca_variant_cycles","best_mca_golden_over_variant_cycles","best_mca_variant_instructions",
        "best_mca_variant_uops","best_mca_variant_ipc","best_mca_variant_block_rthroughput",
    ]
    for case_dir in sorted(mode_dir.iterdir()):
        if not case_dir.is_dir(): continue
        jj = _read_metrics_json(case_dir)
        if not jj: continue
        case_id = jj.get("case_id") or case_dir.name
        difficulty = jj.get("difficulty")
        shots = jj.get("shots") or []
        summ  = jj.get("summary") or {}
        # prepare best row
        if not shots:
            rows.append([
                case_id,
                difficulty,
                bool(summ.get("pass_at_1")), bool(summ.get("pass_at_5")),
                bool(summ.get("valid_at_1")), bool(summ.get("valid_at_5")),
                bool(summ.get("equiv_at_1")), bool(summ.get("equiv_at_5")),
                None, None,None,None,False,None,
                None,None,None,
                None,None,None,
                None,None,None,
                # MCA placeholders
                None,None,None,None,None,None,None,
                None,None,None,None,None,None,None,
                None,None,None,None,None,None,None,None,
            ])
            continue
        best = None
        for s in shots:
            if ((s.get("equiv") or {}).get("ok") is True):
                best = s; break
        if best is None:
            for s in shots:
                if s.get("valid") is True:
                    best = s; break
        if best is None:
            for s in shots:
                if s.get("pass") is True:
                    best = s; break
        if best is None:
            best = shots[0]
        eq = best.get("equiv") or {}
        rt = best.get("runtime") or {}
        cd = best.get("codesize") or {}
        mca = best.get("mca") or {}
        mca_raw = mca.get("raw") or {}
        mca_golden = mca.get("golden") or {}
        mca_variant = mca.get("variant") or {}

        gv_cycles_ratio = None
        gc = mca_golden.get("cycles")
        vc = mca_variant.get("cycles")
        if isinstance(gc, (int, float)) and isinstance(vc, (int, float)) and vc:
            gv_cycles_ratio = gc / vc

        rows.append([
            case_id,
            difficulty,
            bool(summ.get("pass_at_1")), bool(summ.get("pass_at_5")),
            bool(summ.get("valid_at_1")), bool(summ.get("valid_at_5")),
            bool(summ.get("equiv_at_1")), bool(summ.get("equiv_at_5")),
            best.get("k"),
            (eq.get("status")), (eq.get("ok")), (eq.get("method")), bool(eq.get("alive_timeout")), (eq.get("checksum_equal")),
            rt.get("t_raw_ms"), rt.get("t_golden_ms"), rt.get("t_variant_ms"),
            rt.get("speedup_model_raw"), rt.get("speedup_model_golden"), rt.get("speedup_golden_raw"),
            cd.get("baseline"), cd.get("variant"), cd.get("ratio"),
            # MCA best
            mca_raw.get("status"),
            mca_raw.get("iterations"),
            mca_raw.get("cycles"),
            mca_raw.get("instructions"),
            mca_raw.get("uops"),
            mca_raw.get("ipc"),
            mca_raw.get("block_rthroughput"),
            mca_golden.get("status"),
            mca_golden.get("iterations"),
            mca_golden.get("cycles"),
            mca_golden.get("instructions"),
            mca_golden.get("uops"),
            mca_golden.get("ipc"),
            mca_golden.get("block_rthroughput"),
            mca_variant.get("status"),
            mca_variant.get("iterations"),
            mca_variant.get("cycles"),
            gv_cycles_ratio,
            mca_variant.get("instructions"),
            mca_variant.get("uops"),
            mca_variant.get("ipc"),
            mca_variant.get("block_rthroughput"),
        ])
    _write_csv(out_csv, header, rows)

def write_task_csvs(run_dir: Path) -> None:
    tables = _ensure_tables_dir(run_dir)
    # analysis
    write_analysis_csvs(run_dir)

    # transform/repair/refactor (all modes present under raw/)
    for task in ("transform","repair","refactor"):
        for mode_dir in _glob_task_modes(run_dir, task):
            mode = mode_dir.name.split(".", 1)[1]
            per_shot = tables / f"{task}.{mode}.per_shot.csv"
            best_at5 = tables / f"{task}.{mode}.best_at_5.csv"
            write_shot_csv(task, mode_dir, per_shot)
            write_best_csv(task, mode_dir, best_at5)
