# cirbench/utils/runner_common.py
from __future__ import annotations
from pathlib import Path
import os, json, time, hashlib, subprocess, shlex, tempfile
from typing import Any, Dict, List, Optional, Tuple
import random
try:
    import fcntl  # POSIX file locking; on non-POSIX this will fail gracefully
except Exception:  # pragma: no cover
    fcntl = None

# ----------------- run/session helpers -----------------

def run_id() -> str:
    return os.getenv("CIRBENCH_RUN_ID") or time.strftime("%Y-%m-%dT%H-%M-%SZ", time.gmtime())

def get_run_dir(proj_root: Path) -> Path:
    rid = run_id()
    rd = proj_root / "runs" / rid
    rd.mkdir(parents=True, exist_ok=True)
    return rd

def _task_mode_suffix(task: str) -> str:
    t = (task or "").strip().lower()
    if t == "repair":
        return (os.getenv("CIRBENCH_REPAIR_MODE") or "normal").strip().lower()
    if t == "refactor":
        return (os.getenv("CIRBENCH_REFACTOR_MODE") or "normal").strip().lower()
    if t == "transform":
        return (os.getenv("CIRBENCH_TRANSFORM_MODE") or "normal").strip().lower()
    # analysis or others
    return (os.getenv("CIRBENCH_ANALYSIS_MODE") or "default").strip().lower()

def get_case_dir(run_dir: Path, task: str, case_id: str) -> Path:
    mode = _task_mode_suffix(task)
    d = run_dir / "raw" / f"{task}.{mode}" / case_id
    d.mkdir(parents=True, exist_ok=True)
    return d

def report_mode() -> bool:
    return os.getenv("CIRBENCH_REPORT_RUN", "0") == "1"

def resume_mode() -> bool:
    return os.getenv("CIRBENCH_RESUME", "0") == "1"

def mkdirs(p: Path) -> None:
    p.parent.mkdir(parents=True, exist_ok=True)

def write_text(p: Path, txt: str) -> None:
    mkdirs(p); p.write_text(txt, encoding="utf-8")

def write_json(p: Path, obj: Any) -> None:
    mkdirs(p); p.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")

def sha256(txt: str) -> str:
    return hashlib.sha256(txt.encode("utf-8")).hexdigest()

def shot_prefix(k: int) -> str:
    return f"{k:02d}"

def shot_paths(case_dir: Path, k: int) -> Dict[str, Path]:
    pre = shot_prefix(k)
    ad = case_dir / f"{pre}_artifacts"
    return {
        "dir": ad,
        "raw": ad / "model.resp.txt",
        "meta": ad / "model_meta.json",
        "variant": ad / "variant.ll",
        "shotlog": ad / "shot.log",
        "alive_merged": ad / "alive_input.ll",
    }

def case_done_flag(case_dir: Path) -> Path:
    return case_dir / "_early_stop.json"

def case_has_early_stop(case_dir: Path) -> bool:
    return case_done_flag(case_dir).exists()

def mark_early_stop(case_dir: Path, k: int, reason: str) -> None:
    write_json(case_done_flag(case_dir), {"k": k, "reason": reason, "ts": time.time()})

def shot_exists(case_dir: Path, k: int) -> bool:
    return shot_paths(case_dir, k)["raw"].exists()

def next_missing_shot(case_dir: Path, max_k: int) -> int:
    for k in range(1, max_k+1):
        if not shot_exists(case_dir, k):
            return k
    return max_k + 1  # none missing

# ----------------- rate limit helpers (cross‑process) -----------------

def _upper(s: str | None) -> str:
    return (s or "unknown").upper()

