# Minimal utilities to reproduce your 4-stage verifier inside lm-eval.

import re
import os
import math
from functools import lru_cache
from typing import List, Dict, Any
try:
    from datasets import Dataset
except Exception:
    Dataset = None  # we'll guard below


# ---- Debug toggle ----
DEBUG = os.environ.get("PASSK_DEBUG", "0") == "1"
def dprint(*args, **kwargs):
    if DEBUG:
        print("[passk-debug]", *args, **kwargs, flush=True)


# ---------------- Extraction ----------------
ANSWER_CUE_RE = re.compile(
    r'(?i)(?:^|[\n\r>]\s*)(?:\*\*|\*|__|~~|`|>\s*)*'
    r'(?:final\s+answer|the\s+final\s+answer|the\s+answer|answer|ans(?:wer)?|result|output|prediction|choice|option)'
    r'\s*(?:is|=|:|->)?\s*'
)

def first_sentence(s: str) -> str:
    s = (s or "").strip()
    m = re.search(r'(.+?)(?:[.!?](?:\s|$)|$)', s, flags=re.S)
    return m.group(1).strip() if m else s

def last_sentence(s: str) -> str:
    s = (s or "").strip()
    # split on sentence enders; take the last non-empty
    parts = re.split(r'(?<=[.!?])\s+', s)
    for seg in reversed(parts):
        seg = seg.strip()
        if seg:
            return seg
    return s 

def contains_as_word(haystack_norm: str, needle_norm: str) -> bool:
    """
    Word-boundary containment on normalized strings.
    """
    if not needle_norm:
        return False
    # Escape needle and match as whole token(s)
    pat = r'\b' + re.escape(needle_norm) + r'\b'
    return re.search(pat, haystack_norm) is not None

def extract_final_answer(gen: str):
    m = ANSWER_CUE_RE.search(gen or "")
    if not m:
        return None
    tail = (gen[m.end():] or "")

    # Take the first non-empty line if present; otherwise allow empty string
    # so downstream logic can fall back.
    for line in tail.splitlines():
        s = line.strip().rstrip(" .;,:“”'`\"")
        if s:
            return s
    # nothing on the same/next line after the cue
    return tail.strip().rstrip(" .;,:“”'`\"")


def extract_final_answer_or_last_sentence(text: str) -> str:
    ext = extract_final_answer(text)
    return ext if (ext and ext.strip()) else last_sentence(text or "")

# ---------------- Normalization & token F1 ----------------
_PUNC = r"""!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"""
TRANS = str.maketrans({c: " " for c in _PUNC})
STOP = {
    "a","an","the","and","or","of","to","in","on","for","with","at","by","from",
    "is","are","was","were","be","been","being","as","that","this","these","those",
    "it","its","their","his","her","they","them","he","she","we","you","i"
}

def normalize_text(s: str) -> str:
    s = (s or "").lower().strip()
    s = s.translate(TRANS)
    s = re.sub(r"\s+", " ", s)
    return s

def token_set(s: str, remove_stopwords: bool = True):
    toks = normalize_text(s).split()
    if remove_stopwords:
        toks = [t for t in toks if t not in STOP]
    return toks

def f1_tokens(pred_tokens, gold_tokens) -> float:
    P, G = list(pred_tokens), list(gold_tokens)
    if not P and not G:
        return 1.0
    if not P or not G:
        return 0.0
    inter = 0
    gc = {}
    for g in G:
        gc[g] = gc.get(g, 0) + 1
    for p in P:
        if gc.get(p, 0) > 0:
            inter += 1
            gc[p] -= 1
    prec = inter / max(len(P), 1)
    rec  = inter / max(len(G), 1)
    return 0.0 if (prec + rec == 0) else 2 * prec * rec / (prec + rec)

# ---------------- SBERT (lazy / optional) ----------------
@lru_cache(maxsize=1)
def _load_sbert(model_name: str):
    try:
        from sentence_transformers import SentenceTransformer
        return SentenceTransformer(model_name)
    except Exception:
        return None

@lru_cache(maxsize=4096)
def _embed(text: str, model_name: str):
    mdl = _load_sbert(model_name)
    if mdl is None:
        return None
    import numpy as np
    v = mdl.encode([text], normalize_embeddings=True)[0]
    return v

