import string

from typing import (
    List,
    Optional,
)

from dotenv import load_dotenv
from loguru import logger

from src.generator import BaseGenerator
from src.schema import Document


load_dotenv()


SUB_QUERY_GENERATION_PROMPT_TEMPLATE = """
You are a sub-query generator.


1. You are given a query and a list of possible options.
2. Your task is to generate 3 to 5 sub-queries that help retrieve personal context relevant to answering the query.
3. Each sub-query should be answerable based on the user's personal context.
4. Ensure the sub-queries cover different aspects or angles of the query.
5. If the options text says 'Empty,' it means no options are provided.

Please output the sub-queries one sub-query each line, in the following format:
"Sub-query 1 here"
"Sub-query 2 here"
"Sub-query 3 here"

Example 1)

## Query
I have a fever and a cough. What disease do I have?

## Options
Common cold
Flu
Strep throat

### Sub-queries
"Have user visited any countries in Africa recently?"
"Have user eat any cold food recently?"
"Have user been in contact with anyone who has a COVID-19 recently?"

Test Input)

### Query
{query}

### Options
{options}

### Sub-queries
""".strip()


class SubQueryGenerationService:
    prompt_template: str = SUB_QUERY_GENERATION_PROMPT_TEMPLATE

    def __init__(self, generator: BaseGenerator, default_document: str) -> None:
        self.generator = generator
        self.default_document = default_document

    def _preprocess(self, query: str, documents: List[Document], options: Optional[List[str]] = None) -> str:
        # documents_str = "\n".join([f"Document {i}. {doc.content}" for i, doc in enumerate(documents, start=1)]).strip()
        # if not documents_str:
        #     documents_str = self.default_document
        options_str = "\n".join([option.strip() for option in options]).strip() if options else "Empty"
        prompt = self.prompt_template.format(
            # documents=documents_str,
            query=query,
            options=options_str,
        )
        return prompt

    def _postprocess(self, response: str) -> List[str]:
        sub_queries = list(map(lambda x: x.strip(string.whitespace + '"'), response.strip().split("\n")))
        return sub_queries

    def generate_sub_queries(
        self,
        query: str,
        documents: List[Document],
        answer_element_universe: Optional[List[str]] = None,
        
    ) -> List[str]:
        prompt = self._preprocess(query=query, documents=documents, options=answer_element_universe)
        logger.debug(f"[SubQueryGenerationService] prompt: {prompt}")

        response = self.generator.generate(prompt=prompt)
        logger.debug(f"[SubQueryGenerationService] response: {response}")

        sub_queries = self._postprocess(response=response)
        logger.debug(f"[SubQueryGenerationService] sub_queries: {sub_queries}")

        return sub_queries
