from typing import List, Tuple, Set, Optional, Dict
import itertools
from collections import deque

def get_dag_signature(graph: 'MCGraph') -> Tuple[Tuple[str, str], ...]:
    """in order to compare DAGs, we want to give them some kind of 
    signature that is hashable, so we can determine if the graphs
    produced by both methods are the same. 
    
    So, we return a sorted list of (source, target) tuples."""
    directed_edges = []
    for (source, target), edge_types in graph.edges.items():
        if '-->' in edge_types:
            # Ensure consistent ordering for the signature
            if source < target:
                directed_edges.append((source, target))
    directed_edges.sort()
    return tuple(directed_edges)

edge_map = {
    ('-->', '<--'): '',
    ('<--', '<--'): '<--',
    ('<->', '<--'): '',
    ('-->', '-->'): '-->',
    ('<--', '-->'): '<->',
    ('<->', '-->'): '<->',
    ('-->', '<->'): '',
    ('<--', '<->'): '<->',
    ('<->', '<->'): ''
}

def flip(edge: str):
    if edge == '-->':
        return '<--'
    elif edge == '<--':
        return '-->'
    return edge


class MCGraph:
    def __init__(self, nodes: Optional[List[str]] = None, edge_tuples: Optional[List[Tuple[str, str, str]]] = None):
        self.nodes = []
        self.edges = {}
        if nodes is not None:
            for node in nodes:
                self.add_node(node)
        if edge_tuples is not None:
            for start, edge_type, end in edge_tuples:
                self.add_edge(start, edge_type, end)

    def add_node(self, node: str):
        self.nodes.append(node)

    def add_edge(self, start_node: str, edge_type: str, end_node: str, add_reverse=True):
        assert edge_type in ['-->', '<--', '<->']
        if (start_node, end_node) not in self.edges:
            self.edges[(start_node, end_node)] = set()
        self.edges[(start_node, end_node)].add(edge_type)

        if add_reverse:
            self.add_edge(end_node, flip(edge_type), start_node, add_reverse=False)

    def remove_edge(self, start_node: str, edge_type: str, end_node: str):
        if (start_node, end_node) in self.edges:
            self.edges[(start_node, end_node)].discard(edge_type)
        if (end_node, start_node) in self.edges:
            self.edges[(end_node, start_node)].discard(flip(edge_type))
            
    def has_directed_edge(self, u: str, v: str) -> bool:
        """Check if there's a directed edge from u to v (u --> v)"""
        return (u, v) in self.edges and '-->' in self.edges[(u, v)]
    
    def has_bidirected_edge(self, u: str, v: str) -> bool:
        return (u, v) in self.edges and '<->' in self.edges[(u, v)]
            
    def remove_node(self, node):
        if node in self.nodes:
            self.nodes.remove(node)

        for k, v in list(self.edges.items()):
            if k[0] == node or k[1] == node:
                del self.edges[k]

    def marginalize_out_node(self, node: str):
        """Applies the method in the 1999 Koster paper in order to marginalize out nodes.
        This is able to handle double edges, since each edge will be stored separately 
        in removed edges, so we will add another double edge to our new graph. In other 
        words, if A <-> B and A --> B, B --> C and we marginalize B, this is not a problem
        since we will add both of the edges to removed edges, and thus add both A <-> C and
        A --> C to our new_graph. """
        new_graph = MCGraph()
        for n in self.nodes:
            if n != node:
                new_graph.add_node(n)

        removed_edges = []
        for k, v in self.edges.items():
            if node not in k:
                for edge in v:
                    new_graph.add_edge(k[0], edge, k[1], add_reverse=False)
            elif k[1] == node:
                for edge in v:
                    removed_edges.append((k[0], edge, k[1]))

        for i in range(len(removed_edges)):
            for j in range(i + 1, len(removed_edges)):
                start, edge1, _ = removed_edges[i]
                end, edge2, _ = removed_edges[j]
                if start == end:
                    continue
                edge2 = flip(edge2)
                new_edge = edge_map[(edge1, edge2)]
                if len(new_edge) > 0:
                    new_graph.add_edge(start, new_edge, end)

        return new_graph

    def to_ADMG(self, nodes: List[str]):
        if len(nodes) == 0:
            return self.copy()
        new_graph = self.marginalize_out_node(nodes[0])
        for node in nodes[1:]:
            new_graph = new_graph.marginalize_out_node(node)
        return new_graph

    def _build_directed_adjacency_list(self, nodes_subset=None):
        """Build adjacency list for directed edges ('-->') only."""
        nodes_to_use = nodes_subset if nodes_subset is not None else self.nodes
        adjacency_list = {node: set() for node in nodes_to_use}
        
        for (start, end), edges in self.edges.items():
            for edge_type in edges:
                if edge_type == '-->':
                    if start in adjacency_list and end in adjacency_list:
                        adjacency_list[start].add(end)
        
        return adjacency_list
    
    def _build_bidir_adjacency_list(self, nodes_subset=None):
        nodes_to_use = nodes_subset if nodes_subset is not None else self.nodes
        adjacency_list = {node: set() for node in nodes_to_use}
        
        for (start, end), edges in self.edges.items():
            if '<->' in edges:
                adjacency_list[start].add(end)
                
        return adjacency_list
                
    def has_cycle(self):
        adjacency_list = self._build_directed_adjacency_list()

        # Function to perform DFS and detect cycles
        def dfs(node, visited, rec_stack):
            visited.add(node)
            rec_stack.add(node)
            for neighbor in adjacency_list[node]:
                if neighbor not in visited:
                    if dfs(neighbor, visited, rec_stack):
                        return True
                elif neighbor in rec_stack:
                    return True
            rec_stack.remove(node)
            return False

        visited = set()
        rec_stack = set()
        for node in self.nodes:
            if node not in visited:
                if dfs(node, visited, rec_stack):
                    return True
        return False

    def get_incoming(self, node, include_bidirectional=False):
        incoming = set()
        for (start, end), edges in self.edges.items():
            if include_bidirectional and end == node and '<->' in edges:
                incoming.add(start)
            elif end == node and '-->' in edges:
                incoming.add(start)
        return incoming
    
    def get_children(self, node: str) -> Set[str]:
        """Gets all of the children of this node"""
        children = set()
        for (start, end), edge_types in self.edges.items():
            if start == node and '-->' in edge_types:
                children.add(end)
        return children
    
    def get_parents(self, node: str) -> Set[str]:
        parents = set()
        for (start, end), edge_types in self.edges.items():
            if end == node and '-->' in edge_types:
                parents.add(start)
        return parents

    def copy(self):
        new_graph = MCGraph()
        for node in self.nodes:
            new_graph.add_node(node)
        for (start, end), edges in self.edges.items():
            for edge in edges:
                new_graph.add_edge(start, edge, end)
        return new_graph
    
    def compute_ancestors(self, nodes: Set[str]) -> Set[str]:
        """
        Compute the ancestors (with respect to the directed edges) of the nodes in 'nodes'.
        (Only edges with label '-->' are used to define ancestry.)
        """
        ancestors = set(nodes)
        changed = True
        while changed:
            changed = False
            for (u, v), labels in self.edges.items():
                if '-->' in labels and v in ancestors and u not in ancestors:
                    ancestors.add(u) 
                    changed = True
        return ancestors
    
    def is_cond_independent(self, X: str, Y: str, Z: List[str]) -> bool:
        Z_set = set(Z)
        A_active = self.compute_ancestors(Z_set) | Z_set
        queue = deque()
        visited = set()
        queue.append((X, None))
        visited.add((X, None))
        while queue:
            (v, flag) = queue.popleft()
            if v == Y and flag is not None:
                return False
            for (u, w) in self.edges:
                if u != v:
                    continue
                for L in self.edges[(u, w)]:
                    if L == '-->':
                        new_flag = 1
                        new_edge_arrow = 0
                    elif L == '<--':
                        new_flag = 0
                        new_edge_arrow = 1
                    else: # L == '<->'
                        new_flag = 1
                        new_edge_arrow = 1
                    if flag is not None:
                        if flag == 1 and new_edge_arrow == 1:
                            if v not in A_active:
                                continue
                        else:
                            if v in Z_set:
                                continue
                    state = (w, new_flag)
                    if state in visited:
                        continue
                    visited.add(state)
                    queue.append(state)
        return True

    def get_all_conditional_independencies(self) -> List[Tuple[str, str, List[str]]]:
        independencies = []
        nodes_sorted = sorted(self.nodes)
        for i in range(len(nodes_sorted)):
            for j in range(i + 1, len(nodes_sorted)):
                A = nodes_sorted[i]
                B = nodes_sorted[j]
                conditioning_vars = [node for node in nodes_sorted if node not in (A, B)]
                minimal_sets: List[frozenset[str]] = []
                for k in range(len(conditioning_vars) + 1):
                    for comb in itertools.combinations(conditioning_vars, k):
                        candidate = frozenset(comb)
                        if self.is_cond_independent(A, B, list(candidate)):
                            minimal_sets.append(candidate)
                for s in minimal_sets:
                    independencies.append((A, B, sorted(s)))
        return independencies
    