def _cos(a, b):
    import numpy as np
    return float((a * b).sum())

# ---------------- Core 4-stage check ----------------
def four_stage_ok_extracted(pred_full: str, gold: str,
                            f1_threshold: float = 0.8,
                            sim_threshold: float = 0.82,
                            remove_stopwords: bool = True,
                            sbert_model_name: str = "sentence-transformers/all-MiniLM-L6-v2") -> bool:
    """
    4-stage verifier applied ONLY to the extracted answer (cue -> line, else last sentence).
    No checks on the full generation.
    """
    pred_full = pred_full or ""
    gold = gold or ""

    extracted = extract_final_answer_or_last_sentence(pred_full)
    extracted_norm = normalize_text(extracted)
    gold_norm = normalize_text(gold)

    # 1) exact match (normalized)
    if extracted_norm == gold_norm:
        return True

    # 2) token-bounded containment (normalized)
    if contains_as_word(extracted_norm, gold_norm):
        return True

    # 3) token F1
    if f1_tokens(token_set(extracted, remove_stopwords),
                 token_set(gold, remove_stopwords)) >= f1_threshold:
        return True

    # 4) SBERT cosine (if available)
    gv = _embed(extracted, sbert_model_name)
    tv = _embed(gold,      sbert_model_name)
    if gv is not None and tv is not None and _cos(gv, tv) >= sim_threshold:
        return True

    return False


# ---------- Robust numeric extraction / normalization ----------

from decimal import Decimal, InvalidOperation, getcontext
getcontext().prec = 28  # safe precision for fractions → decimals

# Patterns in descending priority:
#  1) mixed numbers: "a b/c" (e.g., "1 1/2 hours")
_MIXED_RE = re.compile(r'([+-]?\d{1,3}(?:,\d{3})*|\d+)\s+(\d+)\s*/\s*(\d+)')
#  2) simple fractions: "a/b" (e.g., "3/4 kg")
_FRAC_RE = re.compile(r'([+-]?\d+)\s*/\s*(\d+)')
#  3) plain numbers (ints/decimals, allow thousands commas)
_NUM_RE  = re.compile(r'([+-]?\d{1,3}(?:,\d{3})*|\d+)(\.\d+)?')

def _strip_commas(s: str) -> str:
    return s.replace(",", "")

def _canon_from_decimal(d: Decimal) -> str:
    # Convert to string without trailing zeros or trailing decimal point
    s = format(d.normalize(), "f")
    if "." in s:
        s = s.rstrip("0").rstrip(".")
    return s or "0"

def _try_decimal(s: str) -> Decimal | None:
    try:
        return Decimal(s)
    except (InvalidOperation, ValueError):
        return None

def canonicalize_number_in_text(s: str) -> str:
    """
    Extracts and canonicalizes the FIRST meaningful number in the string.
    Handles:
      - "36 minutes"          → "36"
      - "1,200 cm"            → "1200"
      - "40%"                 → "40"     (we keep the numeric part, not 0.4)
      - "1 1/2 hours"         → "1.5"
      - "-3/4"                → "-0.75"
      - "+12.00 USD"          → "12"
    If no number is found, returns the original string stripped.
    """
    s0 = (s or "").strip()
    if not s0:
        return s0

    # 1) Mixed number: "a b/c"
    m = _MIXED_RE.search(s0)
    if m:
        whole, num, den = m.groups()
        d_whole = _try_decimal(_strip_commas(whole))
        d_num   = _try_decimal(num)
        d_den   = _try_decimal(den)
        if d_whole is not None and d_num is not None and d_den not in (None, Decimal(0)):
            val = d_whole + (d_num / d_den if d_whole >= 0 else -(d_num / d_den))
            return _canon_from_decimal(val)

    # 2) Simple fraction: "a/b"
    m = _FRAC_RE.search(s0)
    if m:
        num, den = m.groups()
        d_num = _try_decimal(num)
        d_den = _try_decimal(den)
        if d_num is not None and d_den not in (None, Decimal(0)):
            val = d_num / d_den
            return _canon_from_decimal(val)

    # 3) Plain number (int or decimal)
    m = _NUM_RE.search(s0)
    if m:
        intpart, dec = m.groups()
        core = _strip_commas(intpart) + (dec or "")
        d = _try_decimal(core)
        if d is not None:
            return _canon_from_decimal(d)

    # Fallback: nothing numeric recognized
    return s0


