from typing import (
    Dict,
    List,
    Optional,
)

from loguru import logger

from src.generator import BaseGenerator
from src.schema import Document


COT_GENERATION_PROMPT_TEMPLATE = """
You are a chain-of-thought generator.

1. You are given a query and a list of possible options.
2. Your task is to provide a step-by-step reasoning guide to help a student answer the query.
3. The reasoning guide should clearly show your reasoning process so that the student can easily apply it to their query.
4. Analyze the query and write a reasoning guide for the student to follow.
5. If there is a lack of information relevant to the query, you must identify the missing elements as "VARIABLES" and write the guide on a case-by-case basis.
6. If the options text says 'Empty,' it means no options are provided.

Test Input)

### Query  
{query}

### Options  
{options}

### Chain-of-Thought  
""".strip()


class COTGenerationService:
    prompt_template: str = COT_GENERATION_PROMPT_TEMPLATE

    def __init__(self, generator: BaseGenerator, default_document: str) -> None:
        self.generator = generator
        self.default_document = default_document

    def _preprocess(
        self,
        query: str,
        documents: List[Document],
        options: Optional[List[str]] = None,
    ) -> str:
        # documents_str = "\n".join([f"Document {i}. {doc.content}" for i, doc in enumerate(documents, start=1)]).strip()
        # if not documents_str:
        #     documents_str = self.default_document
        options_str = "\n".join([f"- {option}" for option in options]).strip() if options else "Empty"
        prompt = self.prompt_template.format(
            query=query,
            # documents=documents_str,
            options=options_str,
        )
        return prompt

    def _postprocess(self, response: str) -> List[Dict]:
        return response  # currently, just plain text output

    def generate_cot(
        self,
        query: str,
        documents: List[Document],
        answer_element_universe: Optional[List[str]] = None,
    ) -> List[Dict]:
        prompt = self._preprocess(
            query=query,
            documents=documents,
            options=answer_element_universe,
        )
        logger.debug(f"[{self.__class__.__name__}] prompt: {prompt}")

        response = self.generator.generate(prompt=prompt)
        logger.debug(f"[{self.__class__.__name__}] response: {response}")

        cot = self._postprocess(response=response)
        logger.debug(f"[{self.__class__.__name__}] cot: {cot}")

        return cot