def rpm_from_env(provider: str, default: int) -> int:
    """
    Read per‑minute request cap from environment.
    Priority:
      CIRBENCH_RATE_<PROVIDER>_PER_MIN
      CIRBENCH_RATE_PER_MIN
      default (function argument)
    """
    v = os.getenv(f"CIRBENCH_RATE_{_upper(provider)}_PER_MIN")
    if v and v.isdigit():
        return max(1, int(v))
    v = os.getenv("CIRBENCH_RATE_PER_MIN")
    if v and v.isdigit():
        return max(1, int(v))
    return max(1, int(default))

class FileRateLimiter:
    """
    Simple fixed‑window (60s) file‑backed limiter shared by all processes of one run.
    State file: runs/&lt;run_id&gt;/.ratelimit/&lt;provider&gt;.json
    """
    def __init__(self, run_id: str, provider: str, per_min: int):
        base = Path("runs") / run_id / ".ratelimit"
        base.mkdir(parents=True, exist_ok=True)
        self.path = base / f"{(provider or 'unknown')}.json"
        self.per_min = max(1, int(per_min))

    def _read_ts(self, f) -> list[float]:
        try:
            f.seek(0)
            raw = f.read()
            data = json.loads(raw) if raw else {}
        except Exception:
            data = {}
        ts = data.get("ts") or []
        now = time.time()
        cutoff = now - 60.0
        # Keep only last 60s
        ts = [float(x) for x in ts if isinstance(x, (int, float)) and x >= cutoff]
        return ts

    def _lock(self, fh):
        if fcntl is not None:
            try:
                fcntl.flock(fh, fcntl.LOCK_EX)
            except Exception:
                pass

    def _unlock(self, fh):
        if fcntl is not None:
            try:
                fcntl.flock(fh, fcntl.LOCK_UN)
            except Exception:
                pass

    def acquire(self, max_wait_s: float = 120.0) -> float:
        """
        Block until a slot is available or max_wait_s elapses.
        Returns suggested remaining wait seconds (0 means acquired).
        """
        start = time.time()
        while True:
            with open(self.path, "a+b") as f:
                self._lock(f)
                ts = self._read_ts(f)
                now = time.time()
                if len(ts) < self.per_min:
                    ts.append(now)
                    f.seek(0)
                    f.truncate()
                    f.write(json.dumps({"ts": ts}).encode())
                    f.flush()
                    self._unlock(f)
                    return 0.0
                wait = max(0.0, 60.0 - (now - min(ts)))
                self._unlock(f)
            if (time.time() - start) > max_wait_s:
                return wait
            time.sleep(min(wait, 1.5))

def _is_rate_limit_err(ex: Exception) -> tuple[bool, float | None]:
    """
    Best‑effort detection of rate‑limit errors. Returns (is429, retry_after_seconds?)
    Tries `ex.retry_after` first if present.
    """
    try:
        ra = getattr(ex, "retry_after", None)
        if isinstance(ra, (int, float)) and ra > 0:
            return True, float(ra)
    except Exception:
        pass
    msg = (str(ex) or "").lower()
    if ("429" in msg) or ("rate limit" in msg) or ("quota" in msg):
        return True, None
    return False, None

# ----------------- model IO -----------------
class Completion:
    def __init__(self, text: str, meta: dict):
        self.text = text
        self.meta = meta