# -------- Robust extraction from a generation → list of canonical numbers -----

_NUMLIKE_RE = re.compile(
    r"""
    (?ix)                                   # case-insensitive, verbose
    # Prefer things *after* an answer cue, but we'll also scan the whole text.
    (?:                                     
        # A cue anywhere, with optional markdown clutter and glue words:
        (?:^|[\n\r>]\s*)(?:\*\*|\*|__|~~|`|>\s*)*
        (?:final\s+answer|the\s+final\s+answer|the\s+answer|answer|ans(?:wer)?|result|output|prediction|choice|option)
        \s*(?:is|=|:|->)?\s*
    )?
    # Then a number-ish token (we'll canon later):
    (
       [+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?      # 1,234.56  or  12  or  -3.5
       | [+-]?\d+\s+\d+/\d+                   # 1 1/2    (mixed)
       | [+-]?\d+/\d+                         # 3/4
    )
    """,
)

def _dedupe_keep_order(items):
    seen = set()
    out = []
    for x in items:
        if x not in seen:
            out.append(x)
            seen.add(x)
    return out

def _primary_slice_after_cue(text: str) -> str:
    m = ANSWER_CUE_RE.search(text or "")
    return text[m.end():] if m else (text or "")

def extract_numeric_candidates(gen: str) -> list[str]:
    """
    Return a ranked list of canonical numeric strings from a generation.
    Ranking:
      1) First number *after* an explicit answer cue (if any)
      2) Other numbers after the cue
      3) Numbers anywhere in the text (front-to-back)
    Each item is already canonicalized (commas removed, units stripped, fractions→decimals).
    """
    s = gen or ""
    out: list[str] = []

    # 1) Prefer matches AFTER cue
    tail = _primary_slice_after_cue(s)
    tail_matches = [m.group(1) for m in _NUMLIKE_RE.finditer(tail)]
    for raw in tail_matches:
        out.append(canonicalize_number_in_text(raw))

    # 2) Fallback: scan the whole generation
    all_matches = [m.group(1) for m in _NUMLIKE_RE.finditer(s)]
    for raw in all_matches:
        out.append(canonicalize_number_in_text(raw))

    # 3) As a very last resort, take last sentence canonicalized
    if not out:
        out.append(canonicalize_number_in_text(last_sentence(s)))

    # Clean, dedupe, and drop empties
    out = [x for x in _dedupe_keep_order(out) if x]
    return out

# ---- lm-eval filter shim: takes a raw string, returns a list of strings ----
def filter_numeric_from_generation(s: str) -> list[str]:
    """
    lm-eval 'filter' function: given a single generation string,
    return a list of candidate answers (strings). Use with 'take_first'.
    """
    return extract_numeric_candidates(s)



# --- Strict numeric extraction: FIRST number after an answer cue only ---

def extract_number_after_cue(gen: str) -> str:
    tail = _primary_slice_after_cue(gen)
    m = _NUMLIKE_RE.search(tail)
    out = canonicalize_number_in_text(m.group(1)) if m else ""
    dprint("extract_number_after_cue:",
           {"has_cue": bool(ANSWER_CUE_RE.search(gen or "")),
            "found": bool(out), "value": out[:40]})
    return out

def extract_number_after_cue_or_last(gen: str) -> str:
    n = extract_number_after_cue(gen)
    if n:
        return n
    last = canonicalize_number_in_text(last_sentence(gen or ""))
    dprint("fallback_last_sentence_value:", last[:40])
    return last


# ---------------- pass@k metrics over a list of 0/1 values ----------------

def passk_stderr(per_gen: List[float]) -> float:
    """
    Given list of 0/1 per-generation correctness values, return pass@k stderr.
    """
    n = len(per_gen)
    if n == 0:
        return float("nan")
    c = sum(1 for v in per_gen if v > 0)
    if c == 0 or c == n:
        return 0.0
    p = c / n
    se = math.sqrt(p * (1 - p) / n)
    dprint("passk_stderr:", {"n": n, "c": c, "p": p, "stderr": se})
    return se