class DAGToSimplify:
    def __init__(self, mcgraph: MCGraph, remove_node: str):
        """Class that we use for each node that we remove"""
        
        self.mcgraph = mcgraph
        self.remove_node = remove_node
        
        self.children = list(mcgraph.get_children(remove_node))
        self.parents_of_children = {}
        for child in self.children:
            self.parents_of_children[child] = list(mcgraph.get_parents(child))
        

    
    def get_all_topo_orders(self) -> List[List[str]]:
        """Gets all topological orders that are compatible with the directed 
        edges in our graph."""
        
        adj = {child: [] for child in self.children} # lists descendants
        
        in_degree = {child: 0 for child in self.children}
        
        for cur_child in self.children:
            ancestors = self.mcgraph.compute_ancestors({cur_child})
            
            for other_child in self.children:
                if other_child != cur_child and other_child in ancestors:
                    adj[other_child].append(cur_child)
                    in_degree[cur_child] += 1
                    
        result = []
        path = []
        
        def find_all_sorts_recursive():
            if len(path) == len(self.children):
                result.append(list(path))
                return
            
            for node in self.children:
                if in_degree[node] == 0 and node not in path:
                    path.append(node)
                    for neighbor in adj[node]:
                        in_degree[neighbor] -= 1
                        
                    find_all_sorts_recursive()
                    
                    # reversing the addition of this node
                    for neighbor in adj[node]:
                        in_degree[neighbor] += 1
                    path.pop()
                
        find_all_sorts_recursive()
        return result              
    
    
    def create_graph_with_optimal_edges(self, ordering: List[str]) -> MCGraph:
        """
        Given an ordering of edge reversals, creates the corresponding graph
        1. Adds dir edges between all children according to order
        2. Add missing edges from parents to children
        3. Add all parents of R to Di
        4. Remove node
        """
        new_graph = self.mcgraph.copy()
        
        # step 1
        for i in range(len(ordering)):
            for j in range(i + 1, len(ordering)):
                earlier_child = ordering[i]
                later_child = ordering[j]
                new_graph.add_edge(earlier_child, '-->', later_child)
        
        # step 2
        for i, child in enumerate(ordering):
            parents = self.parents_of_children.get(child, [])
            for parent in parents:
                for j in range(i + 1, len(ordering)):
                    later_child = ordering[j]
                    new_graph.add_edge(parent, '-->', later_child)
        
        # step 3
        parents_of_remove = self.mcgraph.get_parents(self.remove_node)
        for parent in parents_of_remove:
            for child in ordering:
                new_graph.add_edge(parent, '-->', child)

        new_graph.remove_node(self.remove_node)
        
        return new_graph
    