def gen_with_retries(
    runner,
    prompt: str,
    *,
    provider: str | None = None,
    rpm: int | None = None,
    retries: int | None = None,
    sleep_s: float = 1.0,
    max_attempts: int | None = None,
) -> Completion:
    """
    Robust generation with:
      - Cross‑process rate limiting (per‑minute), keyed by (run_id, provider)
      - Automatic backoff on HTTP 429 / rate‑limit errors
      - Transport/other errors retry up to CIRBENCH_MODEL_RETRIES
    Environment:
      CIRBENCH_RATE_<PROVIDER>_PER_MIN   per‑minute cap for given provider (e.g. GEMINI)
      CIRBENCH_RATE_PER_MIN                global cap fallback
      CIRBENCH_RATE_MAX_WAIT_S             max wait for local limiter (default 120s)
      CIRBENCH_MODEL_RETRIES               non‑429 retry count (default 0)
      CIRBENCH_MODEL_MAX_ATTEMPTS          hard cap on total attempts including 429 retries (default 5)
    """
    from .logging_utils import get_logger, debug_on
    import time as _t

    logger = get_logger()
    # Auto-infer provider from runner class name if not supplied
    def _infer_provider_from_runner(obj) -> Optional[str]:
        try:
            n = obj.__class__.__name__.lower()
        except Exception:
            return None
        for key in ("gemini", "qwen", "claude", "gpt", "openai", "deepseek", "grok", "rule"):
            if key in n:
                # map "openai" class names to "gpt" bucket for rate limit purposes
                return "gpt" if key == "openai" else key
        return None

    if provider is None:
        provider = _infer_provider_from_runner(runner)

    if debug_on():
        logger.info(f"DEBUG: gen_with_retries using runner={getattr(runner, '__class__', type(runner)).__name__} provider={provider or '-'}")

    if retries is None:
        try:
            retries = int(os.getenv("CIRBENCH_MODEL_RETRIES", "0"))
        except Exception:
            retries = 0

    # Hard cap on total attempts (includes 429/backoff retries)
    if max_attempts is None:
        try:
            max_attempts = int(os.getenv("CIRBENCH_MODEL_MAX_ATTEMPTS", "5"))
        except Exception:
            max_attempts = 5
    if not isinstance(max_attempts, int) or max_attempts < 1:
        max_attempts = 1

    # Construct limiter if provider specified
    limiter = None
    if provider:
        default_map = {"gemini": 9, "gpt": 60, "claude": 60, "qwen": 120, "deepseek": 60}
        eff_rpm = rpm or rpm_from_env(provider, default_map.get((provider or "").lower(), 60))
        limiter = FileRateLimiter(run_id(), provider, eff_rpm)

    max_wait = float(os.getenv("CIRBENCH_RATE_MAX_WAIT_S", "120") or 120.0)

    last_exc = None
    attempt = 0
    while True:
        attempt += 1

        # Pre‑emptive local limit to avoid tripping remote quota
        if limiter:
            leftover = limiter.acquire(max_wait_s=max_wait)
            if leftover > 0 and debug_on():
                logger.warning(f"[ratelimit] waited ~{max_wait}s but still saturated; sending request anyway")

        try:
            outc = runner.generate([prompt])[0]
            if debug_on():
                logger.info(
                    f"LLM meta: lat={outc.meta.get('latency_ms','?')} in/out={outc.meta.get('prompt_tokens','?')}/{outc.meta.get('out_tokens','?')}"
                )
            return Completion(outc.text or "", outc.meta or {})
        except Exception as ex:
            last_exc = ex
            is429, retry_after = _is_rate_limit_err(ex)
            if is429:
                # Exponential backoff with jitter. Prefer server's Retry‑After if present.
                base = (retry_after if (retry_after and retry_after > 0) else (sleep_s * (2 ** min(attempt, 6))))
                jitter = random.uniform(0.2, 0.6)
                backoff = min(max_wait, base + jitter)
                if debug_on():
                    logger.warning(f"[ratelimit] 429 detected; backoff {backoff:.2f}s (attempt {attempt}/{max_attempts})")
                if attempt >= max_attempts:
                    # hard cap reached
                    return Completion("", {"error": f"gen_failed:max_attempts_exceeded:attempts={attempt} max={max_attempts}"})
                _t.sleep(backoff)
                # Do not consume the model‑retries budget for 429s
                continue

            # Non‑rate‑limit errors: consume retry budget AND respect max_attempts
            if attempt <= (retries or 0) and attempt < max_attempts:
                _t.sleep(sleep_s * attempt)
                continue
            # fail due to retries exhausted or max attempts reached
            return Completion("", {"error": f"gen_failed:retries_exhausted_or_max_attempts:attempts={attempt} retries={retries or 0} max={max_attempts}"})

    # Fallback (should rarely reach here). Include attempts/max for diagnosis.
    err_typ = type(last_exc).__name__ if last_exc else "unknown"
    err_msg = str(last_exc)[:200] if last_exc else "no_exception"
    return Completion("", {"error": f"gen_failed:{err_typ}:{err_msg}", "attempts": attempt, "max_attempts": max_attempts})

