import json
from typing import (
    List,
    Dict,
)

from src.generator.base_generator import BaseGenerator


PROMPT_TEMPLATE = """
You are a personalization graph generator.

1. You are given a context and a question.
2. Generate a personalization graph that helps in answering the question.
3. The personalization graph consists of possible answers and the sub-queries associated with each answer.
4. Possible answers represent candidate responses to the question.
5. Sub-queries are designed to retrieve relevant personal history that supports each possible answer.
6. Generate between 1 and 5 possible answers.
7. For each possible answer, generate between 1 and 3 sub-queries.

# Example

### Context
I’ve got a fever, feel nauseous, and have been throwing up all day.

### Question
What’s the most likely cause?

### Personalization Graph (in json)
{{
  "Malaria": [
    "Has the user recently traveled to a malaria-endemic country such as Rwanda?",
    "Has the user been exposed to mosquito bites during their travel?"
  ],
  "Stomach flu": [
    "Did the user consume raw shellfish or other risky foods recently?",
    "Has the user been in close contact with someone who was sick or vomiting?"
  ],
  "Food poisoning": [
    "Did the user eat food from an unsanitary place recently?",
    "Has the user consumed any expired or questionable food items?"
  ]
}}

### Context

{context}

### Question

{question}

### Personalization Graph (in json)
""".strip()


# PERS_GRAPH_SCHEMA = {
#     "type": "json_schema",
#     "json_schema": {
#         "name": "PersGraph",
#         "strict": True,
#         "schema": {
#             "type": "object",
#             "additionalProperties": {
#                 "type": "array",
#                 "items": {"type": "string"}
#             }
#         }
#     }
# }


class PersGraphService:
    prompt_template: str = PROMPT_TEMPLATE

    def __init__(
        self,
        generator: BaseGenerator,
    ) -> None:
        self.generator = generator

    def generate(
        self,
        context: str,
        question: str,
    ) -> Dict[str, List[str]]:
        prompt = self.prompt_template.format(
            context=context,
            question=question,
        )
        raw_response = self.generator.generate(
            prompt,
            json_mode=True,
            # json_schema=PERS_GRAPH_SCHEMA,
        )
        pers_graph = json.loads(raw_response)
        return pers_graph