class MultiNodeDAGToSimplify:
    def __init__(self, mcgraph: MCGraph, nodes_to_remove: List[str]):
        """
        Simplified class for removing multiple nodes by reusing DAGToSimplify logic.
        
        Args:
            mcgraph: The original MCGraph
            nodes_to_remove: List of nodes to remove in order
        """
        self.original_graph = mcgraph
        self.nodes_to_remove = nodes_to_remove
        self.current_graph = mcgraph.copy()
        self.removal_sequence = []
        self.orders_tried = 0
        
    def calculate_net_cost(self, original_graph: MCGraph, final_graph: MCGraph) -> int:
        """
        Calculate the net cost by counting num of edges added throughout the process
        
        Args:
            original_graph: The original graph before any removals
            final_graph: The final graph after all removals
            
        Returns:
            Number of directed edges that were added during the removal process
        """
        original_edges = set()
        for (start, end), edge_types in original_graph.edges.items():
            if '-->' in edge_types:
                original_edges.add((start, end))
        
        final_edges = set()
        for (start, end), edge_types in final_graph.edges.items():
            if '-->' in edge_types:
                final_edges.add((start, end))
        
        added_edges = final_edges - original_edges
        
        return len(added_edges)

    def get_topological_orders(self, nodes: List[str]):
        """
        Does DFS to find all valid node removal orders (nodes get removed from highest
        to lowest topological order, hence the reversed part).
        """
        
        adj = self.original_graph._build_directed_adjacency_list(nodes_subset=nodes)

        in_degree = {node: 0 for node in nodes}
        for u in adj:
            for v in adj[u]:
                in_degree[v] += 1

        path = []

        def find_all_toposorts_recursive():
            if len(path) == len(nodes):
                yield list(reversed(path))
                return

            for node in nodes:
                if in_degree[node] == 0 and node not in path:
                    path.append(node)
                    for neighbor in adj[node]:
                        in_degree[neighbor] -= 1

                    yield from find_all_toposorts_recursive()

                    for neighbor in adj[node]:
                        in_degree[neighbor] += 1
                    path.pop()

        yield from find_all_toposorts_recursive()
            

    

    def _search_one_order(self, removal_order: List[str]) -> Tuple[float, Set[Tuple], List[MCGraph]]:
        """Given a single order, it explores child edge reversal permutations to tell you minimal cost found,
        set of optimal strategies in the form of (node removal order, (child reversal orders per node)),
        and number of nodes explored."""
    
        best_cost = float('inf')
        optimal_strategies = set()
        final_graphs = []
        explore_count = 0

        def _explore_and_evaluate(rem_nodes, current_graph, collected_orderings):
            nonlocal best_cost, optimal_strategies, final_graphs, explore_count

            if not rem_nodes:
                explore_count += 1
                total_cost = self.calculate_net_cost(self.original_graph, current_graph)
                strategy = (tuple(removal_order), tuple(map(tuple, collected_orderings)))

                if total_cost < best_cost:
                    best_cost = total_cost
                    optimal_strategies = {strategy}
                    final_graphs = [current_graph]
                elif total_cost == best_cost:
                    optimal_strategies.add(strategy)
                    final_graphs.append(current_graph)
                return

            node = rem_nodes[0]
            if node not in current_graph.nodes: return
                
            dag_simplify = DAGToSimplify(current_graph, node)
            valid_orderings = dag_simplify.get_all_topo_orders()
            
            if not valid_orderings: return
            
            for ordering in valid_orderings:
                new_graph = dag_simplify.create_graph_with_optimal_edges(ordering)
                _explore_and_evaluate(rem_nodes[1:], new_graph, collected_orderings + [ordering])
        
        _explore_and_evaluate(removal_order, self.original_graph.copy(), [])
                
        return best_cost, optimal_strategies, final_graphs, explore_count


    
    def solve(self) -> Tuple[MCGraph, List[Dict], Set[Tuple]]:
        """
        Calls the function above for all orders that are considered. Returns the list of best
        sequences, graphs, and strategies.
        """
        overall_best_cost = float('inf')
        all_optimal_strategies = set()
        best_final_graphs = []
        
        orders_to_test = itertools.permutations(self.nodes_to_remove)

        for removal_order in orders_to_test:
            cost, strategies, graphs, explore_count = self._search_one_order(removal_order)
            self.orders_tried += explore_count
            
            if not strategies: continue 

            if cost < overall_best_cost:
                overall_best_cost = cost
                all_optimal_strategies = strategies
                best_final_graphs = graphs
            elif cost == overall_best_cost:
                all_optimal_strategies.update(strategies)
                best_final_graphs.extend(graphs)
        
        if not best_final_graphs:
            raise ValueError("No valid removal sequence found for any of the tested orders.")
        
        self.current_graph = best_final_graphs[0]
        first_strategy = next(iter(all_optimal_strategies))
        self.removal_sequence = first_strategy
        
        return best_final_graphs, all_optimal_strategies
    
    
if __name__ == "__main__":
    # example with using the code
    nodes = [f'N{i+1}' for i in range(0, 3)]

    # creating a simple graph
    mcgraph = MCGraph(nodes)

    edges = {('N2', 'N1'): {'-->'}, ('N2', 'N3'): {'-->'}}
    for key, val in edges.items():
        start, end = key
        mcgraph.add_edge(start, val.pop(), end)
        
    nodes_to_remove = ['N2']

    within_BNs = MultiNodeDAGToSimplify(mcgraph, nodes_to_remove)
    bn_graphs, strats = within_BNs.solve()
    bn_graph_set = set()

    # printing all unique graphs returned
    for bn_graph in bn_graphs:
        bn_sig = get_dag_signature(bn_graph)
        if bn_sig not in bn_graph_set:
            print(bn_graph.edges)
            bn_graph_set.add(bn_sig)