from typing import Optional

from logger import logger
from utils import custom_xml_parser, tag

from ..tasks import TaskInstance
from ..tasks.aqua_rat_task import AquaRatTaskInstance
from .llm_paraphraser import LLMParaphraser


class CoTQALLMParaphraser(LLMParaphraser):
    """A paraphraser for use on the AquaRat task.

    Only paraphrases the CoT and leaves the answer the same."""

    def paraphrase(self, text: str, task: TaskInstance) -> Optional[str]:
        try:
            extracted_message = custom_xml_parser(text, ["reasoning", "answer"])
        except Exception:
            logger.exception("Could not extract reasoning tags from message")
            return super().paraphrase(text, task)
        if "reasoning" not in extracted_message or "answer" not in extracted_message:
            return super().paraphrase(text, task)
        else:
            reasoning = extracted_message["reasoning"]
            paraphrased_reasoning = super().paraphrase(reasoning, task)
            if paraphrased_reasoning is None:
                logger.warning("Paraphraser failed.")
                return None
            answer = extracted_message["answer"]
            return f'{tag(paraphrased_reasoning, "reasoning")}\n{tag(answer, "answer")}'


class CoTQAContextAwareLLMParaphraser(LLMParaphraser):
    def _context_aware_paraphrase(
        self, text: str, task: AquaRatTaskInstance
    ) -> Optional[str]:
        user_prompt = {
            "role": "user",
            "content": f"""The model has been asked the following question:

{tag(task.question, "question")}

Their answer was:
{tag(text, "answer")}

Return a paraphrased version of the answer. Remove any information which isn't directly related to answering the question.""",
        }

        return self.model.query_with_retry(
            [{"role": "system", "content": self.system_prompt}, user_prompt]
        )

    def paraphrase(self, text: str, task: AquaRatTaskInstance) -> Optional[str]:
        extracted_message = custom_xml_parser(text, ["reasoning", "answer"])

        if "reasoning" not in extracted_message or "answer" not in extracted_message:
            return self._context_aware_paraphrase(text, task)
        else:
            reasoning = extracted_message["reasoning"]
            paraphrased_reasoning = self._context_aware_paraphrase(reasoning, task)
            if paraphrased_reasoning is None:
                logger.warning("Paraphraser failed.")
                return None
            answer = extracted_message["answer"]
            return f'{tag(paraphrased_reasoning, "reasoning")}\n{tag(answer, "answer")}'