def persist_model_io(case_dir: Path, k: int, text: str, meta: dict) -> None:
    sp = shot_paths(case_dir, k)
    write_text(sp["raw"], text or "")
    write_json(sp["meta"], meta or {})

# ----------------- tool invocations -----------------

def run_cmd(cmd: List[str], *, timeout_s: int = 180) -> Dict[str, Any]:
    try:
        p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout_s)
        out = p.stdout.decode(errors="ignore"); err = p.stderr.decode(errors="ignore")
        return {"ok": (p.returncode==0), "exit": p.returncode, "out": out, "err": err, "cmd": cmd}
    except subprocess.TimeoutExpired:
        return {"ok": False, "exit": None, "out": "", "err": f"timeout({timeout_s}s)", "cmd": cmd, "timeout": True}
    except Exception as ex:
        return {"ok": False, "exit": None, "out": "", "err": f"{type(ex).__name__}:{str(ex)[:200]}", "cmd": cmd}

def fmt_cmd_block(cmd: List[str] | None, exitc: int | None, out: str, err: str = "") -> str:
    line = "$ " + " ".join(shlex.quote(x) for x in (cmd or []))
    body = out or ""
    if err:
        body += ("\n--- stderr ---\n" + err)
    return f"{line}\n[exit={exitc}]\n--- stdout+stderr ---\n{body}"

# ----------------- IR helpers -----------------

def extract_ir_and_meta(text: str) -> Tuple[Optional[str], bool, dict]:
    from .parse import extract_structured, extract_ir_out
    meta_json, _ = extract_structured(text or "")
    ir_out, ierr = extract_ir_out(text or "")
    return ir_out, bool(ierr or (not ir_out)), (meta_json or {})

def exact_match_str(a: str, b: str) -> bool:
    return str(a) == str(b)

# ----------------- Aggregation helpers -----------------

def atk(shots: List[dict], k: int, key: str) -> bool:
    sub = shots[:max(1, min(k, len(shots)))]
    if key in ("pass","valid"):
        return any(s.get(key) is True for s in sub)
    if key == "equiv":
        return any((s.get("equiv") or {}).get("ok") is True for s in sub)
    if key == "em":
        return any(s.get("em") is True for s in sub)
    return False

def build_equiv_dict(status: str, ok: Optional[bool], *, method: Optional[str] = None, alive_timeout: bool = False, checksum_equal: Optional[bool] = None) -> dict:
    d = {"status": status, "ok": ok, "method": method, "alive_timeout": bool(alive_timeout)}
    if checksum_equal is not None:
        d["checksum_equal"] = bool(checksum_equal)
    return d

# ----------------- Model selection helpers -----------------

def _norm_model_entry(m: Any) -> Dict[str, Any]:
    """
    Normalize a model entry from CIRBenchConfig into a plain dict:
    {kind, name, params}. Accepts pydantic model, plain dict, or object with
    attributes. Falls back to "rule/golden" if fields are missing.
    """
    if hasattr(m, "model_dump"):
        d = m.model_dump()
    elif isinstance(m, dict):
        d = dict(m)
    else:
        d = {
            "kind": getattr(m, "kind", None),
            "name": getattr(m, "name", None),
            "params": getattr(m, "params", None),
        }
    kind = (d.get("kind") or d.get("provider") or d.get("type") or "rule")
    name = (d.get("name") or d.get("model") or "golden")
    params = d.get("params") or {}
    return {"kind": str(kind).lower(), "name": str(name), "params": params}