def number_em(references=None, predictions=None, *, fallback_to_last=False, **_):
    """
    Returns a list[float] of length k: 1.0 per generation if its FIRST number
    after an answer cue equals the gold number (canonicalized).
    If fallback_to_last=True, falls back to last-sentence number when no cue/number.
    """
    gold = canonicalize_number_in_text(references[0] or "")
    preds = predictions[0] if isinstance(predictions, list) else predictions
    if not isinstance(preds, (list, tuple)):  # single string
        preds = [preds]
    scores = []
    extracted = []
    for gen in preds:
        p = extract_number_after_cue(gen or "")
        s = 1.0 if (gold and p and p == gold) else 0.0
        scores.append(s)
        extracted.append(p)
    dprint("number_em:",
           {"k": len(preds), "gold": gold, "extracted": extracted, "scores": scores})
    
    return 1.0 if 1.0 in scores else 0.0
    # return scores

def number_em_shadow(references=None, predictions=None, *, fallback_to_last=False, **_):
    return number_em(references=references, predictions=predictions, fallback_to_last=fallback_to_last, **_)

def fs4_per_gen_list(references=None, predictions=None, **_):
    """
    Returns list[float] of length k: 1.0 per generation if 4-stage passes when
    applied to the extracted short answer (cue→line else last sentence).
    """
    gold = references[0]
    preds = predictions[0] if isinstance(predictions, list) else predictions
    if not isinstance(preds, (list, tuple)):
        preds = [preds]
    scores, extracted = [], []
    for gen in preds:
        ext = extract_final_answer_or_last_sentence(gen or "")
        ok = four_stage_ok_extracted(ext, gold)
        scores.append(1.0 if ok else 0.0)
        extracted.append(ext)
    dprint("fs4_per_gen_list:",
           {"k": len(preds), "gold": gold[:60], "extracted": extracted, "scores": scores})
    return 1.0 if 1.0 in scores else 0.0

def fs4_per_gen_list_shadow(references=None, predictions=None, **_):  
    return fs4_per_gen_list(references=references, predictions=predictions, **_)


# -------- process_docs for MuSiQue: keep only supporting paragraph(s) --------

def musique_process_docs(dataset):
    """
    Input split: HF Dataset with fields:
      - "Paragraphs": [{idx,title,paragraph_text,is_supporting}, ...]
      - "question": str
      - "answer": str
      - "question_decomposition": [{..., paragraph_support_idx: int}, ...]
    Output: HF Dataset with fields: question, answer, support_text
    """
    def as_text(x):
        if isinstance(x, list):
            return " ".join(str(s) for s in x)
        return "" if x is None else str(x)

    records: List[Dict[str, Any]] = []

    for d in dataset:
        q = d.get("question", "")
        ans = d.get("answer", "")

        paras = d.get("Paragraphs") or d.get("paragraphs") or []
        # idx -> text and original order
        idx_to_text = {}
        order = []
        for p in paras:
            if not isinstance(p, dict): 
                continue
            idx = p.get("idx")
            if idx is None:
                continue
            order.append(idx)
            idx_to_text[idx] = as_text(p.get("paragraph_text"))

        # supporting indices from flag
        support_idxs = set()
        for p in paras:
            if isinstance(p, dict) and p.get("is_supporting") is True and p.get("idx") is not None:
                support_idxs.add(p["idx"])

        # supporting indices from question_decomposition
        for item in d.get("question_decomposition") or []:
            if isinstance(item, dict):
                psi = item.get("paragraph_support_idx")
                if isinstance(psi, int):
                    support_idxs.add(psi)

        # collect support paragraphs in original order
        support_texts = [idx_to_text[i] for i in order if i in support_idxs and idx_to_text.get(i)]
        support_text = " ".join(support_texts).strip()

        # fallback to first paragraph if nothing matched (avoid empty context)
        if not support_text and order:
            support_text = idx_to_text.get(order[0], "")

        records.append({
            "question": q,
            "answer": ans,
            "support_text": support_text,
        })

    return _to_hf_dataset(records)


