import os, re, json
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import torch
from torch import nn
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application

BASED_DIR = Path(r"....")
IN_PATH   = BASED_DIR / "output" / "example.json"
OUT_DIR   = BASED_DIR / "evaluation_report"
OUT_JSON  = OUT_DIR / "metrics_report.json"
OUT_JSONL = OUT_DIR / "per_item.jsonl"

MATH_EMBED_CANDIDATES = [
    "tbs17/MathBERT",
    "witiko/mathberta",
    "tbs17/MathBERT-custom",
    "AnReu/math_pretrained_bert",
]
GEN_EMBED_FALLBACKS = [
    "sentence-transformers/all-mpnet-base-v2",
    "sentence-transformers/all-MiniLM-L6-v2",
]
GEN_NLI_FALLBACKS = [
    "roberta-large-mnli",
    "typeform/distilbert-base-uncased-mnli",
]

BATCH_SIZE_EMB = int(os.getenv("BATCH_SIZE_EMB", "32"))
BATCH_SIZE_NLI = int(os.getenv("BATCH_SIZE_NLI", "16"))
MAX_STEPS_FOR_FULL_NLI = int(os.getenv("MAX_STEPS_FOR_FULL_NLI", "30"))

has_torch = True
has_tf = True

DEVICE = "cuda" if has_torch and hasattr(torch, "cuda") and torch.cuda.is_available() else "cpu"

SYM_TRANSFORMS = standard_transformations + (implicit_multiplication_application,)

def normalize_unicode_math(s: str) -> str:
    if not isinstance(s, str): return ""
    t = s
    t = t.replace("π", "pi").replace("Π", "pi")
    t = t.replace("√", "sqrt")
    t = re.sub(r"\bsqrt\s*([0-9a-zA-Z]+)", r"sqrt(\1)", t)
    t = re.sub(r"(\d|\))\s*(sqrt\s*\()", r"\1*\2", t)
    t = re.sub(r"(\d)\s*([a-zA-Z])", r"\1*\2", t)
    t = re.sub(r"\bln\s*\(", "log(", t, flags=re.I)
    t = re.sub(r"\s+", " ", t).strip()
    return t

def simple_latex_to_ascii(s: str) -> str:
    s = s.strip().replace("\\(", "").replace("\\)", "")
    s = s.strip("$").replace("\\left", "").replace("\\right", "")
    s = re.sub(r"\|([^|]+)\|", r"Abs(\1)", s)
    s = re.sub(r"\\vert\s*([^\\]+?)\s*\\vert", r"Abs(\1)", s)
    s = re.sub(r"\\lvert\s*([^\\]+?)\s*\\rvert", r"Abs(\1)", s)

    s = re.sub(r"\\text\{[^}]*\}", "", s)
    s = re.sub(r"\\dfrac\{([^{}]+)\}\{([^{}]+)\}", lambda m: f"({m.group(1)})/({m.group(2)})", s)
    s = re.sub(r"\\frac\{([^{}]+)\}\{([^{}]+)\}",    lambda m: f"({m.group(1)})/({m.group(2)})", s)
    s = re.sub(r"\\sqrt\{([^{}]+)\}", r"sqrt(\1)", s)

    s = s.replace("\\cdot", "*").replace("\\times", "*")
    s = s.replace("\\pi", "pi").replace("π", "pi").replace("√", "sqrt")

    s = re.sub(r"\\sin\b", "sin", s)
    s = re.sub(r"\\cos\b", "cos", s)
    s = re.sub(r"\\tan\b", "tan", s)
    s = re.sub(r"\\ln\b",  "log", s)
    s = re.sub(r"\\log\b", "log", s)
    s = s.replace("^", "**")
    s = re.sub(r"\bsqrt\s*([0-9a-zA-Z])", r"sqrt(\1)", s)
    s = re.sub(r"(\d)\s*([a-zA-Z])", r"\1*\2", s)
    s = re.sub(r"\+\s*C\b\.?$", "", s)

    s = re.sub(r"\s+", " ", s).strip()
    return s


