import json
from typing import List

from src.generator.base_generator import BaseGenerator


PROMPT_TEMPLATE = """
You are a sub-query generator.

1. You are given a user context and a question.
2. Generate sub-queries that help identify relevant personal history for answering the question.
3. Sub-queries should be concise, factual, and directly related to retrieving personal context.
4. Generate between 1 and 3 sub-queries.

# Example

### Context
I’ve got a fever, feel nauseous, and have been throwing up all day.

### Question
What’s the most likely cause?

### Sub-Queries (in json)
[
    "Has the user recently traveled to a malaria-endemic country such as Rwanda?",
    "Did the user consume raw shellfish or other risky foods recently?",
    "Has the user been in close contact with someone who was sick or vomiting?"
]

### Context

{context}

### Question

{question}

### Sub-Queries (in json)
""".strip()


# SUB_QUERY_SCHEMA = {
#     "type": "json_schema",
#     "json_schema": {
#         "name": "SubQueries",
#         "strict": True,
#         "schema": {
#             "type": "array",
#             "items": {"type": "string"}
#         }
#     }
# }


class SubQueryService:
    prompt_template: str = PROMPT_TEMPLATE

    def __init__(
        self,
        generator: BaseGenerator,
    ) -> None:
        self.generator = generator

    def generate(
        self,
        context: str,
        question: str,
    ) -> List[str]:
        prompt = self.prompt_template.format(
            context=context,
            question=question,
        )
        raw_response = self.generator.generate(
            prompt,
            json_mode=True,
            # json_schema=SUB_QUERY_SCHEMA,
        )
        sub_queries = json.loads(raw_response)
        return sub_queries