# -------- process_docs for HotpotQA: keep only supporting sentences --------
def _to_hf_dataset(records: List[Dict[str, Any]]):
    """
    Convert list of dicts to a HF Dataset so lm-eval can read `.features`.
    Falls back to the raw list if datasets isn't available (but lm-eval wants Dataset).
    """
    if Dataset is None:
        raise RuntimeError(
            "utils._to_hf_dataset: `datasets` library not found. "
            "Install with `pip install datasets`."
        )
    return Dataset.from_list(records)


def hotpot_process_docs(dataset):
    """
    Supports HotpotQA contexts as either:
      A) dict of lists: {"title": [...], "sentences": [[...], ...]}
      B) list of pairs/dicts: [[title, [sents...]], ...] or [{"title":..., "sentences":[...]}]

    Supports supporting_facts as:
      1) dict of parallel lists: {"title":[...], "sent_id":[...] }
      2) list of [title, sent_id]
      3) list of dicts: {"title":..., "sent_id":...} (or "sent_idx")
    """
    records = []

    for d in dataset:
        q = d.get("question", "")
        ans = d.get("answer", "")
        ctx = d.get("context", None)

        # ---- Build title -> sentences[] map (works for both context schemas) ----
        sent_map = {}

        if isinstance(ctx, dict) and "title" in ctx and "sentences" in ctx:
            titles = ctx.get("title") or []
            sentences = ctx.get("sentences") or []
            n = min(len(titles), len(sentences))
            for i in range(n):
                t = str(titles[i])
                sents_i = sentences[i] or []
                sent_map[t] = [str(s) for s in sents_i]
        elif isinstance(ctx, list):
            for item in ctx:
                if isinstance(item, (list, tuple)) and len(item) >= 2:
                    t, sents = item[0], item[1]
                    sent_map[str(t)] = [str(s) for s in (sents or [])]
                elif isinstance(item, dict):
                    t = item.get("title", "")
                    sents = item.get("sentences") or item.get("text") or []
                    if isinstance(sents, str):
                        sents = [sents]
                    sent_map[str(t)] = [str(s) for s in sents]

        # ---- Normalize supporting_facts into a list of (title, idx) ----
        sup_pairs = []
        sf = d.get("supporting_facts", [])

        if isinstance(sf, dict) and "title" in sf and ("sent_id" in sf or "sent_idx" in sf):
            titles = sf.get("title") or []
            idxs = sf.get("sent_id", sf.get("sent_idx")) or []
            m = min(len(titles), len(idxs))
            for i in range(m):
                try:
                    sup_pairs.append((str(titles[i]), int(idxs[i])))
                except Exception:
                    continue
        elif isinstance(sf, list):
            for item in sf:
                if isinstance(item, (list, tuple)) and len(item) >= 2:
                    sup_pairs.append((str(item[0]), int(item[1])))
                elif isinstance(item, dict):
                    t = str(item.get("title", ""))
                    i = item.get("sent_id", item.get("sent_idx", -1))
                    try:
                        sup_pairs.append((t, int(i)))
                    except Exception:
                        continue

        # ---- Collect supporting sentences in provided order ----
        support_sents = []
        for t, i in sup_pairs:
            if t in sent_map and 0 <= i < len(sent_map[t]):
                support_sents.append(sent_map[t][i])

        support_text = " ".join(support_sents).strip()

        # ---- Minimal fallback to keep the prompt usable ----
        if not support_text and sent_map:
            first_title = next(iter(sent_map.keys()))
            first_list = sent_map[first_title]
            if first_list:
                support_text = first_list[0]

        records.append({
            "question": q,
            "answer": ans,
            "support_text": support_text,
        })

    return _to_hf_dataset(records)


# --- helper to normalize AQUA options into {A..E: text} ---
_AQUA_LETTERS = ["A","B","C","D","E"]
_AQUA_OPT_RE = re.compile(r"^\s*([A-E])[\)\.\:\-]?\s*(.*)$", re.IGNORECASE)

