import typing
from networkx import DiGraph
from typing import Any

from util.interface import RulesDict

import logging
import os
logger_level = os.environ.get("VACT_LOGGER_LEVEL", "INFO")
if logger_level == "DEBUG":
    logger_level = logging.DEBUG
elif logger_level == "INFO":
    logger_level = logging.INFO
elif logger_level == "WARNING":
    logger_level = logging.WARNING
elif logger_level == "ERROR":
    logger_level = logging.ERROR
elif logger_level == "CRITICAL":
    logger_level = logging.CRITICAL
else:
    logger_level = logging.INFO

logging.basicConfig(format='%(asctime)s - %(filename)s - [%(levelname)s] - %(message)s', level=logger_level)
logger = logging.getLogger(__name__)

class ExtractAnswerError(Exception):
    pass

class PydanticFormatError(Exception):
    pass

def load_env(env_file: str = ".env") -> None:
    """
    Load the environment variables from the .env file.
    """
    from dotenv import load_dotenv
    env_file = os.path.abspath(env_file)
    if not os.path.exists(env_file):
        logger.warning(f"The env file {env_file} does not exist. Skip loading environment variables from it.")
        return
    load_dotenv(env_file)
    logger.debug(f"Loaded environment variables from {env_file}")

# def extract_dot(answer: str) -> str:
#     """
#     Extract the dot format from the answer.
#     """
#     start = answer.find("<dot>")
#     end = answer.find("</dot>")
#     if start == -1 or end == -1:
#         raise ExtractAnswerError("No dot format found.")
#     return answer[start + 5: end]

# def extract_json_str(answer: str) -> str:
#     """
#     Extract the json format from the answer.
#     """
#     start = answer.find("<json>")
#     end = answer.find("</json>")
#     if start == -1 or end == -1:
#         raise ExtractAnswerError("No json format found.")
#     return answer[start + 6: end]

# def extract_json(answer: str) -> dict:
#     """
#     Extract the json format from the answer.
#     """
#     json_str = extract_json_str(answer)
#     try:
#         f = json.loads(json_str)
#     except json.JSONDecodeError as e:
#         raise ExtractAnswerError(f"The generated json file is not valid (Error infor: {e})")
#     return f

# def extract_graph(answer: str) -> DiGraph:
#     """
#     Extract the graph from the answer.
#     """
#     from tempfile import NamedTemporaryFile
#     dot = extract_dot(answer)
#     with NamedTemporaryFile(mode='w', delete_on_close=False) as tempfile:
#         tempfile.write(dot)
#         tempfile.close()
#         from networkx.drawing.nx_pydot import read_dot
#         try:
#             graph = read_dot(tempfile.name)
#         except TypeError:
#             raise ExtractAnswerError("The generated dot file is not valid.")
#     return graph

from pydantic import BaseModel, Field, model_validator
from typing import Literal, Self

class Factor(BaseModel):
    name: str = Field(
        ...,
        description = "A short name of the factor or result, used as the node id. Must be consistent with its description. No more than 5 words."
    )
    type: Literal['factor', 'result'] = Field(
        ...,
        description = "The type of the node, which can be either `factor` or `result`. A `factor` is an independently settable variable that can be either true or false, while a `result` is a variable that is determined by other factors or results. Notice that one factor could be the reason of many results, and a results could be the outcome of many factors (or other results)."
    )
    description: str = Field(
        ...,
        description = "A detailed description of the factor or result, describing (1) what the factor is; (2) the detailed definition of the obvious `true` or `false` value of the factors; (2) how to determine the Boolean value from a piece of video."
    )
    explanation: str = Field(
        ...,
        description = "The explanation of how it affects the scenario and why (1) this factor is important and common, and (2) if it has `factor` type, why it can be independently set True or False."
    )
    
    def remove_type(self) -> "FactorWithoutType":
        return FactorWithoutType(name=self.name, type=self.type, description=self.description, explanation=self.explanation)
    
class FactorWithoutType(Factor):
    type: Literal['factor', 'result'] = Field(exclude=True)
    
class FactorList(BaseModel):
    factors: list[Factor]
    
class Edge(BaseModel):
    source: str = Field(
        ...,
        description = "The source node of the edge. Must be one of the node names in the factor list."
    )
    target: str = Field(
        ...,
        description = "The target node of the edge. Must be one of the node names in the factor list. Notice that the factor with type `factor` cannot be the target node."
    )
    description: str = Field(
        ...,
        description = "A detailed description of the causal relation between the source and target nodes, describing (1) how the source node affects the target node; (2) why you believe it is an important relation (3) under what conditions this causal relation holds or not."
    )
    