def to_sympy(text: Optional[str]) -> Optional[sp.Expr]:
    if not isinstance(text, str) or not text.strip():
        return None
    try:
        t = simple_latex_to_ascii(text)
        if "=" in t and "==" not in t:
            left, right = t.split("=", 1)
            left = parse_expr(left,  transformations=SYM_TRANSFORMS)
            right = parse_expr(right, transformations=SYM_TRANSFORMS)
            return sp.simplify(left - right)
        return parse_expr(t, transformations=SYM_TRANSFORMS)
    except Exception:
        return None

def sympy_equivalent(a: Optional[str], b: Optional[str]) -> Optional[bool]:
    if not a or not b:
        return None
    try:
        ea, eb = to_sympy(a), to_sympy(b)
        if ea is None or eb is None:
            return None
        return bool(sp.simplify(ea - eb) == 0)
    except Exception:
        return None

def extract_final_answer(model_answer: Optional[str]) -> Optional[str]:
    if not isinstance(model_answer, str) or not model_answer.strip():
        return None
    text = model_answer.strip()

    m = re.search(r"\\boxed\s*\{(.+?)\}", text, flags=re.S)
    if m:
        return normalize_unicode_math(m.group(1).strip())

    m = re.search(r"(?:final\s*answer\s*[:\-]?\s*)(.*)$", text, flags=re.I|re.S)
    if m:
        tail = m.group(1).strip()
        tail = re.sub(r"\\\((.+?)\\\)", r"\1", tail, flags=re.S)
        tail = re.sub(r"\$(.+?)\$", r"\1", tail, flags=re.S)
        tail = tail.strip().strip(".;:")
        tail = tail.splitlines()[0].strip()
        return normalize_unicode_math(tail)

    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    for ln in reversed(lines):
        if any(tok in ln for tok in ["=", "/", "^", "x", "(", ")", "+", "-", "\\frac", "\\dfrac", "sqrt", "π", "pi", "sin", "cos", "tan", "log", "ln"]):
            expr = ln.split("=")[-1].strip() if "=" in ln and "==" not in ln else ln
            return normalize_unicode_math(expr)

    return normalize_unicode_math(lines[-1]) if lines else None

STEP_HEADER_RE = re.compile(r"(?i)step\s*\d+\s*[:\.\)\-]")

def split_solution_steps(text: str) -> List[str]:
    if not isinstance(text, str) or not text.strip():
        return []

    s = text.strip()
    matches = list(re.finditer(STEP_HEADER_RE, s))

    # No explicit 'Step k:' headers -> fallback to lines
    if not matches:
        return [ln.strip() for ln in s.splitlines() if ln.strip()]

    chunks: List[str] = []

    # Preamble before the first "Step k:" (optional)
    first_start = matches[0].start()
    pre = s[:first_start].strip()
    if pre:
        chunks.append(pre)

    # Bodies following each "Step k:" header
    for i, m in enumerate(matches):
        start = m.end()  # start after the header text
        end = matches[i + 1].start() if i + 1 < len(matches) else len(s)
        body = s[start:end].strip()
        if body:
            chunks.append(body)

    return chunks

def cosine_torch(a, b):
    a = a / (a.norm(dim=1, keepdim=True) + 1e-9)
    b = b / (b.norm(dim=1, keepdim=True) + 1e-9)
    return a @ b.T

def cosine_any(A, B):
    if has_torch and isinstance(A, torch.Tensor):
        return cosine_torch(A, B)
    import numpy as np
    def nrm(m):
        return m / (np.linalg.norm(m, axis=1, keepdims=True) + 1e-9)
    return nrm(A) @ nrm(B).T

