from __future__ import annotations
import os
import re
import string
import numpy as np
import torch
from typing import List, Optional, Callable, Dict, Iterable
from collections import OrderedDict
from dataclasses import dataclass, fields
from pathlib import Path
from math_verify import parse, verify
import logging


def _dataclass_from_dict(cls, dct, **overrides):
    d = dict(dct or {})
    d.update(overrides)
    allowed = {f.name for f in fields(cls)}
    return cls(**{k: v for k, v in d.items() if k in allowed})

# ----------------------------
# DDP helpers
# ----------------------------
def is_main_process() -> bool:
    return (not torch.distributed.is_available()
            or not torch.distributed.is_initialized()
            or torch.distributed.get_rank() == 0)

# ----------------------------
# SBERT backend (offline-friendly, batched, CPU by default)
# ----------------------------
@dataclass
class SBERTConfig:
    model_path: str                        # local dir (pre-downloaded)
    device: str = "cpu"                    # keep SBERT on CPU to avoid VRAM spikes
    batch_size: int = 128
    normalize_embeddings: bool = True
    show_progress_bar: bool = False
    # Config-friendly extras:
    force_offline: bool = True            # set HF offline env internally
    tokenizers_parallelism: str = "false" # env
    num_threads: int = 1                  # torch.set_num_threads

    @classmethod
    def from_dict(cls, d: dict, **overrides):
        return _dataclass_from_dict(cls, d, **overrides)

class SBERTBackend:
    def __init__(self, cfg: SBERTConfig):
        self.cfg = cfg
        self._model = None
        
    def _ensure(self):
        if self._model is not None:
            return
        if self.cfg.force_offline:
            os.environ.setdefault("HF_HUB_OFFLINE", "1")
            os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
        os.environ.setdefault("TOKENIZERS_PARALLELISM", self.cfg.tokenizers_parallelism)
        torch.set_num_threads(self.cfg.num_threads)

        from sentence_transformers import SentenceTransformer
        self._model = SentenceTransformer(self.cfg.model_path, device=self.cfg.device)
        self._model.eval()

    @torch.inference_mode()
    def encode(self, texts: List[str]) -> torch.Tensor:
        self._ensure()
        embs = self._model.encode(
            texts,
            batch_size=self.cfg.batch_size,
            convert_to_tensor=True,
            normalize_embeddings=self.cfg.normalize_embeddings,
            show_progress_bar=self.cfg.show_progress_bar,
        )
        return embs.detach().cpu()  # keep on CPU

# ----------------------------
# Text normalization & metrics
# ----------------------------

# def _strip_answer_prefix(s: str) -> str:
#     return re.sub(r"^\s*(final\s+answer|answer|ans)\s*[:\-]\s*", "", s, flags=re.IGNORECASE).strip()

def _strip_answer_prefix(s: str) -> str:
    # remove common prefixes at the START only
    return re.sub(
        r"^\s*(?:final\s+answer|the\s+final\s+answer|the\s+answer|answer|ans)\s*(?:is|=|:|-)?\s*",
        "",
        s,
        flags=re.IGNORECASE,
    ).strip()


def normalize_text(s: str) -> str:
    s = _strip_answer_prefix(s)
    s = s.lower().strip()
    s = s.translate(str.maketrans("", "", string.punctuation))
    s = re.sub(r"\s+", " ", s)
    return s

def first_sentence(text: str) -> str:
    if not isinstance(text, str):
        return ""
    return text.split("\n", 1)[0].strip()
    
def last_sentence(text: str) -> str:
    if not isinstance(text, str):
        return ""
    return text.strip().rsplit("\n", 1)[-1].strip()

_STOPWORDS = {
    "a","an","the","is","are","was","were","be","been","being","of","for","to","in","on","at",
    "by","and","or","but","if","then","that","this","these","those","with","as","from","it","its"
}

def token_set(s: str, remove_stopwords: bool = True) -> set:
    toks = normalize_text(s).split()
    if remove_stopwords:
        toks = [t for t in toks if t not in _STOPWORDS]
    return set(toks)

def f1(gs: set, ts: set) -> float:
    if not gs or not ts:
        return 0.0
    inter = len(gs & ts)
    p = inter / len(gs)
    r = inter / len(ts)
    return 2 * p * r / (p + r + 1e-12)

# --- Robust "final answer" extractor -----------------------------------------
# Matches things like:
#  "The answer is 42", "Answer: 42", "Final Answer — 42", "**Answer:** 42",
#  "Ans = 42", "Result: 42", "Prediction: 42", etc.
ANSWER_CUE_RE = re.compile(
    r'(?i)(?:^|[\n\r>]\s*)(?:\*\*|\*|__|~~|`|>\s*)*'
    r'(?:final\s+answer|the\s+final\s+answer|the\s+answer|answer|ans(?:wer)?|result|output|prediction|choice|option)'
    r'\s*(?:is|=|:|->)?\s*'
)

