import json

from typing import (
    List,
    Optional,
)

from loguru import logger

from src.generator import BaseGenerator


COT_EXECUTION_PROMPT_TEMPLATE = """
You are a question answering model.

1. Your task is to answer the query based on the teacher's chain-of-thought decision guide, using additional personal context.
2. Read the chain-of-thought decision guide carefully.
3. If the decision guide contains "VARIABLES" that may affect the outcome, extract them and determine their values based on the personal context.
4. Then, follow the decision guide and apply the extracted variables appropriately to derive the final answer.
5. The final answer must be preceded by '### Answer', and your response must end immediately after the answer.
6. If the options text says 'Empty,' it means no options are provided.
7. If the options are not empty, simply output one of the answers listed in the options without any additional explanation.
8. Never output any other explanation. Just output the answer.
9. If option follows a format like '[A] something', then output something as the answer instead of A.

### Personal Context
{personal_context}

### Chain-of-Thought
{cot}

### Query
{query}

### Options
{options}

### Answer
""".strip()


class COTExecutionService:
    prompt_template: str = COT_EXECUTION_PROMPT_TEMPLATE

    def __init__(self, generator: BaseGenerator) -> None:
        self.generator = generator

    def _preprocess(
        self,
        query: str,
        memories: List[str],
        options: Optional[List[str]] = None,
        cot: str = "",
    ) -> str:
        memories_json = json.dumps(memories, ensure_ascii=False, indent=2)
        options_str = "\n".join([f"- {option}" for option in options]).strip() if options else "Empty"
        prompt = self.prompt_template.format(
            query=query,
            personal_context=memories_json,
            options=options_str,
            cot=cot,
        )
        return prompt

    def _postprocess(self, response: str) -> str:
        final_answer = response.replace("### Final Answer", "").strip()
        return final_answer

    def execute_cot(
        self,
        query: str,
        ret_memory_contents: List[str],
        answer_element_universe: Optional[List[str]] = None,
        cot: str = "",
    ) -> str:
        prompt = self._preprocess(
            query=query,
            memories=ret_memory_contents,
            options=answer_element_universe,
            cot=cot,
        )
        logger.debug(f"[{self.__class__.__name__}] prompt: {prompt}")

        response = self.generator.generate(prompt=prompt)
        logger.debug(f"[{self.__class__.__name__}] response: {response}")

        final_answer = self._postprocess(response=response)
        logger.debug(f"[{self.__class__.__name__}] final_answer: {final_answer}")

        return final_answer
