import json

from typing import (
    Dict,
    List,
    Optional,
)

from loguru import logger

from src.generator import BaseGenerator
from src.schema import (
    Document,
    ReasoningGraph,
)


RESPONSE_FORMAT = {
    "type": "json_schema",
    "json_schema": {
        "name": "reasoning_graph",
        "strict": True,
        "schema": {
            "type": "object",
            "properties": {
                "nodes": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "node_id": {"type": "number"},
                            "is_root": {"type": "boolean"},
                            "is_leaf": {"type": "boolean"},
                            "sub_query": {"type": ["string", "null"]},
                            "answer": {"type": ["string", "null"]},
                            "edges": {
                                "type": "array",
                                "items": {
                                    "type": "object",
                                    "properties": {
                                        "edge_id": {"type": "number"},
                                        "src_node_id": {"type": "number"},
                                        "dst_node_id": {"type": "number"},
                                        "condition": {"type": "string"},
                                    },
                                    "required": ["edge_id", "src_node_id", "dst_node_id", "condition"],
                                    "additionalProperties": False,
                                },
                            },
                        },
                        "required": ["node_id", "is_root", "is_leaf", "sub_query", "answer", "edges"],
                        "additionalProperties": False,
                    },
                },
            },
            "required": ["nodes"],
            "additionalProperties": False,
        },
    },
}


REASONING_GRAPH_GENERATION_PROMPT_TEMPLATE = """
# Role
You are a reasoning graph constructor.

# Task
- You will be given a document, a query, and multiple options.
- Options can be empty. If options are empty, you should generate a possible options to construct the reasoning graph.
- Your task is to construct a reasoning graph to arrive at the final answer.
- The reasoning graph should be a directed acyclic graph.
- Each non-leaf node represents a sub-query, which is a question to retrieve relevant memories from the user's memory database.
- Each leaf node represents an answer, which is one of the options.
- Each edge represents a condition that leads from one node to another.
- The node_id of the root node, which is the first node in the list, must be 0.
- You must output the graph, and empty responses are not allowed.
- Make sure the directed acyclic graph is created without any recursion.
- If the options text says 'Empty,' it means no options are provided.

## Reasoning Graph Schema
```python
from typing import List, Optional
from pydantic import BaseModel

class ReasoningEdge(BaseModel):
    edge_id: int
    src_node_id: int
    dst_node_id: int
    condition: str

class ReasoningNode(BaseModel):
    node_id: int
    is_root: bool
    is_leaf: bool
    sub_query: Optional[str]
    answer: Optional[str]
    edges: List[ReasoningEdge]

class ReasoningGraph(BaseModel):
    nodes: List[ReasoningNode]
```

# Example

## Document
Some document here.

## Query
Hi doctor, I am 35 y.o., female, I have fatigue and night sweats. What is causing this?

## Options
- Flu
- HIV
- Malaria

## Graph
{{ 
    "nodes": [ 
        {{ 
            "node_id": 0, 
            "is_root": true, 
            "is_leaf": false, 
            "sub_query": "Have you traveled to a tropical region recently?", 
            "answer": null, 
            "edges": [ 
                {{ "edge_id": 0, "src_node_id": 0, "dst_node_id": 1, "condition": "Yes" }}, 
                {{ "edge_id": 1, "src_node_id": 0, "dst_node_id": 2, "condition": "No" }} 
            ] 
        }}, 
        {{ 
            "node_id": 1, 
            "is_root": false, 
            "is_leaf": false, 
            "sub_query": "Do you have a history of mosquito bites?", 
            "answer": null, 
            "edges": [ 
                {{ "edge_id": 2, "src_node_id": 1, "dst_node_id": 3, "condition": "Yes" }}, 
                {{ "edge_id": 3, "src_node_id": 1, "dst_node_id": 4, "condition": "No" }} 
            ] 
        }}, 
        {{ 
            "node_id": 2, 
            "is_root": false, 
            "is_leaf": true, 
            "sub_query": null, 
            "answer": "HIV", 
            "edges": [] 
        }}, 
        {{ 
            "node_id": 3, 
            "is_root": false, 
            "is_leaf": true, 
            "sub_query": null, 
            "answer": "Malaria", 
            "edges": [] 
        }}, 
        {{ 
            "node_id": 4, 
            "is_root": false, 
            "is_leaf": true, 
            "sub_query": null, 
            "answer": "Flu", 
            "edges": [] 
        }} 
    ] 
}}

# Test Input

## Document
{document}

## Query
{query}

## Options
{options}

## Graph
""".strip()


class ReasoningGraphGenerationService:
    prompt_template: str = REASONING_GRAPH_GENERATION_PROMPT_TEMPLATE

    def __init__(self, generator: BaseGenerator, default_document: str) -> None:
        self.generator = generator
        self.default_document = default_document

    def _preprocess(
        self,
        query: str,
        documents: List[Document],
        options: Optional[List[str]] = None,
    ) -> str:
        documents_str = "\n".join([f"- {doc.content}" for doc in documents]).strip()
        if not documents_str:
            documents_str = self.default_document
        options_str = "\n".join([f"- {option}" for option in options]).strip() if options else "Empty"
        prompt = self.prompt_template.format(
            query=query,
            document=documents_str,
            options=options_str,
        )
        return prompt

    def _postprocess(self, response: str) -> Optional[List[Dict]]:
        try:
            reasoning_graph_json = json.loads(response)
            reasoning_graph = ReasoningGraph(**reasoning_graph_json)
            return reasoning_graph
        except Exception as e:
            logger.error(f"[ReasoningGraphGenerationService] Error during postprocessing: {response}")
            return None

    def generate_reasoning_graph(
        self,
        query: str,
        documents: List[Document],
        answer_element_universe: Optional[List[str]] = None,
    ) -> Optional[List[Dict]]:
        prompt = self._preprocess(
            query=query,
            documents=documents,
            options=answer_element_universe,
        )

        max_retries = 5
        retries = 0
        response = None
        reasoning_graph = None

        while retries < max_retries:
            logger.debug(f"[ReasoningGraphGenerationService] prompt: {prompt}")
            response = self.generator.generate(prompt=prompt, json_mode=True, json_schema=RESPONSE_FORMAT)
            logger.debug(f"[ReasoningGraphGenerationService] response: {response}")

            reasoning_graph = self._postprocess(response=response)
            if reasoning_graph is not None:
                break

            retries += 1
            logger.error(f"[ReasoningGraphGenerationService] Error on attempt {retries}. Retrying...")

        if reasoning_graph is None:
            logger.error("[ReasoningGraphGenerationService] Max retries reached. Unable to generate reasoning graph.")
            reasoning_graph = ReasoningGraph(nodes=[])

        return reasoning_graph
