from typing import Iterable, Iterator, Callable, Self, Literal, Type, cast
from collections.abc import Collection
from functools import reduce
from itertools import combinations
from random import random, shuffle
import math


Node = int


class Edge:

    def __init__(self, n1: Node, n2: Node):
        assert n1 < n2
        self.n1: Node = n1
        self.n2: Node = n2

    def __repr__(self):
        return f'{self.n1} -> {self.n2}'

    def __iter__(self) -> Iterator[Node]:
        yield self.n1
        yield self.n2

    def __contains__(self, other: Node) -> bool:
        return other == self.n1 or other == self.n2

    def connected_to(self, n: Node) -> Node:
        if n == self.n1:
            return self.n2
        elif n == self.n2:
            return self.n1
        else:
            raise KeyError(n)


def is_chain(e1: Edge, e2: Edge) -> bool:
    return e1.n2 == e2.n1 or e2.n2 == e1.n1


def is_fork(e1: Edge, e2: Edge) -> bool:
    return e1.n1 == e2.n1


def is_collider(e1: Edge, e2: Edge) -> bool:
    return e1.n2 == e2.n2


class DAG:

    def __init__(self, K: int) -> None:
        self.nodes : list[Node] = list(range(K))
        self.edges : list[Edge] = []

    @property
    def K(self) -> int:
        return len(self.nodes)

    @property
    def X(self) -> list[Node]:
        return self.nodes[:-1]

    @property
    def Y(self) -> Node:
        return self.nodes[-1]

    def add_edge(self, edge: Edge) -> None:
        self.edges.append(edge)

    def parents(self, node: Node) -> Iterator[Node]:
        for edge in self.edges:
            if edge.n2 == node:
                yield edge.n1

    def children(self, node: Node) -> Iterator[Node]:
        for edge in self.edges:
            if edge.n1 == node:
                yield edge.n2

    def ancestors(self, x: Node, s: set[Node] | None = None) -> set[Node]:
        if s is None:
            s = set()

        s.add(x)

        for parent in self.parents(x):
            if parent not in s:
                self.ancestors(parent, s)
        
        return s

    def descendants(self, x: Node, s: set[Node] | None = None) -> set[Node]:
        if s is None:
            s = set()

        s.add(x)

        for child in self.children(x):
            if child not in s:
                self.descendants(child, s)
        
        return s

    def topological_order(self) -> Iterator[tuple[int, Node]]:
        parents = {
            node: list(self.parents(node))
            for node in self.nodes
        }

        depth = 0
        while parents:
            level = []
            for node, ps in parents.items():
                if not ps:
                    yield depth, node
                    level.append(node)

            for node in level:
                parents.pop(node)
                for _, ps in parents.items():
                    if node in ps:
                        ps.remove(node)

            if not level:
                # Cycle
                raise Exception('Not a DAG')
            
            depth += 1

    @classmethod
    def parse(cls, K: int, text: str) -> Self:
        graph = cls(K)

        for line in text.split('\n'):
            line = line.strip()

            a, b = map(lambda x: int(x.strip()), line.split('->'))
            graph.add_edge(Edge(a, b))
        
        return graph

    def filter_edges(self, f: Callable[[Edge], bool]) -> Self:
        graph = self.__class__(self.K)

        for edge in self.edges:
            if f(edge):
                graph.add_edge(edge)

        return graph

    def incoming_edges(self, nodes: Iterable[Node]) -> Iterator[Edge]:
        for edge in self.edges:
            if edge.n2 in nodes:
                yield edge

    def outgoing_edges(self, nodes: Iterable[Node]) -> Iterator[Edge]:
        for edge in self.edges:
            if edge.n1 in nodes:
                yield edge

    def paths(
        self, x: Node, y: Node, path: tuple[Edge, ...] = ()
    ) -> Iterator[tuple[Edge, ...]]:
        if x == y:
            yield path
        else:
            for edge in self.edges:
                if x in edge:
                    z = edge.connected_to(x)
                    if all(z not in edge for edge in path):  # z not visited
                        yield from self.paths(z, y, path + (edge,))

    def directed_paths(
        self, x: Node, y: Node, path: tuple[Node, ...] = ()
    ) -> Iterator[tuple[Node, ...]]:
        path = path + (x,)

        if x == y:
            yield path
        else:
            for ch in self.children(x):
                if ch not in path:
                    yield from self.directed_paths(ch, y, path)

    def dsep(
        self, x: Node | set[Node], y: Node | set[Node], z: set[Node]
    ) -> bool:
        # If x or y are tuples, go node by node
        if isinstance(x, set):
            return all(self.dsep(xi, y, z) for xi in x)
        if isinstance(y, set):
            return all(self.dsep(x, yi, z) for yi in y)
        
        # Now we can test for every node individually
        assert isinstance(x, Node) and isinstance(y, Node)
        return all(self.dsep_path(z, path) for path in self.paths(x, y))

    def dsep_path(self, z: Iterable[Node], path: tuple[Edge, ...]) -> bool:
        for pair in zip(path[:-1], path[1:], strict=True):
            e1, e2 = pair
            if is_chain(*pair):
                if e1.n2 == e2.n1:
                    b = e1.n2
                elif e2.n2 == e1.n1:
                    b = e2.n2

                if b in z:
                    return True
            elif is_fork(*pair):
                b = e1.n1  # = e2.n1

                if b in z:
                    return True
            elif is_collider(*pair):
                b = e1.n2  # = e2.n2

                if all(desc not in z for desc in self.descendants(b)):
                    return True
            else:
                raise ValueError(f'Invalid path: {path}')
        else:
            # No breaks, so path was not interrupted
            return False

    def dsep_directed_path(
        self, z: Iterable[Node], path: tuple[Node, ...]
    ) -> bool:
        """Check if path = V1->···->Vn, represented as path = (V1, ..., Vn)
        is d-separated (blocked) by z."""
        for node in path:
            if node in z:
                return True
        else:
            return False