def _env_model_override() -> Tuple[Optional[str], Optional[str]]:
    """
    Read model override from env in the following precedence:
      - CIRBENCH_MODEL (format: provider:name)
      - CIRBENCH_MODEL_OVERRIDE (format: provider:name)
      - CIRBENCH_MODEL_PROVIDER + CIRBENCH_MODEL_NAME
    Returns (provider, name); any of them may be None.
    """
    combo = os.getenv("CIRBENCH_MODEL") or os.getenv("CIRBENCH_MODEL_OVERRIDE") or ""
    prov, name = None, None
    if combo and ":" in combo:
        p, n = combo.split(":", 1)
        prov = (p or None)
        name = (n or None)
    prov = os.getenv("CIRBENCH_MODEL_PROVIDER") or prov
    name = os.getenv("CIRBENCH_MODEL_NAME") or name
    if isinstance(prov, str): prov = prov.strip() or None
    if isinstance(name, str): name = name.strip() or None
    return prov, name

def choose_model_cfg(cfg) -> Dict[str, Any]:
    """
    Select a model config dict from CIRBenchConfig given possible CLI/env overrides.
    Falls back to the first entry in cfg.models; if none, returns rule/golden.
    """
    models = getattr(cfg, "models", []) or []
    entries = [_norm_model_entry(m) for m in models]
    prov, name = _env_model_override()

    # Match both provider/kind and name if provided
    if prov or name:
        for e in entries:
            if prov and e["kind"].lower() != prov.lower():
                continue
            if name and e["name"] != name:
                continue
            return {"kind": e["kind"], "name": e["name"], "params": e["params"]}
    # Fallbacks
    if entries:
        e = entries[0]
        return {"kind": e["kind"], "name": e["name"], "params": e["params"]}
    return {"kind": "rule", "name": "golden", "params": {}}

def model_tag(model_cfg: Dict[str, Any]) -> str:
    """Return a short tag like 'qwen:qwen-plus' for logging/run-id composition."""
    k = (model_cfg.get("kind") or model_cfg.get("provider") or "rule")
    n = (model_cfg.get("name") or model_cfg.get("model") or "golden")
    return f"{k}:{n}"

# ----------------- Billing helpers (kept no-op for compatibility) -----------------

def with_billing(meta: dict, provider: str, name: str, pricing) -> dict:
    meta = dict(meta or {})
    meta.setdefault("provider", provider)
    meta.setdefault("model", name)
    return meta

def write_run_billing(run_dir: Path, bills: List[Dict[str, Any]]) -> None:
    if not bills: bills = []
    task = os.getenv("CIRBENCH_TASK", "unknown")
    sid = os.getenv("CIRBENCH_SHARD_ID", "1")
    total = 0.0
    for b in bills:
        total += float(b.get("total_usd") or b.get("usd") or 0.0)
    out = {"total_usd": round(total, 6), "entries": bills}
    p = run_dir / f"billing.{task}.shard-{sid}.json"
    write_json(p, out)

