import re
from typing import Optional, Tuple

from llm_mcts.llm_generation_interface import GenerationRequest, Message, Model
from llm_mcts.tasks.omni_math.omni_judge import (
    OmniJudge,
    MODEL_NAME as OMNI_JUDGE_MODEL_NAME,
)
from llm_mcts.tasks.omni_math.problem import OmniMathProblem
from llm_mcts.tasks.omni_math.reward_model import RewardModel
from llm_mcts.tasks.omni_math.template import (
    GPT_EVALUATION_PROMPT,
    GPT_EVALUATION_PROMPT_WITH_SCORE_WITHOUT_REFERENCE,
)


def extract_student_final_answer(text: str) -> str:
    """
    Extracts the content of '## Student Final Answer' from the given string.
    Returns None if no match is found.
    """
    # This pattern captures everything after '## Student Final Answer' until the next '##'
    pattern = r"## Student Final Answer\s*(.*?)(?=##)"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return None


def extract_equivalence_judgement(text: str) -> str:
    """
    Extracts the result of 'Equivalence Judgement' (TRUE or FALSE) from the given string.
    Returns None if no match is found.
    """
    pattern = r"## Equivalence Judgement\s*\n(.*?)\n"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()  # Remove any leading/trailing whitespace
    return None


def extract_score(text: str) -> int:
    """
    Extracts the score (0-10) from the given string and returns it as an integer.
    Returns None if no match is found.
    """
    pattern = r"## Score.*?\n(\d+)"
    match = re.search(pattern, text)
    if match:
        return int(match.group(1))
    return None


class LLMJudge:
    def __init__(
        self,
        model: Model | str,
        reward_model_name: Optional[str] = None,
        only_reward_model: bool = False,
        is_sigmoid: bool = True,
    ):
        if model == OMNI_JUDGE_MODEL_NAME:
            self.model = OmniJudge()
        else:
            self.model = model
        if reward_model_name is not None:
            self.reward_model = RewardModel(reward_model_name, is_sigmoid=is_sigmoid)
            self.only_reward_model = only_reward_model
        else:
            self.reward_model = None
            self.only_reward_model = False

    def check_test_score(
        self, problem: OmniMathProblem, answer: str
    ) -> Tuple[str, float]:
        """
        Check if the LLM answer is correct against the groundtruth
        """
        if isinstance(self.model, OmniJudge):
            return self.check_test_score_with_omni_judge(problem, answer)
        else:
            return self.check_test_score_with_api(problem, answer)

    def check_test_score_with_omni_judge(
        self, problem: OmniMathProblem, answer: str
    ) -> Tuple[str, float]:
        """
        LLM Judge with OmniJudge
        """
        result_text, _, judgement, _ = self.model.get_judge(
            problem.problem, problem.answer, answer
        )

        if judgement is None:
            score = 0.0
        elif judgement == "TRUE":
            score = 1.0
        elif judgement == "FALSE":
            score = 0.0
        else:
            score = 0.0

        return result_text, score

    def check_test_score_with_api(
        self, problem: OmniMathProblem, answer: str
    ) -> Tuple[str, float]:
        """
        LLM Judge with gpt-4o API
        """
        prompt = GPT_EVALUATION_PROMPT.replace("{{Problem}}", problem.problem)
        prompt = prompt.replace("{{Reference Answer}}", problem.answer)
        prompt = prompt.replace("{{Solution}}", answer)

        results = list(
            self.model.generate(
                [GenerationRequest(messages=[Message(role="user", content=prompt)])]
            )
        )

        # We don't use multimodel for the judge!
        assert len(results) == 1

        try:
            # get score
            score_str = extract_equivalence_judgement(results[0].generation)
            if score_str == "TRUE":
                score_int = 1
            elif score_str == "FALSE":
                score_int = 0
            else:
                score_int = 0
            score = float(score_int)
        except Exception:
            score = 0.0

        return results[0].generation, score

    def estimate_score(
        self, problem: OmniMathProblem, answer: str
    ) -> Tuple[str, float]:
        """
        Estimate the corrrectness of LLM answer
        """

        generation_text = ""
        score = 0.0

        if self.only_reward_model == False or self.reward_model is None:
            if isinstance(self.model, OmniJudge):
                raise ValueError("OmniJudge is not supported for score estimation")

            prompt = GPT_EVALUATION_PROMPT_WITH_SCORE_WITHOUT_REFERENCE.replace(
                "{{Problem}}", problem.problem
            )
            prompt = prompt.replace("{{Solution}}", answer)

            results = list(
                self.model.generate(
                    [GenerationRequest(messages=[Message(role="user", content=prompt)])]
                )
            )

            # We don't use multimodel for the judge!
            assert len(results) == 1

            try:
                score_int = extract_score(results[0].generation)
                score += float(score_int) / 10
            except Exception:
                score += 0.0

            generation_text = results[0].generation

        if self.reward_model is not None:
            score += self.reward_model.get_reward(problem.problem, answer)

        return generation_text, score

    def extract_answer(self, answer: str) -> str:
        return extract_student_final_answer(answer)
