from typing import Literal
from networkx import DiGraph
from networkx.drawing.nx_pydot import write_dot
import os

class EasyWriteConverter:
    def __init__(self, neg_symbol: str = "<neg>"):
        self.neg_symbol = neg_symbol
        
    def _check(self, easy_write_dict: dict[str, list[list[str]]]) -> None:
        """
        Check if the input dictionary is valid.
        """
        if not isinstance(easy_write_dict, dict):
            raise TypeError(f"The input should be a dictionary but get {type(easy_write_dict)}.")
        for key, value in easy_write_dict.items():
            if not isinstance(key, str):
                raise TypeError(f"The key should be a string but get {type(key)}.")
            if not isinstance(value, list):
                raise TypeError(f"The value should be a list but get {type(value)}.")
            for sub_value in value:
                if not isinstance(sub_value, list):
                    raise TypeError(f"The sub_value should be a list but get {type(sub_value)}.")
                for item in sub_value:
                    if not isinstance(item, str):
                        raise TypeError(f"The item should be a string but get {type(item)}.")
                    
    def convert_to_standard_dnf(self, easy_write_dict: dict[str, list[list[str]]]) -> dict[str, list[dict[str, bool]]]:
        """
        Convert the easy write dictionary to standard DNF.
        """
        self._check(easy_write_dict)
        result = {}
        for key, value in easy_write_dict.items():
            result[key] = []
            for sub_value in value:
                cc = {}
                for item in sub_value:
                    if item.startswith(self.neg_symbol):
                        cc[item[len(self.neg_symbol):]] = False
                    else:
                        cc[item] = True
                result[key].append(cc)
        return result

    def convert_to_math_symbol(self, easy_write_dict: dict[str, list[list[str]]]) -> list[str]:
        """
        given a easy write DNF where each key is a variable and each value is a list of conjunction clauses, convert it to a readable math expression.
        """
        return MathSymbolConverter()(self.convert_to_standard_dnf(easy_write_dict))
    
    def __call__(self, easy_write_dict: dict[str, list[list[str]]], mode: Literal["math", "standard"] = "standard") -> list[str] | dict[str, list[dict[str, bool]]]:
        if mode not in ["math", "standard"]:
            raise ValueError(f"mode should be one of ['math', 'standard'] but get {mode}.")
        return self.convert_to_math_symbol(easy_write_dict) if mode == "math" else self.convert_to_standard_dnf(easy_write_dict)

class MathSymbolConverter:
    def __init__(self, and_symbol: str = " ∧ ", or_symbol: str = " ∨ ", neg_symbol: str = "¬"):
        self.and_symbol = and_symbol
        self.or_symbol = or_symbol
        self.neg_symbol = neg_symbol
        
    def _convert_cc_to_math_symbol(self, cc: dict[str, bool]) -> str:
        """
        Convert a single conjunction clause to math symbol.
        """
        return self.and_symbol.join([f"{self.neg_symbol}\"{key}\"" if not value else key for key, value in cc.items()])

    def convert_to_math_symbol(self, logic_dict: dict[str, list[dict[str, bool]]]) -> list[str]:
        """
        given a standard DNF where each key is a variable and each value is a list of conjunction clauses, convert it to a readable math expression.
        """
        result: list[str] = []

        for target_node, dnf in logic_dict.items():
            math_expr = []

            for conjunction in dnf:
                math_subexpr = " ( "+self._convert_cc_to_math_symbol(conjunction)+" ) "
                math_expr.append(math_subexpr)
            result.append(
                f"{target_node} = {self.or_symbol.join(math_expr)}"
            )
        return result
    
    def __call__(self, logic_dict: dict[str, list[dict[str, bool]]]) -> list[str]:
        return self.convert_to_math_symbol(logic_dict)
    
class GraphConverter:        
    def convert_standard_dnf_to_graph(self, logic_dict: dict[str, list[dict[str, bool]]]) -> DiGraph:
        graph = DiGraph()
        # 为每个逻辑表达式添加节点
        for key, conjunctions in logic_dict.items():
            graph.add_node(key, label=key)  # 节点为逻辑变量（如 A, B, C 等）
            # 根据析取范式的合取子句生成边
            for conjunction in conjunctions:
                for var in conjunction:
                    graph.add_edge(var, key)  # 其他直接连接
        return graph
    
    def convert_standard_dnf_to_dot(self, logic_dict: dict[str, list[dict[str, bool]]], save_file: str) -> None:
        write_dot(self.convert_standard_dnf_to_graph(logic_dict=logic_dict), path=save_file)
    
    def convert_standard_dnf_to_image(self, logic_dict: dict[str, list[dict[str, bool]]], save_file: str) -> None:
        graph_img_type = os.path.splitext(save_file)[-1]
        
        from tempfile import NamedTemporaryFile
        with NamedTemporaryFile(mode='w', delete_on_close=False) as tempfile:
            self.convert_standard_dnf_to_dot(logic_dict=logic_dict, save_file=tempfile.name)
            os.system(f"dot -T{graph_img_type} {tempfile.name} -o {save_file}")