def extract_final_answer(text: str) -> str:
    """
    Return the content that comes AFTER a final-answer cue. Prefers the last cue in the text.
    Falls back to "" (caller can then use last_sentence or other logic).
    """
    if not isinstance(text, str):
        return ""
    t = text.strip()
    if not t:
        return ""

    matches = list(ANSWER_CUE_RE.finditer(t))
    if not matches:
        return ""

    # Use the last occurrence to avoid early drafts like "Answer: ..." followed by a later "Final Answer: ..."
    m = matches[-1]
    tail = t[m.end():]

    # Take the first non-empty line after the cue
    parts = [p.strip() for p in re.split(r'[\r\n]+', tail) if p.strip()]
    cand = parts[0] if parts else tail.strip()

    # Prefer content inside LaTeX \boxed{...} if present
    boxed = re.search(r'\\boxed\{([^}]*)\}', cand)
    if boxed:
        cand = boxed.group(1).strip()

    # Heuristic cut: keep up to sentence end if it exists and the segment is long.
    # Avoid breaking decimals; only cut on punctuation that isn't part of a number.
    m2 = re.match(r'(.{1,200}?)(?:(?<!\d)[.?!](?:\s|$)|$)', cand)
    if m2:
        cand = m2.group(1).strip()

    # Trim common leading/trailing cruft
    cand = cand.strip(" -*_`>\"'“”‘’:")
    # Drop leading option markers like "(A) " / "A. "
    cand = re.sub(r'^[\(\[]?[A-Da-d]\)?\.?\s+', '', cand)
    return cand

# ----------------------------
# Math verification (using math-verify package)
# ----------------------------
def is_correct_math(generated: str, truth: str) -> bool:
    """
    Return True if generated answer is mathematically equivalent to truth.
    Uses math-verify package
    """
    try:
        g_expr = parse(generated)
        t_expr = parse(truth)
        return verify(g_expr, t_expr)
    except Exception:
        return False

# ----------------------------
# AnswerVerifier: exact → containment → F1 → SBERT
# Truth comes from dataset['answer'] (falls back to 'completion' if missing)
# ----------------------------


# Verifier & Generation configs: add from_dict
@dataclass
class VerifierConfig:
    sim_threshold: float = 0.75
    f1_threshold: float = 0.90
    remove_stopwords: bool = True
    containment_use_full_response: bool = False
    gen_emb_cache_cap: int = 20000

    @classmethod
    def from_dict(cls, d: dict, **overrides):
        return _dataclass_from_dict(cls, d, **overrides) 

class AnswerVerifier:
    """
    Reusable, batched verifier for NL answers with cascade:
    1) normalized exact match (first sentence of model reply vs short `answer`)
    2) normalized substring containment (first sentence; optionally full response)
    3) token-level F1
    4) SBERT cosine (first sentence vs `answer`)
    """
    def __init__(self, sbert: Optional[SBERTBackend], cfg: VerifierConfig):
        self.sbert = sbert
        self.cfg = cfg

        # Truth caches (aligned to positions in the eval dataset)
        self._truth_norm: Optional[List[str]] = None      # normalized `answer`
        self._truth_tok: Optional[List[set]] = None       # tokens of `answer`
        self._truth_emb: Optional[torch.Tensor] = None    # embeddings of raw `answer` strings

        # Generated-side caches (small LRU)
        self._gen_emb_cache: "OrderedDict[str, torch.Tensor]" = OrderedDict()
        self._gen_tok_cache: Dict[str, set] = {}

    def _cache_get(self, key: str):
        v = self._gen_emb_cache.get(key)
        if v is not None:
            self._gen_emb_cache.move_to_end(key)
        return v

    def _cache_put(self, key: str, value: torch.Tensor):
        self._gen_emb_cache[key] = value
        self._gen_emb_cache.move_to_end(key)
        if len(self._gen_emb_cache) > self.cfg.gen_emb_cache_cap:
            self._gen_emb_cache.popitem(last=False)

    def build_truth_cache(self, val_dataset) -> None:
        """
        Build caches for a given validation split (call once per split).
        Uses `answer` if present, otherwise falls back to `completion`.
        """
        if self._truth_emb is not None:
            return
        
        # use first sentence as fallback if `answer` is missing or empty
        # answers = [val_dataset[i].get('answer', first_sentence(val_dataset[i].get('completion', "")))
        #            for i in range(len(val_dataset))]

        # use robust final-answer extractor as fallback if `answer` is empty
        answers: List[str] = []
        for i in range(len(val_dataset)):
            a = val_dataset[i].get('answer', None)
            a = (str(a).strip() if a is not None else "")
            if not a:
                comp = val_dataset[i].get('completion', "") or ""
                parsed = extract_final_answer(comp)
                a = parsed if parsed else first_sentence(comp)
            answers.append(a)

        self._truth_norm = [normalize_text(a) for a in answers]
        self._truth_tok  = [token_set(a, remove_stopwords=self.cfg.remove_stopwords) for a in answers]
        self._truth_emb  = self.sbert.encode(answers) if self.sbert is not None else None  # [N, d] or None

    def verify_batch(
        self,
        generated_texts: List[str],
        eval_indices: List[int],
        *,
        also_check_full_response: Optional[bool] = None
    ) -> List[bool]:
        """
        Compare model outputs (first sentences) against cached ground-truth answers at eval_indices.
        Returns list[bool] of correctness flags (same order as generated_texts).
        """
        assert self._truth_norm is not None, "call build_truth_cache(val_dataset) first"

        use_full = (self.cfg.containment_use_full_response 
                    if also_check_full_response is None else also_check_full_response)

        # Prepare generated variants
        gen_extract = [extract_final_answer(g) or last_sentence(g) for g in generated_texts]
        gen_extract_norm = [normalize_text(g) for g in gen_extract]
        if use_full:
            gen_full_norm = [normalize_text(g) for g in generated_texts]

        ok = [False] * len(generated_texts)
        unresolved = []

        # 1) Exact (normalized) extracted-sentence match
        for k, j in enumerate(eval_indices):
            if gen_extract_norm[k] == self._truth_norm[j]:
                ok[k] = True
            else:
                unresolved.append(k)

        # 2) Containment (normalized)
        if unresolved:
            still = []
            for k in unresolved:
                j = eval_indices[k]
                t_norm = self._truth_norm[j]
                hit = (t_norm in gen_extract_norm[k])
                if not hit and use_full:
                    hit = (t_norm in gen_full_norm[k])
                if hit:
                    ok[k] = True
                else:
                    still.append(k)
            unresolved = still

        # 3) Token F1 (extracted vs answer)
        if unresolved:
            still = []
            for k in unresolved:
                j = eval_indices[k]
                g_tok = self._gen_tok_cache.get(gen_extract[k])
                if g_tok is None:
                    g_tok = token_set(gen_extract[k], remove_stopwords=self.cfg.remove_stopwords)
                    self._gen_tok_cache[gen_extract[k]] = g_tok
                t_tok = self._truth_tok[j]
                if f1(g_tok, t_tok) >= self.cfg.f1_threshold:
                    ok[k] = True
                else:
                    still.append(k)
            unresolved = still

        # 4) SBERT cosine (batched; extracted vs answer)
        if unresolved and self._truth_emb is not None:
            need, order = [], []
            for k in unresolved:
                key = gen_extract[k]   # use raw extracted as cache key
                emb = self._cache_get(key)
                if emb is None:
                    need.append(key)
                order.append(key)

            # Encode misses once
            if need:
                new_embs = self.sbert.encode(need)  # [M, d]
                for key, emb in zip(need, new_embs):
                    self._cache_put(key, emb)

            gen_embs = torch.stack([self._gen_emb_cache[key] for key in order], dim=0)  # [U, d]
            truth_rows = torch.tensor([eval_indices[k] for k in unresolved], dtype=torch.long)
            truth_embs = self._truth_emb.index_select(dim=0, index=truth_rows)          # [U, d]
            cos = (gen_embs * truth_embs).sum(dim=1)  # L2-normalized → dot = cosine
            flags = (cos >= self.cfg.sim_threshold).tolist()
            for k, flag in zip(unresolved, flags):
                ok[k] = bool(flag)

        return ok

# ----------------------------
# GenerationEvaluator: prompt formatting + generation + scoring
# ----------------------------

@dataclass
class GenConfig:
    max_input_len: int = 1024
    max_new_tokens: int = 256
    temperature: float = 0.0
    answer_cue: str = ""

    @classmethod
    def from_dict(cls, d: dict, **overrides):   
        return _dataclass_from_dict(cls, d, **overrides)
   
   
