from typing import List, Tuple, Set, Optional, Dict
import itertools
from collections import deque
import heapq

def get_dag_signature(graph: 'MCGraph') -> Tuple[Tuple[str, str], ...]:
    """in order to compare DAGs, we want to give them some kind of 
    signature that is hashable, so we can determine if the graphs
    produced by both methods are the same. 
    
    So, we return a sorted list of (source, target) tuples."""
    directed_edges = []
    for (source, target), edge_types in graph.edges.items():
        if '-->' in edge_types:
            # Ensure consistent ordering for the signature
            if source < target:
                directed_edges.append((source, target))
    directed_edges.sort()
    return tuple(directed_edges)


edge_map = {
    ('-->', '<--'): '',
    ('<--', '<--'): '<--',
    ('<->', '<--'): '',
    ('-->', '-->'): '-->',
    ('<--', '-->'): '<->',
    ('<->', '-->'): '<->',
    ('-->', '<->'): '',
    ('<--', '<->'): '<->',
    ('<->', '<->'): ''
}

def flip(edge: str):
    if edge == '-->':
        return '<--'
    elif edge == '<--':
        return '-->'
    return edge



class MCGraph:
    def __init__(self, nodes: Optional[List[str]] = None, edge_tuples: Optional[List[Tuple[str, str, str]]] = None):
        self.nodes = []
        self.edges = {}
        if nodes is not None:
            for node in nodes:
                self.add_node(node)
        if edge_tuples is not None:
            for start, edge_type, end in edge_tuples:
                self.add_edge(start, edge_type, end)

    def add_node(self, node: str):
        self.nodes.append(node)

    def add_edge(self, start_node: str, edge_type: str, end_node: str, add_reverse=True):
        assert edge_type in ['-->', '<--', '<->']
        if (start_node, end_node) not in self.edges:
            self.edges[(start_node, end_node)] = set()
        self.edges[(start_node, end_node)].add(edge_type)

        if add_reverse:
            self.add_edge(end_node, flip(edge_type), start_node, add_reverse=False)

    def remove_edge(self, start_node: str, edge_type: str, end_node: str):
        if (start_node, end_node) in self.edges:
            self.edges[(start_node, end_node)].discard(edge_type)
        if (end_node, start_node) in self.edges:
            self.edges[(end_node, start_node)].discard(flip(edge_type))
            
    def has_directed_edge(self, u: str, v: str) -> bool:
        """Check if there's a directed edge from u to v (u --> v)"""
        return (u, v) in self.edges and '-->' in self.edges[(u, v)]
    
    def has_bidirected_edge(self, u: str, v: str) -> bool:
        return (u, v) in self.edges and '<->' in self.edges[(u, v)]
            
    def remove_node(self, node):
        if node in self.nodes:
            self.nodes.remove(node)

        for k, v in list(self.edges.items()):
            if k[0] == node or k[1] == node:
                del self.edges[k]

    def marginalize_out_node(self, node: str):
        """Applies the method in the 1999 Koster paper in order to marginalize out nodes.
        This is able to handle double edges, since each edge will be stored separately 
        in removed edges, so we will add another double edge to our new graph. In other 
        words, if A <-> B and A --> B, B --> C and we marginalize B, this is not a problem
        since we will add both of the edges to removed edges, and thus add both A <-> C and
        A --> C to our new_graph. """
        new_graph = MCGraph()
        for n in self.nodes:
            if n != node:
                new_graph.add_node(n)

        removed_edges = []
        for k, v in self.edges.items():
            if node not in k:
                for edge in v:
                    new_graph.add_edge(k[0], edge, k[1], add_reverse=False)
            elif k[1] == node:
                for edge in v:
                    removed_edges.append((k[0], edge, k[1]))

        for i in range(len(removed_edges)):
            for j in range(i + 1, len(removed_edges)):
                start, edge1, _ = removed_edges[i]
                end, edge2, _ = removed_edges[j]
                if start == end:
                    continue
                edge2 = flip(edge2)
                new_edge = edge_map[(edge1, edge2)]
                if len(new_edge) > 0:
                    new_graph.add_edge(start, new_edge, end)

        return new_graph

    def to_ADMG(self, nodes: List[str]):
        if len(nodes) == 0:
            return self.copy()
        new_graph = self.marginalize_out_node(nodes[0])
        for node in nodes[1:]:
            new_graph = new_graph.marginalize_out_node(node)
        return new_graph

    def _build_directed_adjacency_list(self, nodes_subset=None):
        """Build adjacency list for directed edges ('-->') only."""
        nodes_to_use = nodes_subset if nodes_subset is not None else self.nodes
        adjacency_list = {node: set() for node in nodes_to_use}
        
        for (start, end), edges in self.edges.items():
            for edge_type in edges:
                if edge_type == '-->':
                    if start in adjacency_list and end in adjacency_list:
                        adjacency_list[start].add(end)
                elif edge_type == '<--':
                    if end in adjacency_list and start in adjacency_list:
                        adjacency_list[end].add(start)
        
        return adjacency_list
    
    def _build_bidir_adjacency_list(self, nodes_subset=None):
        nodes_to_use = nodes_subset if nodes_subset is not None else self.nodes
        adjacency_list = {node: set() for node in nodes_to_use}
        
        for (start, end), edges in self.edges.items():
            if '<->' in edges:
                adjacency_list[start].add(end)
                
        return adjacency_list
                
    def has_cycle(self):
        adjacency_list = self._build_directed_adjacency_list()

        # Function to perform DFS and detect cycles
        def dfs(node, visited, rec_stack):
            visited.add(node)
            rec_stack.add(node)
            for neighbor in adjacency_list[node]:
                if neighbor not in visited:
                    if dfs(neighbor, visited, rec_stack):
                        return True
                elif neighbor in rec_stack:
                    return True
            rec_stack.remove(node)
            return False

        visited = set()
        rec_stack = set()
        for node in self.nodes:
            if node not in visited:
                if dfs(node, visited, rec_stack):
                    return True
        return False

    def get_incoming(self, node, include_bidirectional=False):
        incoming = set()
        for (start, end), edges in self.edges.items():
            if include_bidirectional and end == node and '<->' in edges:
                incoming.add(start)
            elif end == node and '-->' in edges:
                incoming.add(start)
        return incoming
    
    def get_children(self, node: str) -> Set[str]:
        """gets all of the children of this node"""
        children = set()
        for (start, end), edge_types in self.edges.items():
            if start == node and '-->' in edge_types:
                children.add(end)
        return children
    
    def get_parents(self, node: str) -> Set[str]:
        parents = set()
        for (start, end), edge_types in self.edges.items():
            if end == node and '-->' in edge_types:
                parents.add(start)
        return parents

    def copy(self):
        new_graph = MCGraph()
        for node in self.nodes:
            new_graph.add_node(node)
        for (start, end), edges in self.edges.items():
            for edge in edges:
                new_graph.add_edge(start, edge, end)
        return new_graph
    
    def compute_ancestors(self, nodes: Set[str]) -> Set[str]:
        """
        Compute the ancestors (with respect to the directed edges) of the nodes in 'nodes'.
        (Only edges with label '-->' are used to define ancestry.)
        """
        ancestors = set(nodes)
        changed = True
        while changed:
            changed = False
            for (u, v), labels in self.edges.items():
                if '-->' in labels and v in ancestors and u not in ancestors:
                    ancestors.add(u) 
                    changed = True
        return ancestors
    
    def is_cond_independent(self, X: str, Y: str, Z: List[str]) -> bool:
        Z_set = set(Z)
        A_active = self.compute_ancestors(Z_set) | Z_set
        queue = deque()
        visited = set()
        queue.append((X, None))
        visited.add((X, None))
        while queue:
            (v, flag) = queue.popleft()
            if v == Y and flag is not None:
                return False
            for (u, w) in self.edges:
                if u != v:
                    continue
                for L in self.edges[(u, w)]:
                    if L == '-->':
                        new_flag = 1
                        new_edge_arrow = 0
                    elif L == '<--':
                        new_flag = 0
                        new_edge_arrow = 1
                    else: # L == '<->'
                        new_flag = 1
                        new_edge_arrow = 1
                    if flag is not None:
                        if flag == 1 and new_edge_arrow == 1:
                            if v not in A_active:
                                continue
                        else:
                            if v in Z_set:
                                continue
                    state = (w, new_flag)
                    if state in visited:
                        continue
                    visited.add(state)
                    queue.append(state)
        return True

    def get_all_conditional_independencies(self) -> List[Tuple[str, str, List[str]]]:
        independencies = []
        nodes_sorted = sorted(self.nodes)
        for i in range(len(nodes_sorted)):
            for j in range(i + 1, len(nodes_sorted)):
                A = nodes_sorted[i]
                B = nodes_sorted[j]
                conditioning_vars = [node for node in nodes_sorted if node not in (A, B)]
                minimal_sets: List[frozenset[str]] = []
                for k in range(len(conditioning_vars) + 1):
                    for comb in itertools.combinations(conditioning_vars, k):
                        candidate = frozenset(comb)
                        if self.is_cond_independent(A, B, list(candidate)):
                            minimal_sets.append(candidate)
                for s in minimal_sets:
                    independencies.append((A, B, sorted(s)))
        return independencies
    
    def marginalize_nodes_by_ht(self, marg: Set[str]) -> 'MCGraph':
        # Step 1: Marginalize out nodes to get an ADMG
        graph2 = self.to_ADMG(list(marg))
        nodes = list(graph2.nodes)
        new_graph = MCGraph(nodes=nodes)
        # Step 2: Add directed edges
        for v in nodes:
            anc_v = graph2.compute_ancestors({v})
            district_v = graph2._district(v)
            # Tail: dis_an(v)(v) union pa_G(dis_an(v)(v)) \ {v}
            tail = set()
            # dis_an(v)(v): district of v in induced subgraph on ancestors of v
            district_anc_v = graph2._district(v, nodes_subset=anc_v)
            tail.update(district_anc_v)
            # pa_G(dis_an(v)(v)): parents of all nodes in district_anc_v
            for u in district_anc_v:
                tail.update(graph2.get_parents(u))
            tail.discard(v)
            for w in tail:
                new_graph.add_edge(w, '-->', v)
        # Step 3: Add bidirected edges
        # For all pairs v, w in the same district in graph2 and with no ancestral relation
        for v in nodes:
            district_v = graph2._district(v)
            for w in district_v:
                if w == v:
                    continue
                # No ancestral relation -- i think this is excessive
                if w in graph2.compute_ancestors({v}) or v in graph2.compute_ancestors({w}):
                    continue
                # Compute ancestors of v and w
                anc_vw = graph2.compute_ancestors({v, w})
                district_in_sub = graph2._district(v, nodes_subset=anc_vw)
                if w in district_in_sub:
                    new_graph.add_edge(v, '<->', w)
        return new_graph
    
    def _district(self, node, nodes_subset=None):
        """oops, i'm defining this again"""
        # Returns the set of nodes in the district of 'node' (connected by <-> paths)
        # If nodes_subset is given, restrict to those nodes
        visited = set()
        stack = [node]
        while stack:
            v = stack.pop()
            if v in visited:
                continue
            visited.add(v)
            for (a, b), labels in self.edges.items():
                if '<->' in labels:
                    if a == v and (nodes_subset is None or b in nodes_subset):
                        stack.append(b)
        return visited
    