def complete_graph(k: int) -> DAG:
    graph = DAG(k)

    for i in range(k):
        for j in range(i + 1, k):
            graph.add_edge(Edge(i, j))
    
    return graph


def sample_dag(k: int, p: float) -> DAG:
    graph = DAG(k)
    for i in range(k):
        for j in range(i + 1, k):
            if random() < p:
                graph.add_edge(Edge(i, j))
    
    return graph


def check_all_ancestors(graph: DAG) -> bool:
    ancestors = set(graph.ancestors(graph.Y))
    return all(node in ancestors for node in graph.X)


def check_all_parents(graph: DAG) -> bool:
    return sum(1 for _ in graph.parents(graph.Y)) == len(graph.X)


def sample_dag_all_ancestors(k: int, p: float, rejection: bool = True) -> DAG:
    """Returns a graph with k+1 elements (last one being y)"""
    if rejection:  # use rejection sampling
        while not check_all_ancestors(graph := sample_dag(k + 1, p)): ...
    else:
        graph = sample_dag(k + 1, p)
        # Add directed edges X->Y for any non-ancestor of Y,
        # moving backwards in the topological order
        # (which might make any ancestors of the node an ancestor of Y too).
        for node in reversed(graph.X):
            if graph.Y not in graph.descendants(node):
                graph.add_edge(Edge(node, graph.Y))
        
        assert check_all_ancestors(graph)

    return graph


def sample_dag_all_ancestors_not_all_parents(
    k: int, p: float
) -> DAG:
    """Returns a graph with k+1 elements (last one being y)"""
    while not (
        check_all_ancestors(graph := sample_dag(k + 1, p)) and
        not check_all_parents(graph)
    ): 
        ...

    return graph


def R3(g: DAG, x: set[Node], y: set[Node], z: set[Node], w: set[Node]) -> bool:
    incoming_x = tuple(g.incoming_edges(x))
    gx = g.filter_edges(lambda x: x not in incoming_x)
    zw = z.difference(*(gx.ancestors(wi) for wi in w))
    filtered = incoming_x + tuple(g.incoming_edges(zw))
    
    return g.filter_edges(lambda e: e not in filtered).dsep(y, z, x.union(w))


def is_frontier(graph: DAG, x: Node, y: Node, z: Iterable[Node]) -> bool:
    return all(
        graph.dsep_directed_path(z, path)
        for path in graph.directed_paths(x, y)
    )


def parts_of(s: Collection) -> Iterator[tuple]:
    for i in range(len(s) + 1):
        for comb in combinations(s, i):
            yield comb


