import re
import math
from typing import List, Dict, Optional


def _is_missing(x) -> bool:
    if x is None:
        return True
    if isinstance(x, float):
        return math.isnan(x)
    if isinstance(x, str):
        return x.strip() == ""
    return False


def _norm_text(x) -> Optional[str]:
    if _is_missing(x):
        return None
    return str(x).strip()


_num_pat = re.compile(r"[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?")


def _parse_float(x) -> Optional[float]:
    s = _norm_text(x)
    if s is None:
        return None
    s = s.replace(",", "").replace("%", "").strip()
    m = _num_pat.search(s)
    if not m:
        return None
    try:
        return float(m.group(0))
    except ValueError:
        return None


class ResultEvaluator:

    @staticmethod
    def extract_answer(text: str, problem_type: str = "math") -> str:
        if problem_type == "math":
            def _strip(s: str) -> str:
                return s.strip().replace("\u200b", "").replace("\xa0", " ")

            def _gcd(a: int, b: int) -> int:
                while b:
                    a, b = b, a % b
                return abs(a)

            def _simplify_fraction(num_str: str, den_str: str) -> str:
                try:
                    n = int(num_str.strip())
                    d = int(den_str.strip())
                    if d == 0:
                        return f"{n}/{d}"
                    g = _gcd(n, d)
                    n //= g
                    d //= g
                    if d < 0:
                        n, d = -n, -d
                    return f"{n}/{d}"
                except:
                    return f"{_strip(num_str)}/{_strip(den_str)}"

            NUM = r"[+-]?(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d+)?"
            SIMPLE_FRAC = r"\b\d+\s*/\s*\d+\b"

            RE_TRAILING_HASH = re.compile(
                r"####\s*([^\n#]+)", re.IGNORECASE)  # #### 260
            RE_BOXED = re.compile(
                r"\\boxed\{([^}]+)\}", re.IGNORECASE)  # \boxed{...}
            # answer is ...
            RE_ANSWER_IS = re.compile(
                r"answer\s*is[:\s]*([^\n]+)", re.IGNORECASE)
            RE_MONEY = re.compile(r"\$(" + NUM + r")")  # $123.45
            RE_PERCENT = re.compile(r"(" + NUM + r")\s*\\?%")  # 90% / 90\%
            RE_NUMBER = re.compile(r"(" + NUM + r")")  # 1234
            RE_LATEX_FRAC = re.compile(
                r"\\frac\{([^}]+)\}\{([^}]+)\}")  # \frac{a}{b}
            RE_SIMPLE_FRAC = re.compile(SIMPLE_FRAC)  # 1/4
            # shot algebra expression
            RE_ALGEBRA = re.compile(r"^[\s()\-+*/^0-9a-zA-Z._]+$")

            def _post_normalize(ans: str) -> str:
                ans = _strip(ans)
                m = RE_PERCENT.search(ans)
                if m:
                    return m.group(1).replace(",", "") + "%"
                if RE_NUMBER.fullmatch(ans):
                    return ans.replace(",", "")
                m = RE_SIMPLE_FRAC.search(ans)
                if m and ans == m.group(0):
                    a, b = ans.split("/")
                    return _simplify_fraction(a, b)
                return ans

            # 1)  "#### 260"
            m = RE_TRAILING_HASH.search(text)
            if m:
                return _post_normalize(m.group(1))

            # 2) \boxed{...}
            m = RE_BOXED.search(text)
            if m:
                return _post_normalize(m.group(1))

            # 3) "answer is ..."
            m = RE_ANSWER_IS.search(text)
            if m:
                seg = _post_normalize(m.group(1))
                for pat in (RE_LATEX_FRAC, RE_SIMPLE_FRAC, RE_PERCENT, RE_MONEY, RE_NUMBER):
                    mm = pat.findall(seg)
                    if mm:
                        if pat is RE_LATEX_FRAC:
                            a, b = mm[-1]
                            return _simplify_fraction(a, b)
                        if pat is RE_SIMPLE_FRAC:
                            last = mm[-1]
                            a, b = last.split("/")
                            return _simplify_fraction(a, b)
                        last = mm[-1]
                        if isinstance(last, tuple):
                            last = last[0]
                        return _post_normalize(str(last))
                if RE_ALGEBRA.fullmatch(seg):
                    return seg
                return seg

            # 4) LaTeX
            mm = RE_LATEX_FRAC.findall(text)
            if mm:
                a, b = mm[-1]
                return _simplify_fraction(a, b)

            # 5)  1/4
            mm = RE_SIMPLE_FRAC.findall(text)
            if mm:
                a, b = mm[-1].split("/")
                return _simplify_fraction(a, b)

            # 6) %
            mm = RE_PERCENT.findall(text)
            if mm:
                return mm[-1].replace(",", "") + "%"

            # 7) $
            mm = RE_MONEY.findall(text)
            if mm:
                return mm[-1].replace(",", "")

            # 8)
            mm = RE_NUMBER.findall(text)
            if mm:
                for val in reversed(mm):
                    a = val.replace(",", "").strip()
                    try:
                        float(a)
                        return a
                    except:
                        continue

            # 9) "1-x"
            whole = _strip(text)
            if len(whole) <= 64 and RE_ALGEBRA.fullmatch(whole):
                return whole

            # 10) all failed
            return None

        elif problem_type == "folio":
            lower = (text or "").lower()

            patterns = [
                r'\\boxed\{\s*(true|false|uncertain)\s*\}',
                r'\b(?:final\s*)?answer\s*[:：-]?\s*(true|false|uncertain)\b',
                r'\blabel\s*[:：-]?\s*(true|false|uncertain)\b',
                r'\bprediction\s*[:：-]?\s*(true|false|uncertain)\b',
                r'\b(true|false|uncertain)\b',
            ]

            for pattern in patterns:
                matches = re.findall(pattern, lower, flags=re.IGNORECASE)
                if matches:
                    raw = matches[-1] if isinstance(matches[-1],
                                                    str) else matches[-1][0]
                    label = ResultEvaluator._normalize_folio_label(raw)
                    if label:
                        return label

            return text.strip()

        elif problem_type == "arlsat":
            pattern = r'\b([0-4])\b'
            matches = re.findall(pattern, text.strip())
            if matches:
                return matches[0]
            return ""

        elif problem_type == "logiqa":
            lower = (text or "").lower()
            patterns = [
                r'\\boxed\{\s*(entailment|not[_\- ]?entailment)\s*\}',
                r'\b(?:final\s*)?answer\s*[:：-]?\s*(entailment|not[_\- ]?entailment)\b',
                r'\blabel\s*[:：-]?\s*(entailment|not[_\- ]?entailment)\b',
                r'\bprediction\s*[:：-]?\s*(entailment|not[_\- ]?entailment)\b',
                r'\b(entailment|not[_\- ]?entailment)\b',
            ]
            for pattern in patterns:
                matches = re.findall(pattern, lower, flags=re.IGNORECASE)
                if matches:
                    raw = matches[-1] if isinstance(matches[-1],
                                                    str) else matches[-1][0]
                    label = ResultEvaluator._normalize_logiqa_label(raw)
                    if label:
                        return label

            rough = ResultEvaluator._normalize_logiqa_label(lower.strip())
            return rough if rough else text.strip()

        elif problem_type == "reclor":
            pattern = r'\b([0-3])\b'
            matches = re.findall(pattern, text.strip())
            if matches:
                return matches[0]
            return ""

        elif problem_type == "abductionr":
            lower = (text or "").lower()

            patterns = [
                r'\\boxed\{\s*(true|false)\s*\}',
                r'\b(?:final\s*)?answer\s*[:：-]?\s*(true|false)\b',
                r'\blabel\s*[:：-]?\s*(true|false)\b',
                r'\bprediction\s*[:：-]?\s*(true|false)\b',
                r'\b(true|false)\b',
            ]

            for pattern in patterns:
                matches = re.findall(pattern, lower, flags=re.IGNORECASE)
                if matches:
                    raw = matches[-1] if isinstance(matches[-1],
                                                    str) else matches[-1][0]
                    label = ResultEvaluator._normalize_abductionr_label(raw)
                    if label:
                        return label
            return text.strip()

        return text.strip()

    @staticmethod
    def evaluate_mathdata(predictions: List[str], ground_truths: List[str],
                          rel_tol: float = 1e-9, abs_tol: float = 1e-9) -> Dict[str, float]:
        correct = 0
        total = 0
        for pred, truth in zip(predictions, ground_truths):
            pred_answer = ResultEvaluator.extract_answer(pred, "math")
            truth_answer = ResultEvaluator.extract_answer(truth, "math")

            p = _norm_text(pred_answer)
            t = _norm_text(truth_answer)

            if t is None:
                continue

            total += 1
            if p is None:
                continue

            pn, tn = _parse_float(p), _parse_float(t)
            if pn is not None and tn is not None:
                if math.isclose(pn, tn, rel_tol=rel_tol, abs_tol=abs_tol):
                    correct += 1
            else:
                if p.lower() == t.lower():
                    correct += 1

        accuracy = (correct / total) if total > 0 else 0.0
        return {"accuracy": accuracy, "correct": correct, "total": total}

    @staticmethod
    def _normalize_folio_label(s: str) -> str:
        """
        Normalization to: 'true' / 'false' / 'uncertain'
        Accepts common synonyms such as entailed/contradicted/unknown, among others.
        """
        s = (s or "").strip().lower()
        true_set = {"true", "entailed", "entailment", "valid", "yes"}
        false_set = {"false", "contradicted", "contradiction", "no"}
        uncertain_set = {
            "uncertain", "unknown", "indeterminate", "undetermined",
            "cannot be determined", "can't be determined", "not sure", "both"
        }
        if s in true_set:
            return "true"
        if s in false_set:
            return "false"
        if s in uncertain_set:
            return "uncertain"
        return ""

    @staticmethod
    def _normalize_logiqa_label(s: str) -> str:
        """
         'entailment' /  'not-entailment'
        """
        s = (s or "").strip().lower()

        if s in {"entailment"}:
            return "entailment"
        if s in {"not_entailment", "not-entailment", "notentailment"}:
            return "not-entailment"

        alias_map = {
            "entails": "entailment",
            "entailed": "entailment",
            "not entails": "not-entailment",
            "not entailed": "not-entailment",
            "contradicted": "not-entailment",
            "inconsistent": "not-entailment",
        }
        return alias_map.get(s, "")

    @staticmethod
    def _normalize_abductionr_label(s: str) -> str:
        """
         'true' / 'false'
        """
        s = (s or "").strip().lower()
        true_set = {"true", "entailed", "entailment", "valid", "yes"}
        false_set = {"false", "contradicted", "contradiction", "no"}
        if s in true_set:
            return "true"
        if s in false_set:
            return "false"
        return ""

    @staticmethod
    def evaluate_folio(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        """ FOLIO (true/false/uncertain)"""
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "folio")
            truth_label = ResultEvaluator.extract_answer(truth, "folio")
            if pred_label.strip().lower() == truth_label.strip().lower():
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_arlsat(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "arlsat")
            truth_label = ResultEvaluator.extract_answer(truth, "arlsat")

            if pred_label and truth_label and pred_label == truth_label:
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_logiqa(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        """
        (entailment / not-entailment)
        """
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "logiqa")
            truth_label = ResultEvaluator.extract_answer(truth, "logiqa")

            pred_label = ResultEvaluator._normalize_logiqa_label(pred_label)
            truth_label = ResultEvaluator._normalize_logiqa_label(truth_label)

            if pred_label and truth_label and pred_label == truth_label:
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_reclor(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "reclor")
            truth_label = ResultEvaluator.extract_answer(truth, "reclor")

            if pred_label and truth_label and pred_label == truth_label:
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }

    @staticmethod
    def evaluate_abductionr(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        """ AbductionR (true/false)"""
        correct = 0
        total = len(predictions)

        for pred, truth in zip(predictions, ground_truths):
            pred_label = ResultEvaluator.extract_answer(pred, "abductionr")
            truth_label = ResultEvaluator.extract_answer(truth, "abductionr")
            if pred_label.strip().lower() == truth_label.strip().lower():
                correct += 1

        accuracy = correct / total if total > 0 else 0.0
        return {
            "accuracy": accuracy,
            "correct": correct,
            "total": total
        }