class BidirectionalSolver:
    def __init__(self, initial_graph: 'MCGraph', nodes_to_remove: Set[str]):
        self.initial_graph = initial_graph
        self.nodes_to_remove = nodes_to_remove
        self.marginalized_graph = None
        self.bidirectional_components = []
        self.base_graph = None
        self.graphs_considered = 0
        
    def solve(self) -> List['MCGraph']:
        """
        Goes through a few steps
        1. first converts it to a MAG
        2. gets the bidirectional components
        3. finds all valid topological orderings and simplifies the component
        4. combines to find all possible minimal solutions
        """
        self.marginalized_graph = self.initial_graph.marginalize_nodes_by_ht(self.nodes_to_remove)
        
        self.bidirectional_components = self._find_bidirectional_components()
        self.base_graph = self._create_base_graph()
        
        component_solutions_by_cost = []
        for component in self.bidirectional_components:
            solutions_by_cost = self._solve_component(component)
            component_solutions_by_cost.append(solutions_by_cost)
        
        final_graphs = self._combine_component_solutions(component_solutions_by_cost)
        
        return final_graphs
    
    def _create_base_graph(self) -> 'MCGraph':
        """
        Base graph: the original graph with all of its bidir edges removed
        """        
        
        new_graph = MCGraph(nodes=self.marginalized_graph.nodes)
        for (u, v), edge_types in self.marginalized_graph.edges.items():
            for edge_type in edge_types:
                if edge_type == '<->':
                    continue
                else:
                    new_graph.add_edge(u, edge_type, v, add_reverse=False) # have both directionss already
                    
        return new_graph
    
    def _find_bidirectional_components(self) -> List[Set[str]]:
        """
        Gets the unique bidirectional components. Output is a 
        list of bidirectional components, each of which is a set.
        """
        visited = set()
        components = []
        
        for node in self.marginalized_graph.nodes:
            if node not in visited:
                component = self._dfs_bidirectional_component(node, visited)
                if len(component) > 1:  # single nodes are not components 
                    components.append(component)
        
        return components
    
    def _dfs_bidirectional_component(self, start_node: str, visited: Set[str]) -> Set[str]:
        """
        Does DFS to find components, adds all found nodes to the visited set,
        and returns all of the nodes in that component as a set
        """
        component = set()
        stack = [start_node]
        
        while stack:
            node = stack.pop()
            if node in visited:
                continue
                
            visited.add(node)
            component.add(node)
            
            for (u, v), edge_types in self.marginalized_graph.edges.items():
                if '<->' in edge_types:
                    if u == node and v not in visited:
                        stack.append(v)
        
        return component
    
    def _solve_component(self, component: Set[str]) -> List[Tuple['MCGraph', int]]:
        """
        Output is dict where the key is the cost (num edges) and the value is a list of 
        graphs with that cost
        
        First, this turns items in component from a set into a list. Then, gets all of
        the valid topological orderings.
        For each topo ordering, determine the number of edges that get added and add to
        the value of the dictionary key entry where the key is the cost
        """
        component_list = list(component)
        solutions_by_cost = {}
        
        # go thru all topo orders 
        for ordering in self._get_all_topological_orderings(component_list):
            self.graphs_considered += 1
            comp_graph, num_edges = self._get_edges_for_ordering(ordering, component)
            if num_edges not in solutions_by_cost:
                solutions_by_cost[num_edges] = []
            solutions_by_cost[num_edges].append(comp_graph)

        return solutions_by_cost
    
    
    def _generate_cost_combinations(self, component_solutions_by_cost: List[Dict[int, List['MCGraph']]]):
        """
        Generate all possible cost combinations in order of cost
        
        Start off with one entry in the heap, the min of each of the costs per component
        
        Then, when processing min of the heap, we find its 'neighbors' by taking those that have one higher cost,
        so one higher index in the list
        
        We will never get into a situation where we have min stuff not added to heap and process something else that's min
        because of strictly non-decreasing sums. Since we never "kill" any branches, that means all of the branches
        are incorporated in the heap, and they are all at least as expensive. 
        """
        
        component_costs = []
        for component_dict in component_solutions_by_cost:
            component_costs.append(sorted(component_dict.keys()))
            
        heap = []
        visited = set()
        
        min_combo = tuple(costs[0] for costs in component_costs)
        heapq.heappush(heap, (sum(min_combo), min_combo))
        visited.add(min_combo)
        
        while heap:
            total_cost, combo = heapq.heappop(heap)
            yield combo
            
            for i in range(len(combo)):
                current_cost_idx = component_costs[i].index(combo[i])
                if current_cost_idx + 1 < len(component_costs[i]):
                    new_combo = list(combo)
                    new_combo[i] = component_costs[i][current_cost_idx + 1]
                    new_combo = tuple(new_combo)
                    
                    if new_combo not in visited:
                        visited.add(new_combo)
                        heapq.heappush(heap, (sum(new_combo), new_combo))
                        
    
    def _get_all_topological_orderings(self, nodes: List[str]):
        adj = {} # keeps track of adjaceneis
        in_degree = {} # if in-deg is zero, can add node to end of list
        
        for node in nodes:
            adj[node] = []
            in_degree[node] = 0
        
        for (u, v), edge_types in self.marginalized_graph.edges.items():
            if u in nodes and v in nodes and '-->' in edge_types:
                adj[u].append(v)
                in_degree[v] += 1
        
        path = []
        
        def find_all_toposorts_recursive():
            if len(path) == len(nodes):
                # nodes get removed highest to lowest, so we flip it
                yield list(reversed(path))
                return
            
            for node in nodes:
                if in_degree[node] == 0 and node not in path:
                    path.append(node)
                    for neighbor in adj[node]:
                        in_degree[neighbor] -= 1
                    
                    yield from find_all_toposorts_recursive()
                    
                    # backtracking
                    for neighbor in adj[node]:
                        in_degree[neighbor] += 1
                    path.pop()
        
        yield from find_all_toposorts_recursive()
        
    
    
    def _create_component_subgraph(self, component: Set[str]) -> 'MCGraph':
        """
        when we want to simplify just one component, use this to get the minimal subgraph
        We add in in the parents and then all of the directed edges that are already exist
        Alternatively, this outputs the relevant subset of the base graph
        """
        parents = set()
        for node in component:
            parents.update(self.marginalized_graph.get_parents(node))
        
        relevant_nodes = component | parents
        
        subgraph = MCGraph()
        for node in relevant_nodes:
            subgraph.add_node(node)
        
        for (u, v), edge_types in self.marginalized_graph.edges.items():
            if u in relevant_nodes and v in relevant_nodes:
                for edge_type in edge_types:
                    if edge_type != '<->':
                        subgraph.add_edge(u, edge_type, v, add_reverse=False) # other direction already added, since we construct from existing graph
        
        return subgraph
    
    def _get_edges_for_ordering(self, ordering: List[str], component: Set[str]) -> Tuple['MCGraph', int]:
        """
        Want to count the number of edges we added given a component and a topological
        ordering for that component + the edges for that, returns (temp_graph, num added)
        
        Start with a subgraph, get rid of the bidirectional edges 
        Then, going from high to low topological ordering, process and add edges
        For this, we just check if the edge is needed by checking if it's blocked along all paths
        Otherwise, we add the necessary edges that don't already exist
        """
        temp_graph = self._create_component_subgraph(component)
                
        edges_added_count = 0
        
        for i in range(len(ordering)):
            for j in range(i + 1, len(ordering)):
                node_i = ordering[i]  
                node_j = ordering[j]  
                
                if not self._is_edge_blocked_in_current_graph(node_i, node_j, component, temp_graph):
                    if not temp_graph.has_directed_edge(node_j, node_i):
                        edges_added_count += 1
                        temp_graph.add_edge(node_j, '-->', node_i)

                    
                    parents_j = temp_graph.get_parents(node_j)
                    for parent in parents_j:
                        if not temp_graph.has_directed_edge(parent, node_i):
                            edges_added_count += 1
                            temp_graph.add_edge(parent, '-->', node_i)
        
        return temp_graph, edges_added_count
    

    
    def _combine_component_solutions(self, component_solutions_by_cost: List[Dict[int, List['MCGraph']]]) -> List['MCGraph']:
        
        """
        Input is component solutions by cost for all component sols, so [for component 1: {cost: [graphs], etc. }, then for graph 2 ]
        The function gets all cost combinations in sorted order with generate_cost_combinations
        
        Then, for each of these, we have diff cost combos, e.g. (2, 1, 3). For all of these, we go through
        all possible graphs of cost 2 for component 1, then all of cost 1 for component 2, and so on
        """
        if not component_solutions_by_cost:
            return [self.marginalized_graph.copy()]
        
        cost_combinations = self._generate_cost_combinations(component_solutions_by_cost)
        
        all_acyclic_graphs = []
        current_min_cost = None
        
        for cost_combo in cost_combinations:
            total_cost = sum(cost_combo)
            if current_min_cost is not None and total_cost > current_min_cost:
                break # stop searching, higher num of edges
            
            component_choices = []
            for i, cost in enumerate(cost_combo):
                component_choices.append(component_solutions_by_cost[i][cost])
                
            acyclic_graphs = []
            for combination in itertools.product(*component_choices):
                cur_graph = self.base_graph.copy()
                
                for component_graph in combination:
                    for (u, v), edge_types in component_graph.edges.items():
                        for edge_type in edge_types:
                            cur_graph.add_edge(u, edge_type, v)
                            
                if not cur_graph.has_cycle():
                    acyclic_graphs.append(cur_graph)
                    
            if acyclic_graphs:
                if current_min_cost is None:
                    current_min_cost = total_cost
                    
                all_acyclic_graphs.extend(acyclic_graphs)
                
            
        return all_acyclic_graphs
        
    
    def _is_edge_blocked_in_current_graph(self, i: str, j: str, component: Set[str], graph: 'MCGraph') -> bool:
        """
        Determines if we add an edge by checking if there's a path of all lower topological order
        """
        blocking_vertices = set()
        for k in component:
            if graph.has_directed_edge(i, k) and graph.has_directed_edge(j, k):
                blocking_vertices.add(k)
                
        
        return not self._has_path_avoiding(j, i, blocking_vertices, component)
    
    def _has_path_avoiding(self, start: str, end: str, avoid_set: Set[str], 
                                   component: Set[str]) -> bool:
        """
        Checks if there exists a path of only bidir edges that is not blocked, 
        given two vertices and a set that blocks
        """
        if start == end:
            return True
        
        graph = self.marginalized_graph
            
        visited = set()
        queue = deque([start])
        
        while queue:
            current = queue.popleft()
            
            if current == end:
                return True
                
            if current in visited or current in avoid_set:
                continue
                
            visited.add(current)
            
            for (u, v), edge_types in graph.edges.items():
                neighbor = None
                if u == current and v in component:
                    if '<->' in edge_types:
                        neighbor = v
                
                if neighbor and neighbor not in visited:
                    queue.append(neighbor)
        
        return False
    
if __name__ == "__main__":

    nodes = [f'N{i+1}' for i in range(0, 3)]
    test_graph = MCGraph(nodes=nodes)
    nodes_to_remove =  ['N2']

    edges =  {('N2', 'N1'): {'-->'}, ('N2', 'N3'): {'-->'},}

    for key, val in edges.items():
        start, end = key
        test_graph.add_edge(start, val.pop(), end)
        
    simplifier = BidirectionalSolver(test_graph, nodes_to_remove)
    solutions = simplifier.solve()
    for graph in solutions:
        print(graph.edges)
