import re
import math
import numpy as np
from typing import List, Dict, Optional


def _is_missing(x) -> bool:
    if x is None:
        return True
    if isinstance(x, float):
        return math.isnan(x)
    if isinstance(x, str):
        return x.strip() == ""
    return False


def _norm_text(x) -> Optional[str]:
    if _is_missing(x):
        return None
    return str(x).strip()


_num_pat = re.compile(r"[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?")


def _parse_float(x) -> Optional[float]:
    s = _norm_text(x)
    if s is None:
        return None
    s = s.replace(",", "").replace("%", "").strip()
    m = _num_pat.search(s)
    if not m:
        return None
    try:
        return float(m.group(0))
    except ValueError:
        return None


class ResultEvaluator:

    # ---------------- Extract only within the text end window (English trigger word) ----------------
    @staticmethod
    def _tail_window(text: str, max_chars: int = 800) -> str:
        """Cut off from the last occurrence of the English trigger word; otherwise, take the last max_chars characters"""
        s = (text or "").replace("**", "").replace("\u200b", " ").strip()
        lower = s.lower()
        triggers = [
            "final answer", "the answer is", "answer is", "answer",
            "final label", "final prediction", "final decision", "final verdict",
            "final result", "final", "label", "prediction", "decision",
            "verdict", "result", "overall", "conclusion",
            "\\boxed", "hence", "therefore", "thus", "so in conclusion"
        ]
        pos = -1
        for t in triggers:
            i = lower.rfind(t)
            if i > pos:
                pos = i
        if pos != -1:
            start = max(0, pos - 60)
            return s[start:]
        return s[-max_chars:]

    @staticmethod
    def _extract_label_from_tail(text: str, label_regex: str, normalize_fn) -> str:
        """
        Only extract English labels in the final window:
        1) Prioritize strong anchor points (\\boxed / Final Answer / Label / Prediction / Decision / Verdict / Conclusion, etc.);
        2) If none are found, scan line by line from the bottom, restricting to short lines and avoiding terms from the reasoning context.
        """
        tail = ResultEvaluator._tail_window(text)
        lower = tail.lower()

        primary_patterns = [
            r'\\boxed\{\s*(' + label_regex + r')\s*\}',
            r'(?:final\s*(?:answer|label|prediction|decision|verdict|result))\s*[:\-]?\s*(' + label_regex + r')\b',
            r'(?:the\s*)?(?:answer|label|prediction|decision|verdict|result)\s*(?:is|:)\s*(' + label_regex + r')\b',
            r'(?:conclusion|therefore|thus|hence)[^.\n]{0,80}?(' +
            label_regex + r')\b',
        ]
        last_match = None
        for pat in primary_patterns:
            for m in re.finditer(pat, lower, flags=re.IGNORECASE):
                last_match = m.group(1)
        if last_match:
            label = normalize_fn(last_match)
            return label if label else ""

        noise_keywords = [
            "premise", "hypothesis", "option", "choice", "choices", "context",
            "reason", "because", "assume", "suppose", "if", "then", "proof",
            "explain", "consider", "step", "reasoning", "analysis"
        ]
        bare_pat = re.compile(
            r'\b(' + label_regex + r')\b', flags=re.IGNORECASE)
        for line in reversed(lower.splitlines()):
            l = line.strip()
            if not l:
                continue
            if len(l) > 120:
                continue
            if any(k in l for k in noise_keywords):
                continue
            found = bare_pat.findall(l)
            if len(found) == 1:
                label = normalize_fn(found[0])
                if label:
                    return label

        return ""

    @staticmethod
    def extract_answer(text: str, problem_type: str = "math") -> str:
        """extract answer from the model output text"""
        if problem_type == "math":
            # --------- tools ---------
            def _strip(s: str) -> str:
                return s.strip().replace("\u200b", "").replace("\xa0", " ")

            def _gcd(a: int, b: int) -> int:
                while b:
                    a, b = b, a % b
                return abs(a)

            def _simplify_fraction(num_str: str, den_str: str) -> str:
                try:
                    n = int(num_str.strip())
                    d = int(den_str.strip())
                    if d == 0:
                        return f"{n}/{d}"
                    g = _gcd(n, d)
                    n //= g
                    d //= g
                    if d < 0:
                        n, d = -n, -d
                    return f"{n}/{d}"
                except:
                    return f"{_strip(num_str)}/{_strip(den_str)}"

            NUM = r"[+-]?(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d+)?"
            SIMPLE_FRAC = r"\b\d+\s*/\s*\d+\b"

            RE_TRAILING_HASH = re.compile(
                r"####\s*([^\n#]+)", re.IGNORECASE)  # #### 260
            RE_BOXED = re.compile(
                r"\\boxed\{([^}]+)\}", re.IGNORECASE)  # \boxed{...}
            # answer is ...
            RE_ANSWER_IS = re.compile(
                r"answer\s*is[:\s]*([^\n]+)", re.IGNORECASE)
            RE_MONEY = re.compile(r"\$(" + NUM + r")")  # $123.45
            RE_PERCENT = re.compile(r"(" + NUM + r")\s*\\?%")  # 90% / 90\%
            RE_NUMBER = re.compile(r"(" + NUM + r")")  # 123
            RE_LATEX_FRAC = re.compile(
                r"\\frac\{([^}]+)\}\{([^}]+)\}")  # \frac{a}{b}
            RE_SIMPLE_FRAC = re.compile(SIMPLE_FRAC)  # 1/4
            # shot algebraic expression
            RE_ALGEBRA = re.compile(r"^[\s()\-+*/^0-9a-zA-Z._]+$")

            def _post_normalize(ans: str) -> str:
                ans = _strip(ans)
                m = RE_PERCENT.search(ans)
                if m:
                    return m.group(1).replace(",", "") + "%"
                if RE_NUMBER.fullmatch(ans):
                    return ans.replace(",", "")
                m = RE_SIMPLE_FRAC.search(ans)
                if m and ans == m.group(0):
                    a, b = ans.split("/")
                    return _simplify_fraction(a, b)
                return ans

            # 1) "#### 260"
            m = RE_TRAILING_HASH.search(text)
            if m:
                return _post_normalize(m.group(1))

            # 2) \boxed{...}
            m = RE_BOXED.search(text)
            if m:
                return _post_normalize(m.group(1))

            # 3) "answer is ..."
            m = RE_ANSWER_IS.search(text)
            if m:
                seg = _post_normalize(m.group(1))
                for pat in (RE_LATEX_FRAC, RE_SIMPLE_FRAC, RE_PERCENT, RE_MONEY, RE_NUMBER):
                    mm = pat.findall(seg)
                    if mm:
                        if pat is RE_LATEX_FRAC:
                            a, b = mm[-1]
                            return _simplify_fraction(a, b)
                        if pat is RE_SIMPLE_FRAC:
                            last = mm[-1]
                            a, b = last.split("/")
                            return _simplify_fraction(a, b)
                        last = mm[-1]
                        if isinstance(last, tuple):
                            last = last[0]
                        return _post_normalize(str(last))
                if RE_ALGEBRA.fullmatch(seg):
                    return seg
                return seg

            # 4) LaTeX
            mm = RE_LATEX_FRAC.findall(text)
            if mm:
                a, b = mm[-1]
                return _simplify_fraction(a, b)

            # 5)  1/4
            mm = RE_SIMPLE_FRAC.findall(text)
            if mm:
                a, b = mm[-1].split("/")
                return _simplify_fraction(a, b)

            # 6) %
            mm = RE_PERCENT.findall(text)
            if mm:
                return mm[-1].replace(",", "") + "%"

            # 7) $
            mm = RE_MONEY.findall(text)
            if mm:
                return mm[-1].replace(",", "")

            # 8) last number
            mm = RE_NUMBER.findall(text)
            if mm:
                for val in reversed(mm):
                    a = val.replace(",", "").strip()
                    try:
                        float(a)
                        return a
                    except:
                        continue

            # 9)  "1-x"
            whole = _strip(text)
            if len(whole) <= 64 and RE_ALGEBRA.fullmatch(whole):
                return whole

            return None

        elif problem_type == "folio":
            label_regex = r'true|false|uncertain|unknown|indeterminate|undetermined'
            label = ResultEvaluator._extract_label_from_tail(
                text, label_regex, ResultEvaluator._normalize_folio_label
            )
            return label if label else ""

        elif problem_type == "logiqa":
            label_regex = r'entailment|not[_\- ]?entailment'
            label = ResultEvaluator._extract_label_from_tail(
                text, label_regex, ResultEvaluator._normalize_logiqa_label
            )
            return label if label else ""

        elif problem_type == "abductionr":
            label_regex = r'true|false|entailed|entailment|contradicted|contradiction|yes|no'
            label = ResultEvaluator._extract_label_from_tail(
                text, label_regex, ResultEvaluator._normalize_abductionr_label
            )
            return label if label else ""

        elif problem_type == "fld":
            label_regex = r'proved|disproved|unknown|true|false|uncertain'
            label = ResultEvaluator._extract_label_from_tail(
                text, label_regex, ResultEvaluator._normalize_fld_label
            )
            return label if label else ""

        elif problem_type == "proofwriter":
            label_regex = r'true|false|uncertain|unknown|indeterminate|undetermined'
            label = ResultEvaluator._extract_label_from_tail(
                text, label_regex, ResultEvaluator._normalize_proofwriter_label
            )
            return label if label else ""

        elif problem_type == "ruletaker":
            label_regex = r'entailment|not[_\- ]?entailment|true|false|uncertain|unknown'
            label = ResultEvaluator._extract_label_from_tail(
                text, label_regex, ResultEvaluator._normalize_ruletaker_label
            )
            return label if label else ""

        elif problem_type == "arlsat":
            pattern = r'\b([0-4])\b'
            matches = re.findall(pattern, text.strip())
            if matches:
                return matches[0]
            return ""

        elif problem_type == "reclor":
            pattern = r'\b([0-3])\b'
            matches = re.findall(pattern, text.strip())
            if matches:
                return matches[0]
            return ""

        elif problem_type == "rulearena-airline":
            text = text.replace("**", "")
            start_id = text.find("The total cost is")
            if start_id != -1:
                conclusion = text[start_id:]
                value_idx = conclusion.find("$")
                if value_idx != -1:
                    value = conclusion[value_idx +
                                       1:].replace(",", "").replace(".", "")
                    if value.isnumeric():
                        value = int(value)
                    return value
        elif problem_type == "rulearena-nba":
            return text.replace("**", "")
        elif problem_type == "rulearena-tax":
            text = text.replace("**", "")
            pattern = r"The total tax (owed|overpaid) is \$((?:\d{1,3}(?:,\d{3})*|\d+)(\.\d+)?)."
            match = re.search(pattern, text)
            if match:
                status = match.group(1)
                value = float(match.group(2).replace(",", ""))
                value = -value if status == "overpaid" else value
                return value

        # If a specific format is not found, return the original text.
        return text.strip()

    @staticmethod
    def evaluate_mathdata(predictions: List[str], ground_truths: List[str],
                          rel_tol: float = 1e-9, abs_tol: float = 1e-9) -> Dict[str, float]:
        correct = 0
        total = 0
        for pred, truth in zip(predictions, ground_truths):
            pred_answer = ResultEvaluator.extract_answer(pred, "math")
            truth_answer = ResultEvaluator.extract_answer(truth, "math")

            p = _norm_text(pred_answer)
            t = _norm_text(truth_answer)

            # Samples without ground truth are skipped directly and not included in the denominator.
            if t is None:
                continue

            total += 1
            # Predicted missing → Count as an error but do not throw an exception.
            if p is None:
                continue

            pn, tn = _parse_float(p), _parse_float(t)
            if pn is not None and tn is not None:
                if math.isclose(pn, tn, rel_tol=rel_tol, abs_tol=abs_tol):
                    correct += 1
            else:
                if p.lower() == t.lower():
                    correct += 1

        accuracy = (correct / total) if total > 0 else 0.0
        return {"accuracy": accuracy, "correct": correct, "total": total}

    @staticmethod
    def _normalize_folio_label(s: str) -> str:
        """
        'true' / 'false' / 'uncertain'
        """
        s = (s or "").strip().lower()
        true_set = {"true", "entailed", "entailment", "valid", "yes"}
        false_set = {"false", "contradicted", "contradiction", "no"}
        uncertain_set = {
            "uncertain", "unknown", "indeterminate", "undetermined",
            "cannot be determined", "can't be determined", "not sure", "both"
        }
        if s in true_set:
            return "true"
        if s in false_set:
            return "false"
        if s in uncertain_set:
            return "uncertain"
        return ""

    @staticmethod
    def _normalize_logiqa_label(s: str) -> str:
        """
        'entailment' /  'not-entailment'
        """
        s = (s or "").strip().lower()

        if s in {"entailment"}:
            return "entailment"
        if s in {"not_entailment", "not-entailment", "notentailment"}:
            return "not-entailment"

        alias_map = {
            "entails": "entailment",
            "entailed": "entailment",
            "not entails": "not-entailment",
            "not entailed": "not-entailment",
            "contradicted": "not-entailment",
            "inconsistent": "not-entailment",
        }
        return alias_map.get(s, "")

    @staticmethod
    def _normalize_abductionr_label(s: str) -> str:
        """
        'true' / 'false'
        """
        s = (s or "").strip().lower()
        true_set = {"true", "entailed", "entailment", "valid", "yes"}
        false_set = {"false", "contradicted", "contradiction", "no"}
        if s in true_set:
            return "true"
        if s in false_set:
            return "false"
        return ""

    @staticmethod
    def _normalize_fld_label(s: str) -> str:
        """
        'proved' / 'disproved' / 'unknown'
        """
        s = (s or "").strip().lower()
        proved_set = {"proved", "true", "valid",
                      "yes", "entailed", "entailment"}
        disproved_set = {"disproved", "false", "invalid",
                         "no", "contradicted", "contradiction"}
        unknown_set = {"unknown", "uncertain", "indeterminate", "undetermined",
                       "cannot be determined", "can't be determined", "not sure", "both"}
        if s in proved_set:
            return "proved"
        if s in disproved_set:
            return "disproved"
        if s in unknown_set:
            return "unknown"
        return ""

    @staticmethod
    def _normalize_proofwriter_label(s: str) -> str:
        """
        'true' / 'false' / 'unknown'
        """
        s = (s or "").strip().lower()
        true_set = {"true", "valid", "yes", "entailed", "entailment"}
        false_set = {"false", "invalid", "no", "contradicted", "contradiction"}
        unknown_set = {"unknown", "uncertain", "indeterminate", "undetermined",
                       "cannot be determined", "can't be determined", "not sure", "both"}
        if s in true_set:
            return "true"
        if s in false_set:
            return "false"
        if s in unknown_set:
            return "unknown"
        return ""

    @staticmethod
    def _normalize_ruletaker_label(s: str) -> str:
        """
        'entailment' / 'not entailment'
        """
        s = (s or "").strip().lower()
        entailment_set = {"entailment", "true",
                          "valid", "yes", "entailed", "entails"}
        not_entailment_set = {"not_entailment", "not-entailment", "notentailment",
                              "false", "invalid", "no", "contradicted", "contradiction",
                              "uncertain", "unknown", "indeterminate", "undetermined",
                              "cannot be determined", "can't be determined", "not sure", "both"}
        if s in entailment_set:
            return "entailment"
        if s in not_entailment_set:
            return "not entailment"
        return ""

    @staticmethod
    def evaluate_folio(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        """
        ‘true/false/uncertain’
        """
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "folio")
            truth_label = ResultEvaluator.extract_answer(truth, "folio")
            if pred_label.strip().lower() == truth_label.strip().lower():
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_arlsat(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "arlsat")
            truth_label = ResultEvaluator.extract_answer(truth, "arlsat")

            if pred_label and truth_label and pred_label == truth_label:
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_logiqa(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        """
        entailment / not-entailment
        """
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "logiqa")
            truth_label = ResultEvaluator.extract_answer(truth, "logiqa")

            pred_label = ResultEvaluator._normalize_logiqa_label(pred_label)
            truth_label = ResultEvaluator._normalize_logiqa_label(truth_label)

            if pred_label and truth_label and pred_label == truth_label:
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_reclor(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "reclor")
            truth_label = ResultEvaluator.extract_answer(truth, "reclor")

            if pred_label and truth_label and pred_label == truth_label:
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_abductionr(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        """
        true/false
        """
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "abductionr")
            truth_label = ResultEvaluator.extract_answer(truth, "abductionr")
            if pred_label.strip().lower() == truth_label.strip().lower():
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_rulearena(predictions: List[str], ground_truths: List[str], data_split: str) -> Dict[str, float]:
        assert data_split in ["airline", "nba", "tax"]
        correct = 0
        total = len(predictions)
        assert total == len(
            ground_truths), f"Predictions and ground truths must have the same length, for rulearena, len of predictions = {total}, len of ground_truths = {len(ground_truths)}"
        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(
                pred, f"rulearena-{data_split}")
            truth_label = truth

            if pred_label and truth_label:
                if data_split == "airline":
                    if pred_label == truth_label or (isinstance(pred_label, str) and str(truth_label) in pred_label):
                        correct += 1
                elif data_split == "nba" and truth_label in pred_label:
                    correct += 1
                elif data_split == "tax":
                    if (not isinstance(pred_label, str) and np.isclose(pred_label, truth_label)) or (isinstance(pred_label, str) and str(truth_label) in pred_label):
                        correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_fld(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        """
        proved/disproved/unknown
        """
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "fld")
            truth_label = ResultEvaluator.extract_answer(truth, "fld")
            if pred_label.strip().lower() == truth_label.strip().lower():
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_proofwriter(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        """
        true/false/unknown
        """
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "proofwriter")
            truth_label = ResultEvaluator.extract_answer(truth, "proofwriter")
            if pred_label.strip().lower() == truth_label.strip().lower():
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_ruletaker(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        """
        entailment/not entailment
        """
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "ruletaker")
            truth_label = ResultEvaluator.extract_answer(truth, "ruletaker")
            if pred_label.strip().lower() == truth_label.strip().lower():
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }
