import abc
import dataclasses
import math
import random
import re
from typing import Any, Dict, List, Optional, Sequence

from bd_mcts.tasks.base import ProbResult, Task

_BOXED_OPEN_RE = re.compile(r"\\(?:boxed|fbox)\s*\{")
_ANSWER_LINE_RE = re.compile(r"(?i)(?:final\s+answer|answer)\s*[:=]\s*(.+)$")
_ANSWER_IS_RE = re.compile(r"(?i)(?:final\s+answer|answer)\s+is\s+(.+)$")


SUBSTITUTIONS = [
    ("an ", ""),
    ("a ", ""),
    (".$", "$"),
    ("\\$", ""),
    (r"\ ", ""),
    (" ", ""),
    ("mbox", "text"),
    (",\\text{and}", ","),
    ("\\text{and}", ","),
    ("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
    "square",
    "ways",
    "integers",
    "dollars",
    "mph",
    "inches",
    "ft",
    "hours",
    "km",
    "units",
    "\\ldots",
    "sue",
    "points",
    "feet",
    "minutes",
    "digits",
    "cents",
    "degrees",
    "cm",
    "gm",
    "pounds",
    "meters",
    "meals",
    "edges",
    "students",
    "childrentickets",
    "multiples",
    "\\text{s}",
    "\\text{.}",
    "\\text{\ns}",
    "\\text{}^2",
    "\\text{}^3",
    "\\text{\n}",
    "\\text{}",
    r"\mathrm{th}",
    r"^\circ",
    r"^{\circ}",
    r"\;",
    r",\!",
    "{,}",
    '"',
    "\\dots",
]


def normalize_final_answer(final_answer: str) -> str:
    """
    Normalize a final answer to a quantitative reasoning question.

    Copied character for character from appendix D of Lewkowycz et al. (2022)
    """
    final_answer = final_answer.split("=")[-1]

    for before, after in SUBSTITUTIONS:
        final_answer = final_answer.replace(before, after)
    for expr in REMOVED_EXPRESSIONS:
        final_answer = final_answer.replace(expr, "")

    # Extract answer that is in LaTeX math, is bold,
    # is surrounded by a box, etc.
    final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
    final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)

    # Normalize shorthand TeX:
    #  \fracab -> \frac{a}{b}
    #  \frac{abc}{bef} -> \frac{abc}{bef}
    #  \fracabc -> \frac{a}{b}c
    #  \sqrta -> \sqrt{a}
    #  \sqrtab -> sqrt{a}b
    final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
    final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
    final_answer = final_answer.replace("$", "")

    # Normalize 100,000 -> 100000
    if final_answer.replace(",", "").isdigit():
        final_answer = final_answer.replace(",", "")

    return final_answer


def _wrap_for_math_verify(latex: str) -> str:
    latex = latex.strip()
    # math-verify's LaTeX extractor expects explicit math delimiters like `$...$`.
    # Our pipeline often strips delimiters during answer extraction, so we wrap again
    # to avoid falling back to plain-number extraction (which can cause false positives).
    return f"${latex}$"


def remove_boxed(s: str) -> str:
    if "\\boxed " in s:
        left = "\\boxed "
        assert s[: len(left)] == left
        return s[len(left) :]

    left = "\\boxed{"

    assert s[: len(left)] == left
    assert s[-1] == "}"

    return s[len(left) : -1]


def last_boxed_only_string(string: str) -> Optional[str]:
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx : right_brace_idx + 1]

    return retval


def extract_last_boxed(text: str) -> Optional[str]:
    """
    Return the content of the LAST \boxed{...} or \fbox{...} in text.
    Supports nested braces.
    """
    matches = list(_BOXED_OPEN_RE.finditer(text))
    if not matches:
        return None

    start = matches[-1].end()  # after "{"
    depth = 1
    i = start
    while i < len(text):
        ch = text[i]
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return text[start:i]
        i += 1
    return None


def strip_math_delimiters(s: str) -> str:
    s = s.strip()
    # remove surrounding math delimiters
    if len(s) >= 4 and s.startswith("$$") and s.endswith("$$"):
        s = s[2:-2].strip()
    elif len(s) >= 2 and s[0] == "$" and s[-1] == "$":
        s = s[1:-1].strip()
    elif len(s) >= 4 and s.startswith("\\(") and s.endswith("\\)"):
        s = s[2:-2].strip()
    elif len(s) >= 4 and s.startswith("\\[") and s.endswith("\\]"):
        s = s[2:-2].strip()
    # remove trailing punctuation commonly added after answers
    s = re.sub(r"[\s\.;:,]+$", "", s.strip())
    return s


def _strip_answer_prefix(line: str) -> str:
    line = line.strip()
    line = re.sub(
        r"(?i)^(?:final\s+answer|answer)\s*[:=]\s*",
        "",
        line,
    )
    line = re.sub(r"(?i)^(?:final\s+answer|answer)\s+is\s+", "", line)
    line = re.sub(
        r"(?i)^(?:therefore|thus|so|hence),?\s*(?:the\s+)?answer\s+is\s+",
        "",
        line,
    )
    line = re.sub(r"(?i)^(?:the\s+)?answer\s+is\s+", "", line)
    return line.strip()


def extract_final_answer(text: str) -> str:
    boxed = extract_last_boxed(text)
    if boxed is not None:
        return strip_math_delimiters(boxed)

    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    for line in reversed(lines):
        match = _ANSWER_LINE_RE.search(line)
        if match:
            return strip_math_delimiters(match.group(1))
        match = _ANSWER_IS_RE.search(line)
        if match:
            return strip_math_delimiters(match.group(1))

    if lines:
        return strip_math_delimiters(_strip_answer_prefix(lines[-1]))
    return ""


def normalize_for_fallback_compare(s: str) -> str:
    s = strip_math_delimiters(s)
    s = s.replace("\\left", "").replace("\\right", "")
    s = re.sub(r"\\[,;!]", "", s)
    s = re.sub(r"\s+", "", s)
    return s


# ===== PRM Scorer (Qwen/Qwen2.5-Math-PRM-7B) =====


class QwenPRMScorer:
    """
    Qwen2.5-Math-PRM-7B scoring.
    - Split response into steps (default: by double newlines)
    - Join with "<extra_0>" and add trailing "<extra_0>"
    - For each <extra_0> token position, take P(positive)
    - Aggregate by product (official BoN practice)
    """

    def __init__(
        self,
        model_name: str = "Qwen/Qwen2.5-Math-PRM-7B",
        device_map: str = "auto",
        torch_dtype: str = "bfloat16",
        trust_remote_code: bool = True,
    ) -> None:
        self.model_name = model_name
        self.device_map = device_map
        self.torch_dtype = torch_dtype
        self.trust_remote_code = trust_remote_code
        self._tokenizer = None
        self._model = None

    def _lazy_load(self) -> None:
        if self._tokenizer is not None and self._model is not None:
            return
        import torch
        from transformers import AutoModel, AutoTokenizer

        self._tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            trust_remote_code=self.trust_remote_code,
        )
        dtype = getattr(torch, self.torch_dtype)
        self._model = AutoModel.from_pretrained(
            self.model_name,
            device_map=self.device_map,
            torch_dtype=dtype,
            trust_remote_code=self.trust_remote_code,
        ).eval()

    @staticmethod
    def _split_steps(response: str) -> List[str]:
        # Qwen's recommended formatting splits steps by double line breaks.
        steps = [c.strip() for c in response.split("\n\n") if c.strip()]
        if steps:
            return steps
        # fallback: non-empty lines
        lines = [ln.strip() for ln in response.splitlines() if ln.strip()]
        return lines if lines else [response.strip()]

    @staticmethod
    def _assistant_with_separators(steps: Sequence[str]) -> str:
        return "<extra_0>".join(steps) + "<extra_0>"

    @staticmethod
    def _make_step_rewards(logits, token_masks):
        import torch.nn.functional as F

        probabilities = F.softmax(logits, dim=-1)
        probabilities = probabilities * token_masks.unsqueeze(-1)

        all_scores: List[List[float]] = []
        for i in range(probabilities.size(0)):
            sample = probabilities[i]
            positive_probs = sample[sample != 0].view(-1, 2)[:, 1]
            all_scores.append(positive_probs.detach().cpu().tolist())
        return all_scores

    @staticmethod
    def _aggregate(step_scores: Sequence[float], mode: str = "prod") -> float:
        if not step_scores:
            return 0.0
        if mode == "prod":
            p = 1.0
            for s in step_scores:
                p *= float(s)
            return float(p)
        if mode == "mean":
            return float(sum(step_scores) / len(step_scores))
        if mode == "min":
            return float(min(step_scores))
        raise ValueError(f"Unknown aggregation mode: {mode}")

    def score(
        self, system: str, query: str, response: str, agg: str = "prod"
    ) -> Dict[str, Any]:
        self._lazy_load()
        import torch

        tokenizer = self._tokenizer
        model = self._model
        assert tokenizer is not None and model is not None

        steps = self._split_steps(response)
        assistant_content = self._assistant_with_separators(steps)

        messages = [
            {"role": "system", "content": system},
            {"role": "user", "content": query},
            {"role": "assistant", "content": assistant_content},
        ]
        conversation_str = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )

        with torch.inference_mode():
            inputs = tokenizer(conversation_str, return_tensors="pt").to(model.device)
            outputs = model(input_ids=inputs["input_ids"], use_cache=False)

            step_sep_id = tokenizer.encode("<extra_0>")[0]
            token_masks = inputs["input_ids"] == step_sep_id

            step_rewards = self._make_step_rewards(outputs[0], token_masks)[0]

        reward = self._aggregate(step_rewards, mode=agg)
        return {
            "reward": reward,  # [0,1]
            "step_rewards": step_rewards,  # per-step [0,1]
            "n_steps": len(step_rewards),
            "aggregation": agg,
        }


# ===== Task implementation =====


class FastMathTask(Task):
    """
    samples: [{question_key: str, answer_key: str, ...}, ...]
    """

    def __init__(
        self,
        samples: Sequence[dict[str, Any]],
        *,
        question_key: str = "question",
        answer_key: str = "answer",
        system_prompt: str = "Please reason step by step, and put your final answer within \\boxed{}.",
        use_math_verify: bool = True,
        prm_scorer: Optional[QwenPRMScorer] = None,
        parse_mode: str = "response",
    ) -> None:
        if parse_mode not in ("response", "final"):
            raise ValueError("parse_mode must be 'response' or 'final'")
        self.samples = list(samples)
        self._samples = self.samples
        self.question_key = question_key
        self.answer_key = answer_key
        self.system_prompt = system_prompt
        self._parse_mode = parse_mode

        self.prm = prm_scorer or QwenPRMScorer()

        # Math-Verify (optional)
        self._mv_parse = None
        self._mv_verify = None
        self._gold_parsed_cache: dict[int, Any] = {}

        if use_math_verify:
            try:
                from math_verify import parse as mv_parse
                from math_verify import verify as mv_verify

                self._mv_parse = mv_parse
                self._mv_verify = mv_verify
            except Exception:
                # math-verify not installed / incompatible runtime
                self._mv_parse = None
                self._mv_verify = None

    def _question(self, sample_id: int) -> str:
        sample = self.samples[sample_id]
        for key in (self.question_key, "problem", "question"):
            if key in sample:
                return str(sample[key])
        raise KeyError(f"Missing question key in sample {sample_id}")

    def get_prompt(self, sample_id: int) -> str:
        return self._question(sample_id)

    def _gold(self, sample_id: int) -> str:
        sample = self.samples[sample_id]
        for key in (self.answer_key, "solution", "final_answer", "answer"):
            if key in sample:
                gold_raw = str(sample[key])
                break
        else:
            raise KeyError(f"Missing answer key in sample {sample_id}")
        gold = extract_final_answer(gold_raw)
        return gold if gold else strip_math_delimiters(gold_raw)

    def parse_answer(self, sample_id: int, lm_response: str) -> str:
        if self._parse_mode == "response":
            return lm_response.strip()
        return extract_final_answer(lm_response)

    def submit(self, sample_id: int, answer: str) -> ProbResult:
        gold = self._gold(sample_id)
        pred = extract_final_answer(answer)
        if not pred:
            pred = strip_math_delimiters(answer)

        detail: dict[str, Any] = {
            "gold": gold,
            "pred": pred,
        }

        # 1) Equivalence check via Math-Verify.
        if self._mv_parse is not None and self._mv_verify is not None:
            try:
                if sample_id not in self._gold_parsed_cache:
                    self._gold_parsed_cache[sample_id] = self._mv_parse(
                        _wrap_for_math_verify(normalize_final_answer(gold))
                    )
                gold_parsed = self._gold_parsed_cache[sample_id]
                pred_parsed = self._mv_parse(
                    _wrap_for_math_verify(normalize_final_answer(pred))
                )
                ok = bool(
                    self._mv_verify(
                        gold_parsed,
                        pred_parsed,
                    )
                )  # order important
                return ProbResult(
                    metric=1.0 if ok else 0.0,
                    sample_detail={**detail, "correct": ok, "method": "math-verify"},
                )
            except Exception as e:
                detail["method"] = "math-verify"
                detail["error"] = repr(e)

        # 2) Fallback: weak, but minimally normalized string comparison.
        ok = normalize_for_fallback_compare(gold) == normalize_for_fallback_compare(
            pred
        )
        return ProbResult(
            metric=1.0 if ok else 0.0,
            sample_detail={**detail, "correct": ok, "method": "string-normalize"},
        )

    def evaluate(self, sample_id: int, answer: str) -> ProbResult:
        """
        PRM score: `answer` is assumed to be the LM raw output (including reasoning).
        The returned metric is in [0, 1].
        """
        if self._parse_mode != "response":
            return self.submit(sample_id, answer)
        q = self._question(sample_id)
        res = self.prm.score(
            system=self.system_prompt, query=q, response=answer, agg="prod"
        )
        return ProbResult(metric=float(res["reward"]), sample_detail=res)