class TextEncoder:
    def __init__(self, device: str = DEVICE):
        self.device = device if has_torch else "cpu"
        self.available = has_torch and has_tf
        self.tok = None
        self.model = None
        self.hdim = 768
        if not self.available:
            return
        for name in (MATH_EMBED_CANDIDATES + GEN_EMBED_FALLBACKS):
            try:
                self.tok = AutoTokenizer.from_pretrained(name, use_fast=True)
                self.model = AutoModel.from_pretrained(name)
                self.model.to(self.device).eval()
                self.hdim = getattr(self.model.config, "hidden_size", self.hdim)
                print(f"[ENC] {name}")
                return
            except Exception:
                continue
        self.available = False

    def encode_sentences(self, texts: List[str], batch_size: int = BATCH_SIZE_EMB):
        if not self.available:
            import numpy as np
            vecs = []
            for t in texts:
                v = np.zeros(64, dtype="float32")
                for ch in (t or ""):
                    v[ord(ch) % 64] += 1.0
                v = v / (np.linalg.norm(v) + 1e-9)
                vecs.append(v)
            return vecs
        outs = []
        with torch.no_grad():
            for i in range(0, len(texts), batch_size):
                batch = texts[i:i+batch_size]
                x = self.tok(batch, padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device)
                last = self.model(**x).last_hidden_state
                mask = x["attention_mask"].unsqueeze(-1)
                pooled = (last * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
                outs.append(pooled.detach().cpu())
        return torch.cat(outs, dim=0) if outs else torch.zeros((0, self.hdim))

    def encode_tokens(self, text: str):
        if not self.available:
            import numpy as np
            embs = []
            for ch in (text or "")[:256]:
                v = np.zeros(32, dtype="float32")
                v[ord(ch) % 32] = 1.0
                embs.append(v)
            return embs
        with torch.no_grad():
            x = self.tok(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            return self.model(**x).last_hidden_state[0].detach().cpu()

class NLIModel:
    def __init__(self, device: str = DEVICE):
        self.device = device if has_torch else "cpu"
        self.available = has_torch and has_tf
        self.tok = None
        self.model = None
        self.contr_idx, self.neutral_idx, self.entail_idx = 0, 1, 2
        if not self.available:
            return
        for name in GEN_NLI_FALLBACKS:
            try:
                self.tok = AutoTokenizer.from_pretrained(name, use_fast=True)
                self.model = AutoModelForSequenceClassification.from_pretrained(name)
                self.model.to(self.device).eval()
                print(f"[NLI] {name}")
                return
            except Exception:
                continue
        self.available = False

    def probs(self, pairs: List[Tuple[str, str]], batch_size: int = BATCH_SIZE_NLI):
        if not self.available:
            import numpy as np
            return np.tile([0.2, 0.6, 0.2], (len(pairs), 1)).astype("float32")
        outs = []
        with torch.no_grad():
            for i in range(0, len(pairs), batch_size):
                s1 = [a for a, _ in pairs[i:i+batch_size]]
                s2 = [b for _, b in pairs[i:i+batch_size]]
                inputs = self.tok(s1, s2, padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device)
                logits = self.model(**inputs).logits
                outs.append(torch.softmax(logits, dim=-1).detach().cpu())
        return torch.cat(outs, dim=0) if outs else torch.zeros((0, 3))

def f1_exact_string(ref_steps: List[str], hyp_steps: List[str]) -> Tuple[float, float, float]:
    norm = lambda s: re.sub(r"\s+", " ", s.strip().lower())
    from collections import Counter
    ref = Counter(map(norm, ref_steps or []))
    hyp = Counter(map(norm, hyp_steps or []))
    inter = sum((ref & hyp).values())
    p = inter / max(1, sum(hyp.values()))
    r = inter / max(1, sum(ref.values()))
    f1 = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0
    return p, r, f1

def f1_semantic(encoder: TextEncoder, ref_steps: List[str], hyp_steps: List[str], sim_threshold: float = 0.7) -> Tuple[float, float, float]:
    R = [s for s in (ref_steps or []) if s.strip()]
    H = [s for s in (hyp_steps or []) if s.strip()]
    if not R and not H:
        return 1.0, 1.0, 1.0
    if not R:
        return 0.0, 1.0, 0.0
    if not H:
        return 1.0, 0.0, 0.0
    R_emb = encoder.encode_sentences(R)
    H_emb = encoder.encode_sentences(H)
    S = cosine_any(R_emb, H_emb)
    flat, used_r, used_h = [], set(), set()
    if has_torch and isinstance(S, torch.Tensor):
        for i in range(S.size(0)):
            for j in range(S.size(1)):
                flat.append((float(S[i, j].item()), i, j))
    else:
        for i in range(S.shape[0]):
            for j in range(S.shape[1]):
                flat.append((float(S[i, j]), i, j))
    flat.sort(reverse=True)
    matches = 0
    for sim, i, j in flat:
        if sim < sim_threshold:
            break
        if i in used_r or j in used_h:
            continue
        used_r.add(i); used_h.add(j); matches += 1
    p = matches / max(1, len(H))
    r = matches / max(1, len(R))
    f1 = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0
    return p, r, f1

def srs_score(encoder: TextEncoder, nli: NLIModel, question: str, steps: List[str]) -> Dict[str, float]:
    step_embs = encoder.encode_sentences(steps) if steps else (torch.zeros((0, getattr(encoder, "hdim", 768))) if has_torch else [])
    q_sent_emb = encoder.encode_sentences([question])
    q_tok_embs = encoder.encode_tokens(question)

    def step_to_q(step_emb, q_tok) -> float:
        if (has_torch and isinstance(q_tok, torch.Tensor) and q_tok.size(0) == 0) or (isinstance(q_tok, list) and not q_tok):
            return 0.5
        if has_torch and isinstance(step_emb, torch.Tensor) and isinstance(q_tok, torch.Tensor):
            sims = cosine_torch(step_emb.unsqueeze(0), q_tok).squeeze(0)
            return float((1.0 + float(torch.max(sims).item())) / 2.0)
        import numpy as np
        se = step_emb.reshape(1, -1)
        qt = q_tok if not isinstance(q_tok, list) else (np.stack(q_tok, axis=0) if q_tok else np.zeros((0, 1), dtype="float32"))
        sims = cosine_any(se, qt).flatten()
        return float((1.0 + float(np.max(sims))) / 2.0)

    def qtok_to_chain(q_tok, step_mat) -> float:
        if (has_torch and isinstance(step_mat, torch.Tensor) and step_mat.size(0) == 0) or (isinstance(step_mat, list) and not step_mat):
            return 0.5
        if has_torch and isinstance(q_tok, torch.Tensor) and isinstance(step_mat, torch.Tensor):
            sims = cosine_torch(q_tok.unsqueeze(0), step_mat).squeeze(0)
            return float((1.0 + float(torch.max(sims).item())) / 2.0)
        import numpy as np
        qt = q_tok.reshape(1, -1)
        se = step_mat if not isinstance(step_mat, list) else (np.stack(step_mat, axis=0) if step_mat else np.zeros((0, 1), dtype="float32"))
        sims = cosine_any(qt, se).flatten()
        return float((1.0 + float(np.max(sims))) / 2.0)

    if has_torch and isinstance(step_embs, torch.Tensor):
        faithfulness = 0.5 if step_embs.size(0) == 0 else float(sum(step_to_q(step_embs[j], q_tok_embs) for j in range(step_embs.size(0))) / step_embs.size(0))
        info_chain = 0.5 if step_embs.size(0) == 0 else (1.0 + float(cosine_torch(step_embs.mean(dim=0, keepdim=True), q_sent_emb).item())) / 2.0
    else:
        import numpy as np
        if isinstance(step_embs, list) and step_embs:
            step_mat = np.stack(step_embs, axis=0)
            faithfulness = float(sum(step_to_q(step_mat[j], q_tok_embs) for j in range(step_mat.shape[0])) / step_mat.shape[0])
            q_emb = q_sent_emb[0] if isinstance(q_sent_emb, list) else q_sent_emb
            num = float((step_mat.mean(axis=0) @ q_emb) / ((np.linalg.norm(step_mat.mean(axis=0)) + 1e-9) * (np.linalg.norm(q_emb) + 1e-9)))
            info_chain = (1.0 + num) / 2.0
        else:
            faithfulness, info_chain = 0.5, 0.5

    if has_torch and not isinstance(q_tok_embs, list) and hasattr(q_tok_embs, "size") and q_tok_embs.size(0) > 0:
        vals = [qtok_to_chain(q_tok_embs[t], step_embs) for t in range(q_tok_embs.size(0))]
        info_step = float(sum(vals) / len(vals))
    else:
        if isinstance(q_tok_embs, list) and q_tok_embs:
            import numpy as np
            if isinstance(step_embs, list) and step_embs:
                step_mat = np.stack(step_embs, axis=0)
                vals = [qtok_to_chain(qt, step_mat) for qt in q_tok_embs]
                info_step = float(sum(vals) / len(vals))
            else:
                info_step = 0.5
        else:
            info_step = 0.5

    repetition_step = 0.5
    if has_torch and isinstance(step_embs, torch.Tensor) and step_embs.size(0) >= 2:
        sims = cosine_torch(step_embs, step_embs)
        eye = torch.eye(step_embs.size(0), dtype=torch.bool)
        sims = sims.masked_fill(eye, -1.0)
        repetition_step = (1.0 - float(torch.max(sims).item())) / 2.0
    elif isinstance(step_embs, list) and len(step_embs) >= 2:
        import numpy as np
        mat = np.stack(step_embs, axis=0)
        S = cosine_any(mat, mat)
        np.fill_diagonal(S, -1.0)
        repetition_step = (1.0 - float(S.max())) / 2.0

    discourse = 0.5
    if steps:
        probs_q = nli.probs([(s, question) for s in steps])
        if has_torch and isinstance(probs_q, torch.Tensor):
            discourse = 1.0 - float(torch.max(probs_q[:, nli.contr_idx]).item())
        else:
            discourse = 1.0 - float(probs_q[:, nli.contr_idx].max())

    coherence = 0.5
    if len(steps) >= 2:
        from itertools import combinations
        idxs = list(range(len(steps)))
        pairs_idx = list(combinations(idxs, 2))
        if len(steps) > MAX_STEPS_FOR_FULL_NLI:
            max_pairs = (MAX_STEPS_FOR_FULL_NLI * (MAX_STEPS_FOR_FULL_NLI - 1)) // 2
            pairs_idx = pairs_idx[:max_pairs]
        pairs = [(steps[j], steps[k]) for (j, k) in pairs_idx]
        probs_s = nli.probs(pairs)
        if has_torch and isinstance(probs_s, torch.Tensor):
            coherence = 1.0 - float(torch.max(probs_s[:, nli.contr_idx]).item())
        else:
            coherence = 1.0 - float(probs_s[:, nli.contr_idx].max())

    return {
        "faithfulness": float(faithfulness),
        "info_step": float(info_step),
        "info_chain": float(info_chain),
        "repetition_step": float(repetition_step),
        "discourse": float(discourse),
        "coherence": float(coherence),
    }

def vr_score(nli: NLIModel, steps: List[str], question: str) -> float:
    if not steps:
        return 0.0
    probs = nli.probs([(s, question) for s in steps])
    if has_torch and isinstance(probs, torch.Tensor):
        p_neu, p_ent = probs[:, 1], probs[:, 2]
        s_valid = p_ent + p_neu
        s_red = p_neu
        return float(torch.min(s_valid).item() - float(torch.max(s_red).item()))
    else:
        import numpy as np
        p_neu, p_ent = probs[:, 1], probs[:, 2]
        s_valid = p_ent + p_neu
        s_red = p_neu
        return float(np.min(s_valid) - float(np.max(s_red)))

def extract_ref_answer_expr(answer_field: Optional[str]) -> Optional[str]:
    if not isinstance(answer_field, str) or not answer_field.strip():
        return None
    txt = answer_field.strip()

    m_exact = re.search(r"exact\s*form\s*:\s*(.+?)(?:\bor\b|$)", txt, flags=re.I)
    if m_exact:
        cand = m_exact.group(1).strip().strip(".")
        cand = normalize_unicode_math(cand)
        cand = cand.replace(") / (", ")/(").replace(" / ", "/")
        return cand

    m_lnp = re.search(r"\\?ln\s*\(\s*([^)]+)\s*\)", txt, flags=re.I)
    if m_lnp:
        return f"log({normalize_unicode_math(m_lnp.group(1))})"
    m_lnb = re.search(r"\\?ln\s*\|\s*([^|]+)\s*\|\s*(?:\+\s*C)?", txt, flags=re.I)
    if m_lnb:
        return f"log(Abs({normalize_unicode_math(m_lnb.group(1))}))"

    m_num = re.findall(r"[-+]?\d+(?:\.\d+)?", txt)
    if m_num:
        return m_num[-1]
    return None

def evaluate_single_entry(entry: Dict[str, Any], encoder: TextEncoder, nli: NLIModel, f1_sem_threshold: float = 0.70) -> Dict[str, Any]:
    question = entry.get("question", "") or ""
    answer_field = entry.get("answer", None)
    ref_steps = [str(s) for s in (entry.get("steps") or [])]
    model_answer_raw = entry.get("model_answer", "") or ""

    ref_answer_expr = extract_ref_answer_expr(answer_field)
    hyp_steps = split_solution_steps(model_answer_raw)
    pred = extract_final_answer(model_answer_raw)

    acc_string = 1 if (pred is not None and ref_answer_expr is not None and str(pred).strip() == str(ref_answer_expr).strip()) else 0
    acc_symb = 1 if (ref_answer_expr is not None and sympy_equivalent(pred, ref_answer_expr) is True) else 0

    p_e, r_e, f1_e = f1_exact_string(ref_steps, hyp_steps)
    p_s, r_s, f1_s = f1_semantic(encoder, ref_steps, hyp_steps, sim_threshold=f1_sem_threshold)

    srs_components = srs_score(encoder, nli, question, hyp_steps)
    srs_avg = sum(srs_components.values()) / len(srs_components) if srs_components else 0.0

    vr = vr_score(nli, hyp_steps, question)

    return {
        "id": entry.get("id"),
        "accuracy_string": acc_string,
        "accuracy_symbolic": acc_symb,
        "f1_exact": {"precision": p_e, "recall": r_e, "f1": f1_e},
        "f1_semantic": {"precision": p_s, "recall": r_s, "f1": f1_s, "sim_threshold": f1_sem_threshold},
        "srs_components": srs_components,
        "srs": srs_avg,
        "vr_score": vr,
        "pred_answer": pred,
        "ref_answer_expr_used": ref_answer_expr,
        "ref_steps_count": len(ref_steps),
        "hyp_steps_count": len(hyp_steps)
    }

def aggregate_metrics(per_item: List[Dict[str, Any]]) -> Dict[str, Any]:
    def avg(key):
        vals = [x[key] for x in per_item if isinstance(x.get(key), (int, float))]
        return sum(vals)/len(vals) if vals else None
    def avg_nested(prefix, subkey):
        vals = [x[prefix][subkey] for x in per_item if isinstance(x.get(prefix, {}).get(subkey), (int, float, float))]
        return sum(vals)/len(vals) if vals else None
    agg = {
        "count": len(per_item),
        "accuracy_string": avg("accuracy_string"),
        "accuracy_symbolic": avg("accuracy_symbolic"),
        "f1_exact_precision": avg_nested("f1_exact", "precision"),
        "f1_exact_recall": avg_nested("f1_exact", "recall"),
        "f1_exact_f1": avg_nested("f1_exact", "f1"),
        "f1_semantic_precision": avg_nested("f1_semantic", "precision"),
        "f1_semantic_recall": avg_nested("f1_semantic", "recall"),
        "f1_semantic_f1": avg_nested("f1_semantic", "f1"),
        "srs": avg("srs"),
        "vr_score": avg("vr_score"),
    }
    comps = ["faithfulness","info_step","info_chain","repetition_step","discourse","coherence"]
    for c in comps:
        vals = [x["srs_components"][c] for x in per_item if "srs_components" in x and c in x["srs_components"]]
        agg[f"srs_{c}"] = sum(vals)/len(vals) if vals else None
    return agg

def load_entries(path: Path) -> List[Dict[str, Any]]:
    raw = json.load(open(path, "r", encoding="utf-8"))
    if isinstance(raw, list):
        return raw
    if isinstance(raw, dict):
        return [raw]
    raise ValueError("Input must be a JSON object or a list of objects.")

def save_report(report: dict, path: str = "./output/metrics_report.json") -> str:
    p = Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    with open(p, "w", encoding="utf-8") as f:
        json.dump(report, f, ensure_ascii=False, indent=2)
    return str(p.resolve())

def main():
    in_path = IN_PATH
    if not in_path.exists():
        raise FileNotFoundError(f"Input not found: {in_path.resolve()}")

    entries = load_entries(in_path)
    encoder = TextEncoder(DEVICE)
    nli = NLIModel(DEVICE)
    per_item = [evaluate_single_entry(e, encoder, nli, f1_sem_threshold=0.70) for e in entries]

    report = {
        "file": str(in_path.resolve()),
        "per_item": per_item,
        "aggregates": aggregate_metrics(per_item),
    }

    out_file = save_report(report, str(OUT_JSON))
    print(f"[WRITE] {out_file}")

if __name__ == "__main__":
    main()