from typing import Tuple

from llm_mcts.llm_generation_interface import GenerationRequest, Message, Model
from llm_mcts.tasks.swe_bench.problem import SWEBenchProblem


class SWEBenchLLMJudge:
    def __init__(self, model: Model):
        self.model = model

    def check_test_score(
        self, problem: SWEBenchProblem, answer: str
    ) -> Tuple[str, float]:
        """
        Check if the LLM answer is correct against the groundtruth
        """
        prompt = f"""
Your task is to compare the other AI assistant's answer to the mathematical problem, and the groundtruth answer, and check if the assistants' answer is correct.

You only need to check if the assistant's answer is semantically the same as the groundtruth.

In your response, please include a score of the assistant's answer. If it is correct, just output 1, and if the answer is incorrect, output 0.

You should output your score in the block surrounded by "```score' and "```" in the end of your response.

Assistant' Answer:
{answer}

Groundtruth Answer:
{problem.patch}
"""

        results = list(
            self.model.generate(
                [GenerationRequest(messages=[Message(role="user", content=prompt)])]
            )
        )

        # We don't use multimodel for the judge!
        assert len(results) == 1

        try:
            score_int = int(results[0].parse_score_block())
            score = float(score_int)
        except Exception:
            score = 0.0

        return results[0].generation, score

    def estimate_score(
        self, problem: SWEBenchProblem, answer: str
    ) -> Tuple[str, float]:
        """
        Estimate the corrrectness of LLM answer
        """
        problem_str = problem.text
        prompt = f"""
Your task is to judge the correctness of the other AI assistant's answer to the given software engineering problem.
You should first elaborate the detailed reasoning behind your scoring, and then provide your score
for the given answer, based on how much is the given answer correct for the given software engineering problem.

Your score should be a single integer from 0 to 5, i.e. 0, 1, 2, 3, 4 or 5. The lower score means that the answer is not correct, and the higher score means
that the answer looks correct. If you are certain that the answer is correct, give it a score 5.

You should output your score in the block surrounded by "```score' and "```" in the end of your response.

Problem:
{problem_str}

Answer:
{answer}
"""

        results = list(
            self.model.generate(
                [GenerationRequest(messages=[Message(role="user", content=prompt)])]
            )
        )

        # We don't use multimodel for the judge!
        assert len(results) == 1

        try:
            score_int = int(results[0].parse_score_block())
            score = int(score_int) / 5
        except Exception:
            score = 0.0

        return results[0].generation, score