def _aqua_letter_map(options):
    """
    Returns dict {A..E: text} stripping leading 'A) ', 'B. ', etc.
    Falls back to positional mapping if options don't embed letters.
    """
    mapping = {}
    saw_letters = set()
    opts = list(options or [])
    for raw in opts:
        s = "" if raw is None else str(raw)
        m = _AQUA_OPT_RE.match(s)
        if m:
            L = m.group(1).upper()
            txt = m.group(2).strip()
            mapping[L] = txt
            saw_letters.add(L)
    # If some letters missing, fill by position
    for i, raw in enumerate(opts):
        if i < len(_AQUA_LETTERS):
            L = _AQUA_LETTERS[i]
            if L not in mapping:
                s = "" if raw is None else str(raw)
                m = _AQUA_OPT_RE.match(s)
                mapping[L] = (m.group(2).strip() if m else s.strip())
    return mapping


def aquarat_extract_target(doc):
    """
    Returns the canonical numeric string for the correct option.
    Examples:
      options: ["A) 36 minutes", ...], correct: "C" → "36"
    """
    corr = str(doc.get("correct","")).strip().upper()
    mapping = _aqua_letter_map(doc.get("options"))
    gold_raw = mapping.get(corr, "")
    return canonicalize_number_in_text(gold_raw)



# --- helper to extract GSM8K target answer ---
def gsm8k_extract_target(doc):
    """
    GSM8K: the gold 'answer' field looks like:
      "Let's reason... 12 * 2 = 24. #### 24"
    We only want the part after '####'.
    """
    ans = str(doc.get("answer", "")).strip()
    if "####" in ans:
        return ans.split("####")[-1].strip()
    # fallback: just return the whole field if malformed
    return ans


# ---------------- Helpers for lm-eval I/O ----------------
def _get_text_from_results(results):
    """
    results can be a str, a list[str], list[list[str]], or objects with .text
    We take the first decode.
    """
    x = results
    if isinstance(x, (list, tuple)):
        x = x[0] if x else ""
        if isinstance(x, (list, tuple)):
            x = x[0] if x else ""
    # object with .text
    t = getattr(x, "text", None)
    return t if t is not None else (x if isinstance(x, str) else str(x))

def _gold_from_mc(doc) -> str:
    # Works for OpenBookQA/QASC where doc["choices"] has "label" & "text"
    labels = doc["choices"]["label"]
    texts  = doc["choices"]["text"]
    goldL  = doc["answerKey"]
    for i, L in enumerate(labels):
        if L == goldL:
            return texts[i]
    return ""  # should not happen


# ---------------- process_results entry points ----------------
def process_results_openbookqa_freeform(doc, results):
    pred_full = _get_text_from_results(results)
    gold_text = _gold_from_mc(doc)
    ok = four_stage_ok_extracted(
        pred_full, gold_text,
        f1_threshold=float(os.environ.get("FS_F1", 0.80)),
        sim_threshold=float(os.environ.get("FS_SIM", 0.82)),
        remove_stopwords=os.environ.get("FS_RM_STOP", "1") != "0",
        sbert_model_name=os.environ.get("FS_SBERT", "sentence-transformers/all-MiniLM-L6-v2"),
    )
    return {"acc_fs4": 1.0 if ok else 0.0}


def process_results_qasc_freeform(doc, results):
    pred_full = _get_text_from_results(results)
    gold_text = _gold_from_mc(doc)
    ok = four_stage_ok_extracted(
        pred_full, gold_text,
        f1_threshold=float(os.environ.get("FS_F1", 0.80)),
        sim_threshold=float(os.environ.get("FS_SIM", 0.82)),
        remove_stopwords=os.environ.get("FS_RM_STOP", "1") != "0",
        sbert_model_name=os.environ.get("FS_SBERT", "sentence-transformers/all-MiniLM-L6-v2"),
    )
    return {"acc_fs4": 1.0 if ok else 0.0}


# common process results for free-form QA like musique, hotpotqa, etc.
def process_results_freeform(doc, results):
    pred_full = _get_text_from_results(results)
    gold = doc.get("answer", "")
    ok = four_stage_ok_extracted(
        pred_full, gold,
        f1_threshold=float(os.environ.get("FS_F1", 0.80)),
        sim_threshold=float(os.environ.get("FS_SIM", 0.82)),
        remove_stopwords=os.environ.get("FS_RM_STOP", "1") != "0",
        sbert_model_name=os.environ.get("FS_SBERT", "sentence-transformers/all-MiniLM-L6-v2"),
    )
    return {"acc_fs4": 1.0 if ok else 0.0}
