import os
import sys
project_root = os.path.dirname(
    os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)

from synthesizer.fact import AttributeFact, RelationFact
from collections import defaultdict, deque


class ReasoningNode:
    def __init__(self, conclusion, support):
        self.conclusion = conclusion
        self.true_conditions = []
        self.support = support

    def __eq__(self, other):
        return isinstance(other, ReasoningNode) and self.conclusion == other.conclusion

    def __hash__(self):
        if isinstance(self.conclusion, AttributeFact):
            return hash((self.conclusion.entity, self.conclusion.attribute))
        elif isinstance(self.conclusion, RelationFact):
            return hash((self.conclusion.entity1, self.conclusion.entity2, self.conclusion.relation))
        else:
            raise TypeError("Unsupported conclusion type for hashing")


class ReasoningGraph:
    def __init__(self):
        self.root = None
        self.nodes = []
        self.edges = {}
        self.redges = {}

    def add_node(self, node):
        if node not in self.nodes:
            self.nodes.append(node)
            self.edges[node] = []
            self.redges[node] = []
        return self.find_node(node)

    def add_edge(self, source, target):
        source = self.add_node(source)
        if target is None:
            assert self.root is None, "Root node already exists"
            self.root = source
            return
        target = self.add_node(target)
        self.edges[source].append(target)
        self.redges[target].append(source)

    def test_if_dag(self, source, target):
        """Test if adding an edge from source to target will create a cycle."""
        assert target is not None and target in self.nodes, "Target node must be valid and exist in the graph"
        if source == target:
            return False
        if source not in self.nodes:
            return True
        if target in self.edges[source]:
            return False

        # Temporarily add the edge
        self.edges[source].append(target)
        self.redges[target].append(source)

        # Perform cycle detection using DFS
        visited = set()
        rec_stack = set()

        def dfs(node):
            if node in rec_stack:
                return True
            if node in visited:
                return False

            visited.add(node)
            rec_stack.add(node)

            for neighbor in self.edges[node]:
                if dfs(neighbor):
                    return True

            rec_stack.remove(node)
            return False

        # Check if there's a cycle starting from any node
        cycle_found = False
        for node in self.edges:
            if node not in visited:
                if dfs(node):
                    cycle_found = True
                    break

        # Rollback the temporary edge
        self.edges[source].remove(target)
        self.redges[target].remove(source)

        return not cycle_found

    def find_node(self, node):
        """Find a node in the graph."""
        for n in self.nodes:
            if n == node:
                return n
        return None

    def topo_sort(self):
        in_degree = defaultdict(int)
        queue = deque()
        for node in self.nodes:
            in_degree[node] = len(self.redges[node])
            if in_degree[node] == 0:
                queue.append(node)

        result = []
        while queue:
            node = queue.popleft()
            result.append(node)

            for child in self.edges[node]:
                in_degree[child] -= 1
                if in_degree[child] == 0:
                    queue.append(child)

        if len(result) != len(self.nodes):
            raise ValueError("the graph is not a DAG, topological sort failed")

        return result
