from typing import (
    List,
    Optional,
)

from src.generator import BaseGenerator
from src.schema import (
    Document,
    ReasoningGraph,
)


REASONING_GRAPH_EXECUTION_PROMPT_TEMPLATE = """
# Role
You are a reasoning graph executor.

# Task
- You will be given a document, a query, a reasoning graph, memories, and options.
- Your task is to understand the reasoning graph to arrive at the final answer.
- The reasoning graph is a directed acyclic graph.
- Each non-leaf node represents a sub-query.
- Each leaf node represents an answer, which is one of the options.
- Each edge represents a condition that leads from one node to another.
- Follow the reasoning graph, derive the final answer, and output it.
- The final answer must be written followed by `## Final Answer` and your response must terminate after the final answer.
- If the options text says 'Empty,' it means no options are provided.
- If the options are not empty, simply output one of the answers listed in the options without any additional explanation.
- Never output any other explanation. Just output the final answer.
- If option follows a format like '[A] somethng', then output somenthing as the answer instead of A.

# Example

## Query
Hi doctor, I am 35 y.o., female, I have fatigue and night sweats. What is causing this?

## Reasoning Graph
{{ 
    "nodes": [ 
        {{ 
            "node_id": 0, 
            "is_root": true, 
            "is_leaf": false, 
            "sub_query": "Have you traveled to a tropical region recently?", 
            "answer": null, 
            "edges": [ 
                {{ "edge_id": 0, "src_node_id": 0, "dst_node_id": 1, "condition": "Yes" }}, 
                {{ "edge_id": 1, "src_node_id": 0, "dst_node_id": 2, "condition": "No" }} 
            ] 
        }}, 
        {{ 
            "node_id": 1, 
            "is_root": false, 
            "is_leaf": false, 
            "sub_query": "Do you have a history of mosquito bites?", 
            "answer": null, 
            "edges": [ 
                {{ "edge_id": 2, "src_node_id": 1, "dst_node_id": 3, "condition": "Yes" }}, 
                {{ "edge_id": 3, "src_node_id": 1, "dst_node_id": 4, "condition": "No" }} 
            ] 
        }}, 
        {{ 
            "node_id": 2, 
            "is_root": false, 
            "is_leaf": true, 
            "sub_query": null, 
            "answer": "HIV", 
            "edges": [] 
        }}, 
        {{ 
            "node_id": 3, 
            "is_root": false, 
            "is_leaf": true, 
            "sub_query": null, 
            "answer": "Malaria", 
            "edges": [] 
        }}, 
        {{ 
            "node_id": 4, 
            "is_root": false, 
            "is_leaf": true, 
            "sub_query": null, 
            "answer": "Flu", 
            "edges": [] 
        }} 
    ] 
}}

## Memories
- I have traveled to Africa recently.
- I have visited Japan before.
- I love to travel.

## Options
- HIV
- Malaria
- Flu

## Final Answer
Malaria

# Test Input

## Query
{query}

## Reasoning Graph
{reasoning_graph_json}

## Memories
{memories}

## Options
{options}

## Final Answer
""".strip()


class ReasoningGraphExecutionService:
    prompt_template: str = REASONING_GRAPH_EXECUTION_PROMPT_TEMPLATE

    def __init__(self, generator: BaseGenerator, default_document: str) -> None:
        self.generator = generator
        self.default_document = default_document

    def _preprocess(
        self,
        query: str,
        documents: List[Document],
        reasoning_graph: ReasoningGraph,
        memories: List[str],
        options: Optional[List[str]] = None,
    ) -> str:
        documents_str = "\n".join([f"- {doc.content}" for doc in documents]).strip()
        if not documents_str:
            documents_str = self.default_document
        memories_str = "\n".join([f"- {memory}" for memory in memories]).strip()
        options_str = "\n".join([f"- {option}" for option in options]).strip() if options else "Empty"
        prompt = self.prompt_template.format(
            document=documents_str,
            query=query,
            reasoning_graph_json=reasoning_graph.model_dump_json(indent=4),
            memories=memories_str,
            options=options_str,
        )
        return prompt

    def _postprocess(self, response: str) -> str:
        final_answer = response.strip()
        return final_answer

    def execute_reasoning_graph(
        self,
        query: str,
        documents: List[Document],
        reasoning_graph: ReasoningGraph,
        ret_memory_contents: List[str],
        answer_element_universe: Optional[List[str]] = None,
    ) -> str:
        prompt = self._preprocess(
            query=query,
            documents=documents,
            reasoning_graph=reasoning_graph,
            memories=ret_memory_contents,
            options=answer_element_universe,
        )
        response = self.generator.generate(prompt=prompt)
        final_answer = self._postprocess(response=response)
        return final_answer
