from typing import (
    List,
    Optional,
)

from loguru import logger

from src.generator import BaseGenerator
from src.schema import Document


QUESTION_ANSWERING_PROMPT_TEMPLATE = """
You are a question answering model.

1. You are given a personal context, a query, and a list of possible options.
2. Your task is to generate an answer to the query based on the user's personal context.
3. You should generate an answer to the query by referring to the personal context where relevant.
4. If the options text says 'Empty,' it means no options are provided.
5. If the options are not empty, simply output one of the answers listed in the options without any additional explanation.
6. Never output any other explanation. Just output the answer.
7. If option follows a format like '[A] something', then output something as the answer instead of A.

Test Input)

### Personal Context
{personal_context}

### Query
{query}

### Options
{options}

### Answer
""".strip()


class QuestionAnsweringService:
    prompt_template: str = QUESTION_ANSWERING_PROMPT_TEMPLATE

    def __init__(self, generator: BaseGenerator, default_document: str) -> None:
        self.generator = generator
        self.default_document = default_document

    def _preprocess(
        self,
        query: str,
        documents: List[Document],
        ret_memory_contents: List[str],
        options: Optional[List[str]] = None,
    ) -> str:
        # documents_str = "\n".join([f"Document {i}. {doc.content}" for i, doc in enumerate(documents, start=1)]).strip()
        # if not documents_str:
        #     documents_str = self.default_document
        memories_str = "\n".join(
            [item.strip() for i, item in enumerate(ret_memory_contents, start=1)]
        ).strip()
        options_str = "\n".join([f"- {option}" for option in options]) if options else "Empty"
        prompt = self.prompt_template.format(
            # documents=documents_str,
            personal_context=memories_str,
            query=query,
            options=options_str,
        )
        return prompt

    def _postprocess(self, response: str) -> str:
        return response.strip()

    def generate_answer(
        self,
        query: str,
        documents: List[Document],
        ret_memory_contents: List[str],
        answer_element_universe: Optional[List[str]] = None,
    ) -> str:
        prompt = self._preprocess(
            query=query,
            documents=documents,
            ret_memory_contents=ret_memory_contents,
            options=answer_element_universe,
        )
        logger.debug(f"[QuestionAnsweringService] prompt: {prompt}")

        response = self.generator.generate(prompt=prompt)
        logger.debug(f"[QuestionAnsweringService] response: {response}")

        raw_answer = self._postprocess(response=response)
        logger.debug(f"[QuestionAnsweringService] raw_answer: {raw_answer}")

        return raw_answer