class FR1:

    def __init__(self, graph: DAG):
        self.graph: DAG = graph

        self.fr_cache: dict[tuple[Node, ...], bool] = {}
        self.v_cache: set[tuple[Node, ...]] = set()

        self.parents_Y: set[Node] = set(self.graph.parents(self.graph.Y))
        self.children: dict[Node, set[Node]] = {
            node: set(self.graph.children(node))
            for node in self.graph.X
        }
        self.De_X: dict[Node, set[Node]] = {
            node: set(self.graph.descendants(node))
            for node in self.graph.X
        }

    def is_frontier(self, x: Node, s: tuple[Node, ...]) -> bool:
        c = {x}
        v = c.union(s)
        while c and self.graph.Y not in c:
            c = set.union(set(), *(
                self.children[x]
                for x in c
            )) - v
            v.update(c)

        return not c

    def run(self, s: Iterable[Node]) -> tuple[Node, ...]:
        s = sorted(s, reverse=True)
        
        p: tuple[Node, ...] = ()
        z: set[Node] = set()
        
        for x in s:
            if x not in self.parents_Y:
                p_de = tuple(elem for elem in p if elem in self.De_X[x])
                t = p_de + (x,)
                if (fr := self.fr_cache.get(t)) is None:
                    self.fr_cache[t] = fr = self.is_frontier(x, p_de)

                if fr:
                    z.add(x)

            p = p + (x,)
        
        s = tuple(x for x in s if x not in z)
        self.v_cache.add(s)

        return s

    def clear(self) -> None:
        """Clear v_cache, but keep fr_cache (stable for G)"""
        self.v_cache.clear()

    def cache_redirect(self, comb: tuple[int, ...]) -> int:
        # For SHAP
        return self.run(comb)

class FR2:

    def __init__(self, graph: DAG):
        self.graph: DAG = graph

        self.fr_cache: dict[int, bool] = {}
        self.v_cache: set[int] = set()

        self.y_encoded: int = self.encode_x(graph.Y)

        self.parents_Y: set[int] = set(
            map(self.encode_x, graph.parents(graph.Y))
        )
        self.children: dict[int, int] = {
            self.encode_x(node): self.encode_S(graph.children(node))
            for node in graph.X
        }
        self.DeX: dict[int, int] = {
            self.encode_x(x): self.encode_S(self.graph.descendants(x))
            for x in self.graph.X
        }

    @property
    def decoded_fr_cache(self) -> dict[tuple[Node, ...], bool]:
        return {
            tuple(self.iter_S(k, encoded=False)): v
            for k, v in self.fr_cache.items()
        }

    @property
    def decoded_v_cache(self) -> set[tuple[Node, ...]]:
        return {
            tuple(self.iter_S(k, encoded=False))
            for k in self.v_cache
        }

    @staticmethod
    def encode_x(x: int) -> int:
        return 1 << x

    @staticmethod
    def encoded_union(encoded_elements: Iterable[int]) -> int:
        return reduce(lambda x, y: x | y, encoded_elements, 0)

    @classmethod
    def encode_S(cls, S: Iterable[int]) -> int:
        return cls.encoded_union(map(cls.encode_x, S))

    @staticmethod
    def decode_x(x: int) -> int:
        # return int(math.log2(x))  # won't scale well with high K
        return x.bit_length() - 1

    @classmethod
    def iter_S(cls, s: int, encoded: bool = False) -> Iterable[int]:
        while s > 0:
            x = cls.decode_x(s)
            xe = cls.encode_x(x)
            
            yield xe if encoded else x
            
            s -= xe

    def max_encoded(self, x: int) -> int:
        """Return the highest element, encoded"""
        return self.encode_x(self.decode_x(x))

    def is_frontier(self, x: int, s: int) -> bool:
        """Given x and s encoded, is p frontier for x?"""
        c = x
        v = s + c
        while c != 0 and self.y_encoded & c == 0:
            c = self.encoded_union(
                map(
                    self.children.__getitem__,
                    self.iter_S(c, encoded=True)
                )
            ) & ~v
            v |= c
        
        return c == 0

    def run(self, S: tuple[int, ...] | int) -> int:
        p: int = 0
        z: int = 0
        if isinstance(S, int):
            s = S
        else:
            s = self.encode_S(S)

        while s > 0:
            x = self.max_encoded(s)
            s -= x

            # If x is a parent of Y, can't be a frontier, no need to cache it
            if x not in self.parents_Y:
                p_de = p & self.DeX[x]  # p filtered to descendants of x
                code = p_de + x  # code identifying the pair (p_de, x)
                if (fr := self.fr_cache.get(code)) is None:
                    self.fr_cache[code] = fr = self.is_frontier(x, p_de)

                if fr:
                    z += x  # add x to removable nodes

            p += x  # don't forget to add xe to p
        
        s = p - z
        self.v_cache.add(s)

        return s

    def clear(self) -> None:
        """Clear v_cache, but keep fr_cache (stable for G)"""
        self.v_cache.clear()

    def cache_redirect(self, comb: tuple[int, ...]) -> int:
        # For SHAP
        return self.run(sum(2 ** k for k in comb))