def _normalize_dataset_configs(dataset_config: Any) -> list[str] | None:
    if dataset_config is None:
        return None
    if isinstance(dataset_config, str):
        value = dataset_config.strip()
        if not value:
            return None
        lowered = value.lower()
        if lowered in ("all", "all_categories", "*", "none", "null"):
            return None
        if "," in value:
            parts = [part.strip() for part in value.split(",") if part.strip()]
            return parts or None
        return [value]
    if isinstance(dataset_config, Sequence):
        configs = [str(item).strip() for item in dataset_config]
        configs = [item for item in configs if item]
        return configs or None
    value = str(dataset_config).strip()
    return [value] if value else None


def _resolve_dataset_configs(dataset: str, dataset_config: Any) -> list[str] | None:
    configs = _normalize_dataset_configs(dataset_config)
    if configs is not None:
        return configs
    try:
        from datasets import get_dataset_config_names
    except Exception:
        return None
    try:
        available = get_dataset_config_names(dataset)
    except Exception:
        return None
    if not available:
        return None
    return list(available)


def _sample_dataset(
    ds: Any,
    limit: int,
    rng: random.Random | None,
    shuffle: bool,
) -> list[dict[str, Any]]:
    if limit <= 0:
        return []
    total = len(ds)
    if limit >= total:
        return list(ds)
    if shuffle:
        if rng is None:
            rng = random.Random()
        indices = list(range(total))
        rng.shuffle(indices)
        indices = indices[:limit]
    else:
        indices = list(range(limit))
    return list(ds.select(indices))


def _load_hendrycks_math_samples(
    *,
    dataset: str,
    split: str,
    dataset_config: Any,
    per_category_limit: int | None,
    per_category_shuffle: bool,
    per_category_seed: int | None,
) -> list[dict[str, Any]]:
    from datasets import load_dataset

    configs = _resolve_dataset_configs(dataset, dataset_config)
    limit = None
    if per_category_limit is not None:
        try:
            limit = int(per_category_limit)
        except (TypeError, ValueError) as exc:
            raise ValueError("per_category_limit must be an integer") from exc
    seed = None
    if per_category_seed is not None:
        try:
            seed = int(per_category_seed)
        except (TypeError, ValueError) as exc:
            raise ValueError("per_category_seed must be an integer") from exc
    rng = random.Random(seed) if limit is not None and per_category_shuffle else None

    if configs is None:
        ds = load_dataset(dataset, split=split)
        if limit is None:
            return list(ds)
        return _sample_dataset(ds, limit, rng, per_category_shuffle)

    samples: list[dict[str, Any]] = []
    for config in configs:
        ds = load_dataset(dataset, config, split=split)
        if limit is None:
            samples.extend(list(ds))
        else:
            samples.extend(_sample_dataset(ds, limit, rng, per_category_shuffle))
    return samples


def make_hendrycks_math_task(
    *,
    dataset: str = "hendrycks/math",
    split: str = "test",
    dataset_config: str | Sequence[str] | None = None,
    question_key: str = "problem",
    answer_key: str = "solution",
    system_prompt: str = "Please reason step by step, and put your final answer within \\boxed{}.",
    use_math_verify: bool = True,
    prm_scorer: Optional[QwenPRMScorer] = None,
    parse_mode: str = "response",
    per_category_limit: int | None = None,
    per_category_shuffle: bool = True,
    per_category_seed: int | None = None,
) -> FastMathTask:
    samples = _load_hendrycks_math_samples(
        dataset=dataset,
        split=split,
        dataset_config=dataset_config,
        per_category_limit=per_category_limit,
        per_category_shuffle=per_category_shuffle,
        per_category_seed=per_category_seed,
    )
    return FastMathTask(
        samples,
        question_key=question_key,
        answer_key=answer_key,
        system_prompt=system_prompt,
        use_math_verify=use_math_verify,
        prm_scorer=prm_scorer,
        parse_mode=parse_mode,
    )