class Graph(BaseModel):
    """
    A graph describe ALL the causal relations between the factors and results. As long as one node affects another node, there should be an edge between them. Because the `factor` nodes are independently settable, they cannot be the target node of any edge.
    """
    edges: list[Edge]
    
    @property
    def nodes(self) -> list[str]:
        nodes = [edge.source for edge in self.edges] + [edge.target for edge in self.edges]
        nodes = list(set(nodes))
        return nodes
    
    def to_digraph(self) -> DiGraph:
        g = DiGraph()
        edges = [(edge.source, edge.target) for edge in self.edges]
        g.add_edges_from(edges) # type: ignore
        return g
    
    def has_edge(self, source: str, target: str) -> bool:
        for edge in self.edges:
            if edge.source == source and edge.target == target:
                return True
        return False
    
    @property
    def root_and_non_root(self) -> tuple[list[str], list[str]]:
        """
        This function should only be used after the graph is validated.
        We assume that the graph is a valid DAG without isolated nodes.
        """
        non_root = [edge.target for edge in self.edges]
        non_root = list(set(non_root))
        all_nodes = [edge.source for edge in self.edges] + [edge.target for edge in self.edges]
        all_nodes = list(set(all_nodes))
        root = [node for node in all_nodes if node not in non_root]
        return root, non_root
    
class VariableValue(BaseModel):
    var: str = Field(
        ...,
        description = "The name of the variable. Must be one of the node names in the factor list. The variables in one clause must be unique"
    )
    val: bool = Field(
        ...,
        description = "The Boolean value of the variable."
    )
    
class ConjunctionClause(BaseModel):
    clause: list[VariableValue] = Field(
        ...,
        description = "A conjunction clause (AND) is represented as a list of VariableValue. For example, the conjunction clause `[VariableValue(var='B', val=True), VariableValue(var='C', val=False)]` represents \"B is true AND C is false\"."
    )
    
    def to_dict(self) -> dict[str, bool]:
        return {vv.var: vv.val for vv in self.clause}
    
    @model_validator(mode='after')
    def check_unique_variables(self) -> Self:
        variables = [vv.var for vv in self.clause]
        if len(variables) != len(set(variables)):
            raise PydanticFormatError("The variables in one conjunction clause must be unique.")
        return self
    
class Rule(BaseModel):
    head: str = Field(
        ...,
        description="The target node of the rule. Must be one of the node names in the factor list. The rule is used to determine the Boolean value of this node."
    )
    dnf: list[ConjunctionClause] = Field(
        ...,
        description="The rule is a disjunction of conjunction clauses (DNF), represented as a list of conjunction clauses. The conjunction clauses listed in the dnf are considered as OR relation. For example, if a factor A is true when (B is true) or (C is true and D is false), the dnf should have two conjunction clauses: [ConjunctionClause(VariableValue(var='B', val=True)), ConjunctionClause(VariableValue(var='C', val=True), VariableValue(var='D', val=False))]."
    )
    
class RulesJson(BaseModel):
    roots: list[Factor] = Field(
        ...,
        description = "The root nodes (factors) in the causal graph, which are not affected by any other nodes. All the type should be 'factor'."
    )
    non_roots: list[Factor] = Field(
        ...,
        description = "The non-root nodes (results) in the causal graph, which are affected by other nodes. The type can be 'result'. Notice that the non-root nodes can also be the reason of other nodes."
    )
    rules: list[Rule] = Field(
        ...,
        description = "Each non-root node should have exact one rule to determine its Boolean value, represented as a dnf (a list of conjunction clauses)."
    )
    
    def rules_convert_to_dict(self) -> RulesDict:
        rules_dict: dict[str, list[dict[str, bool]]] = {}
        for rule in self.rules:
            rules_dict[rule.head] = [cc.to_dict() for cc in rule.dnf]
        return rules_dict
    
    def get_roots_names(self) -> list[str]:
        return [factor.name for factor in self.roots]
    
    def get_non_roots_names(self) -> list[str]:
        return [factor.name for factor in self.non_roots]

# def recover_complete_json_from_rules(rules) -> RulesJson:
#     roots, non_roots = root_and_non_root(recover_graph_from_rules(rules))
#     return RulesJson

def recover_graph_from_rules(rules: dict[str, list[dict[str, bool]]]) -> DiGraph:
    g = DiGraph()
    edges: list[tuple[str, str]] = []
    for target_node, dnf in rules.items():
        for cc in dnf:
            edges += [(source_node, target_node) for source_node in cc]
    g.add_edges_from(edges)
    return g

def root_and_non_root(graph: DiGraph) -> tuple[list[str], list[str]]:
    """
    Get the root and non-root nodes in the graph.
    """
    roots: list[str] = [str(node) for node in graph.nodes if graph.in_degree(node) == 0] # type: ignore
    non_roots = [str(node) for node in graph.nodes if graph.in_degree(node) > 0] # type: ignore
    return roots, non_roots

from openai.types.responses.parsed_response import ParsedResponse

@typing.no_type_check
def parse_completion(completion: ParsedResponse[str]) -> tuple[str, Any, str]:
    try:
        if len(completion.output[0].summary) == 0:
            thinking = ""
        else:
            thinking = completion.output[0].summary[0].text
        generation_obj = completion.output[1].content[0].parsed
        generation_str = completion.output[1].content[0].text
    except (AttributeError, IndexError) as e:
        logger.error('[ERROR] to parse the completion', completion)
        raise RuntimeError(f"The response from OpenAI api is not valid. Error info: {e}")
    return thinking, generation_obj, generation_str