class GenerationEvaluator:
    def __init__(self, model, tokenizer, verifier: Optional[AnswerVerifier] = None, gen_cfg: GenConfig = GenConfig()):
        self.model = model
        self.tok = tokenizer
        self.verifier = verifier
        self.cfg = gen_cfg

        # pad/eos guards
        if getattr(self.tok, "pad_token", None) is None and getattr(self.tok, "eos_token", None) is not None:
            self.tok.pad_token = self.tok.eos_token
        if getattr(self.tok, "padding_side", None) != "left":
            self.tok.padding_side = "left"

    def _format_prompts(self, questions: List[str]) -> List[str]:
        if hasattr(self.tok, "apply_chat_template") and getattr(self.tok, "chat_template", None):
            return [
                self.tok.apply_chat_template(
                    [{"role": "user", "content": q}],
                    add_generation_prompt=True,
                    tokenize=False,
                )
                for q in questions
            ]
        # non-chat: add explicit cue to reduce echoing
        cue = self.cfg.answer_cue
        cue = cue.strip() if isinstance(cue, str) else ""
        suffix = f"\n{cue}" if cue else ""
        return [q.rstrip() + suffix for q in questions]

    @torch.inference_mode()
    def _generate(self, prompts: List[str]) -> List[str]:
        inputs = self.tok(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.cfg.max_input_len,
        )
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

        out = self.model.generate(
            **inputs,
            max_new_tokens=self.cfg.max_new_tokens,
            do_sample=(self.cfg.temperature > 0.0),
            temperature=(self.cfg.temperature if self.cfg.temperature > 0.0 else 1.0),
            eos_token_id=getattr(self.tok, "eos_token_id", None),
            pad_token_id=getattr(self.tok, "pad_token_id", getattr(self.tok, "eos_token_id", None)),
        )
        # decode only new tokens
        gen_only = out[:, inputs["input_ids"].shape[1]:]
        gens = self.tok.batch_decode(gen_only, skip_special_tokens=True)
        return gens

    def evaluate_indices(
        self,
        dataset,
        indices: List[int],
        verify_mode: str = "nl",              # "nl" or "math"
        is_correct_math: Optional[Callable[[str, str], bool]] = is_correct_math,
        return_details: bool = False,
    ):
        """Return accuracy (0-100) and optional per-sample boolean list."""
        questions = [dataset[i]["prompt"] for i in indices]
        prompts = self._format_prompts(questions)
        gens = self._generate(prompts)

        if verify_mode == "math": 
            assert is_correct_math is not None, "Provide is_correct_math for math mode"
            truths = []
            for i in indices:
                a = dataset[i].get("answer", None)
                a = (str(a).strip() if a is not None else "")
                if not a:
                    comp = dataset[i].get("completion", "") or ""
                    parsed = extract_final_answer(comp)
                    a = parsed if parsed else last_sentence(comp) or comp
                truths.append(a)
            
            gen_extracts = [extract_final_answer(g) or last_sentence(g) for g in gens]
            # # debug print the gen, truth pairs
            # for g, t in zip(gen_extracts, truths):
            #     logging.info(f"[DEBUG]\nGenerated answer: {g}\nGround Truth: {t}\n---")
            flags = [bool(is_correct_math(g, t)) for g, t in zip(gens, truths)]
        else:
            # NL mode: build truth cache from `answer`, then cascade verify
            self.verifier.build_truth_cache(dataset)
            flags = self.verifier.verify_batch(gens, indices)

        acc = 100.0 * (sum(flags) / max(1, len(flags)))
        if return_details:
            return acc, flags, gens
        return acc


def build_validation_from_cfg(
    model,
    tokenizer,
    val_cfg: dict,
    *,
    default_mode: str = "math"
):
    """
    Build (sbert_backend | None), verifier (AnswerVerifier | None), gen_evaluator (GenerationEvaluator)
    from a plain dict like config["validation"].
    """
    mode = val_cfg.get("mode", default_mode).lower()
    sbert_backend = None
    verifier = None

    # Generation config
    gen_cfg = GenConfig.from_dict(val_cfg.get("generation", {}))

    if mode == "nl":
        sbert_dict = val_cfg.get("sbert", {}) or {}
        verifier_dict = val_cfg.get("verifier", {}) or {}

        # SBERT is optional: if model_path missing, we still build the verifier (it will skip the SBERT stage)
        model_path = sbert_dict.get("model_path", "")
        if model_path:
            sbert_backend = SBERTBackend(SBERTConfig.from_dict(sbert_dict))

        verifier = AnswerVerifier(
            sbert=sbert_backend,
            cfg=VerifierConfig.from_dict(verifier_dict),
        )

    gen_eval = GenerationEvaluator(model=model, tokenizer=tokenizer, verifier=verifier, gen_cfg=gen_cfg)
    return mode, sbert_backend, verifier, gen_eval