def write_repro_sh(sp: Dict[str, Path], task: str, cfg, *, func: str | None = None) -> None:
    ad = sp["dir"]
    ad.mkdir(parents=True, exist_ok=True)
    variant = sp.get("variant")
    alive_in = sp.get("alive_merged")
    llvm_as = getattr(getattr(cfg, "toolchain", object()), "llvm_as", None) or "llvm-as"
    opt     = getattr(getattr(cfg, "toolchain", object()), "opt", None) or "opt"
    alive   = getattr(getattr(cfg, "toolchain", object()), "alive_tv", None) or "alive-tv"

    lines = [
        "#!/usr/bin/env bash",
        "set -euo pipefail",
        'cd "$(dirname "$0")"',
        'echo "[repro] artifacts dir: $(pwd)"',
        "",
        f'VAR="{variant.name if variant else "variant.ll"}"',
        'if [[ -f "$VAR" ]]; then',
        f'  echo "== llvm-as"; {shlex.quote(llvm_as)} "$VAR" -o a.bc || true',
        f'  echo "== opt verify"; {shlex.quote(opt)} -passes=verify "$VAR" -disable-output || true',
        "fi",
    ]
    if alive_in and alive_in.exists():
        lines += [
            f'ALIVE="{alive_in.name}"',
            'if [[ -f "$ALIVE" ]]; then',
            f'  echo "== alive-tv"; {shlex.quote(alive)} "$ALIVE" -src-fn=src -tgt-fn=tgt --quiet || true',
            "fi",
        ]
    lines += [f'echo "[done] task={task} func={func or "-"}"']
    sh = ad / "repro.sh"
    write_text(sh, "\n".join(lines) + "\n")
    try:
        os.chmod(sh, 0o755)
    except Exception:
        pass

def _parse_shot_k(shot_dir: Path) -> int:
    try:
        base = shot_dir.name
        pre = base.split("_", 1)[0]
        return int(pre)
    except Exception:
        return 1

def _ensure_prompt_files(shot_dir: Path, prompt_text: str) -> None:
    p_txt = shot_dir / "prompt.txt"
    p_sha = shot_dir / "prompt.sha256"
    write_text(p_txt, prompt_text or "")
    write_text(p_sha, sha256(prompt_text or ""))

def maybe_short_circuit_prompt_only(cfg, shot_dir: Path, prompt_text: str) -> bool:
    try:
        if not getattr(cfg, "prompt_only", False):
            return False
        _ensure_prompt_files(shot_dir, prompt_text or "")
        (shot_dir / "STOPPED_BEFORE_LLM").write_text("1\n", encoding="utf-8")
        case_dir = shot_dir.parent
        k = _parse_shot_k(shot_dir)
        mark_early_stop(case_dir, k, reason="prompt_only")
        return True
    except Exception:
        return True

def maybe_materialize_external_io(cfg, case_id: str, shot_dir: Path) -> str:
    root = getattr(cfg, "from_files", None)
    if not root:
        return "none"
    kind = (getattr(cfg, "from_kind", "resp") or "resp").strip().lower()
    task = (os.getenv("CIRBENCH_TASK") or "").strip().lower()
    k = _parse_shot_k(shot_dir)
    root = Path(root)

    def _cands_resp():
        return [
            root / case_id / f"shot-{k}" / "model.resp.txt",
            root / case_id / "model.resp.txt",
            root / f"{case_id}.resp.txt",
        ]

    def _cands_pred():
        return [
            root / case_id / f"shot-{k}" / "pred.ll",
            root / case_id / f"shot-{k}" / "variant.ll",
            root / case_id / f"shot-{k}" / "pred.json",
            root / case_id / "pred.ll",
            root / case_id / "variant.ll",
            root / case_id / "pred.json",
            root / f"{case_id}.ll",
            root / f"{case_id}.pred.ll",
            root / f"{case_id}.pred.json",
        ]

    if kind == "resp":
        for p in _cands_resp():
            if p.exists():
                write_text(shot_dir / "model.resp.txt", p.read_text(encoding="utf-8"))
                return "resp"
        return "none"

    import shutil
    if kind == "pred":
        for p in _cands_pred():
            if not p.exists():
                continue
            if task in ("transform", "repair", "refactor"):
                if p.suffix == ".json":
                    shutil.copyfile(p, shot_dir / "pred.json")
                    return "pred"
                else:
                    shutil.copyfile(p, shot_dir / "variant.ll")
                    return "pred"
            else:
                if p.suffix == ".json":
                    shutil.copyfile(p, shot_dir / "pred.json")
                    return "pred"
                else:
                    shutil.copyfile(p, shot_dir / "variant.ll")
                    return "pred"
        return "none"

    return "none"