import json
from typing import Tuple

from llm_mcts.llm_generation_interface import GenerationRequest, Message, Model
from llm_mcts.prompts.math_vista.math_vista_code import (
    create_one_query,
    SHOT_EXAMPLES,
    create_extract_answer_prompt,
    normalize_extracted_answer,
    safe_equal,
)
from llm_mcts.tasks.math_vista.problem import MathVistaProblem


class MathVistaLLMJudge:
    def __init__(self, model: Model):
        self.model = model

    def check_test_score(
        self, problem: MathVistaProblem, answer: str
    ) -> Tuple[str, float]:
        """
        Check if the LLM answer is correct against the groundtruth
        """
        prompt = create_extract_answer_prompt(
            response=answer, problem=problem, quick_extract=False
        )

        results = list(
            self.model.generate(
                [GenerationRequest(messages=[Message(role="user", content=prompt)])]
            )
        )

        # We don't use multimodel for the judge!
        assert len(results) == 1

        try:
            prediction = normalize_extracted_answer(
                extraction=results[0].generation,
                choices=problem.choices,
                question_type=problem.question_type,
                answer_type=problem.answer_type,
                precision=problem.precision,
                ignore_empty_extractions=False,
            )
            # verify the prediction is true or false
            true_false = safe_equal(prediction, problem.answer)
            score = 1.0 if true_false else 0.0
        except Exception:
            score = 0.0

        answer_dict = {
            "generation": results[0].generation,
            "prediction": prediction,
            "ground_truth": problem.answer,
            "true_false": true_false,
            "score": score,
        }  # NOTE: output as JSON for saving all the information
        return json.dumps(answer_dict), score

    def estimate_score(
        self, problem: MathVistaProblem, answer: str
    ) -> Tuple[str, float]:
        """
        Estimate the corrrectness of LLM answer
        """
        problem_prompt = create_one_query(
            problem=problem,
            examples=SHOT_EXAMPLES,
            shot_num=0,
            shot_type="solution",
            use_caption=False,
            use_ocr=False,
        )
        prompt = f"""
Your task is to judge the correctness of the other AI assistant's answer to the given mathematical problem.
You should first elaborate the detailed reasoning behind your scoring, and then provide your score
for the given answer, based on how much is the given answer correct for the given mathematical problem.
Your score should be a single integer from 0 to 5, i.e. 0, 1, 2, 3, 4 or 5. The lower score means that the answer is not correct, and the higher score means
that the answer looks correct. If you are certain that the answer is correct, give it a score 5.
You should output your score in the block surrounded by "```score' and "```" in the end of your response.
Problem:
{problem_prompt}
Answer:
{answer}
"""

        results = list(
            self.model.generate(
                [
                    GenerationRequest(
                        messages=[
                            Message(
                                role="user", content=[problem.decoded_image, prompt]
                            )
                        ]
                    )
                ]
            )
        )

        # We don't use multimodel for the judge!
        assert len(results) == 1

        try:
            score_int = int(results[0].parse_score_block())
            score = int(score_int) / 5
        except Exception:
            score = 0.0

        return results[0].generation, score

    def extract_answer(self, problem: MathVistaProblem, answer: str) -> str:
        """
        Extract the answer from the response
        """
        prompt = create_extract_answer_prompt(
            response=answer, problem=problem, quick_extract=False
        )

        results = list(
            self.model.generate(
                [GenerationRequest(messages=[Message(role="user", content=prompt)])]
            )
        )

        # We don't use multimodel for the judge!
        assert len(results) == 1

        try:
            prediction = normalize_extracted_answer(
                extraction=results[0].generation,
                choices=problem.choices,
                question_type=problem.question_type,
                answer_type=problem.answer_type,
                precision=problem.precision,
                ignore_empty_extractions=False,
            )
            return prediction
        except Exception:
            return ""
