"""
DAG (Directed Acyclic Graph) representation and operations.

This module provides the core DAG data structure used throughout
the Fisher dimension framework for representing causal graphs.
"""

from __future__ import annotations

from typing import Set, List, Tuple, Optional, FrozenSet, Iterator
from collections import deque
import numpy as np


class DAG:
    """
    Directed Acyclic Graph representation for causal models.

    Attributes:
        _num_nodes: Number of nodes in the graph
        _edges: Set of directed edges as (parent, child) tuples
        _adjacency_matrix: Cached adjacency matrix (invalidated on modification)
        _parents: Cached parent sets per node
        _children: Cached children sets per node
        _topological_order: Cached topological ordering
    """

    def __init__(self, num_nodes: int, edges: Optional[List[Tuple[int, int]]] = None):
        """
        Initialize a DAG.

        Args:
            num_nodes: Number of nodes (labeled 0 to num_nodes-1)
            edges: Optional list of (parent, child) tuples

        Raises:
            ValueError: If num_nodes < 1 or edges create a cycle
        """
        if num_nodes < 1:
            raise ValueError(f"Number of nodes must be positive, got {num_nodes}")

        self._num_nodes = num_nodes
        self._edges: Set[Tuple[int, int]] = set()

        # Cached structures
        self._parents: List[Set[int]] = [set() for _ in range(num_nodes)]
        self._children: List[Set[int]] = [set() for _ in range(num_nodes)]
        self._adjacency_matrix: Optional[np.ndarray] = None
        self._topological_order: Optional[List[int]] = None

        # Add edges if provided
        if edges:
            for parent, child in edges:
                self.add_edge(parent, child)

    def _validate_node(self, node: int) -> None:
        """Validate that a node index is within bounds."""
        if not 0 <= node < self._num_nodes:
            raise ValueError(f"Node {node} is out of bounds [0, {self._num_nodes})")

    def _invalidate_cache(self) -> None:
        """Invalidate cached structures after modification."""
        self._adjacency_matrix = None
        self._topological_order = None

    def _would_create_cycle(self, parent: int, child: int) -> bool:
        """
        Check if adding edge parent -> child would create a cycle.

        Uses DFS to check if child can already reach parent.
        """
        if parent == child:
            return True

        # Check if child can reach parent (would create cycle)
        visited = set()
        stack = [child]

        while stack:
            current = stack.pop()
            if current == parent:
                return True
            if current in visited:
                continue
            visited.add(current)
            stack.extend(self._children[current])

        return False

    def add_edge(self, parent: int, child: int) -> None:
        """
        Add a directed edge from parent to child.

        Args:
            parent: Source node
            child: Target node

        Raises:
            ValueError: If nodes are out of bounds or edge would create cycle
        """
        self._validate_node(parent)
        self._validate_node(child)

        if (parent, child) in self._edges:
            return  # Edge already exists

        if self._would_create_cycle(parent, child):
            raise ValueError(f"Adding edge {parent} -> {child} would create a cycle")

        self._edges.add((parent, child))
        self._parents[child].add(parent)
        self._children[parent].add(child)
        self._invalidate_cache()

    def remove_edge(self, parent: int, child: int) -> None:
        """
        Remove a directed edge from parent to child.

        Args:
            parent: Source node
            child: Target node

        Raises:
            ValueError: If edge doesn't exist
        """
        if (parent, child) not in self._edges:
            raise ValueError(f"Edge {parent} -> {child} does not exist")

        self._edges.remove((parent, child))
        self._parents[child].remove(parent)
        self._children[parent].remove(child)
        self._invalidate_cache()

    def has_edge(self, parent: int, child: int) -> bool:
        """Check if directed edge parent -> child exists."""
        return (parent, child) in self._edges

    def parents(self, node: int) -> Set[int]:
        """Return the set of parents of a node."""
        self._validate_node(node)
        return self._parents[node].copy()

    def children(self, node: int) -> Set[int]:
        """Return the set of children of a node."""
        self._validate_node(node)
        return self._children[node].copy()

    def ancestors(self, node: int) -> Set[int]:
        """
        Return all ancestors of a node (nodes that can reach this node).

        Uses BFS to find all nodes with a directed path to the given node.
        """
        self._validate_node(node)
        ancestors = set()
        queue = deque(self._parents[node])

        while queue:
            current = queue.popleft()
            if current not in ancestors:
                ancestors.add(current)
                queue.extend(self._parents[current])

        return ancestors

    def descendants(self, node: int) -> Set[int]:
        """
        Return all descendants of a node (nodes reachable from this node).

        Uses BFS to find all nodes reachable via directed paths.
        """
        self._validate_node(node)
        descendants = set()
        queue = deque(self._children[node])

        while queue:
            current = queue.popleft()
            if current not in descendants:
                descendants.add(current)
                queue.extend(self._children[current])

        return descendants

    def markov_blanket(self, node: int) -> Set[int]:
        """
        Return the Markov blanket of a node.

        MB(v) = Pa(v) ∪ Ch(v) ∪ Pa(Ch(v)) \ {v}

        The Markov blanket is the set of nodes that make the node
        conditionally independent of all other nodes.
        """
        self._validate_node(node)

        mb = set()

        # Parents
        mb.update(self._parents[node])

        # Children
        mb.update(self._children[node])

        # Parents of children (co-parents)
        for child in self._children[node]:
            mb.update(self._parents[child])

        # Remove the node itself
        mb.discard(node)

        return mb

    def topological_sort(self) -> List[int]:
        """
        Return a topological ordering of the nodes.

        Uses Kahn's algorithm for topological sorting.

        Returns:
            List of nodes in topological order (parents before children)
        """
        if self._topological_order is not None:
            return self._topological_order.copy()

        # Compute in-degrees
        in_degree = [len(self._parents[i]) for i in range(self._num_nodes)]

        # Initialize queue with nodes having no parents
        queue = deque([i for i in range(self._num_nodes) if in_degree[i] == 0])
        result = []

        while queue:
            node = queue.popleft()
            result.append(node)

            for child in self._children[node]:
                in_degree[child] -= 1
                if in_degree[child] == 0:
                    queue.append(child)

        if len(result) != self._num_nodes:
            raise RuntimeError("Graph has a cycle - this should not happen in a DAG")

        self._topological_order = result
        return result.copy()

    def is_acyclic(self) -> bool:
        """
        Check if the graph is acyclic.

        By construction, this should always be True for a DAG object,
        but this method can be used for validation.
        """
        try:
            self.topological_sort()
            return True
        except RuntimeError:
            return False

    def skeleton(self) -> Set[FrozenSet[int]]:
        """
        Return the skeleton (undirected edges) of the DAG.

        Returns:
            Set of frozensets, each containing two adjacent nodes
        """
        return {frozenset([parent, child]) for parent, child in self._edges}

    def v_structures(self) -> List[Tuple[int, int, int]]:
        """
        Return all v-structures (immoralities) in the DAG.

        A v-structure is i → k ← j where i and j are not adjacent.

        Returns:
            List of (i, k, j) tuples where i → k ← j is a v-structure
            with i < j for canonical ordering
        """
        v_structs = []
        skeleton = self.skeleton()

        for k in range(self._num_nodes):
            parents_k = list(self._parents[k])

            # Check all pairs of parents
            for idx_i, i in enumerate(parents_k):
                for j in parents_k[idx_i + 1:]:
                    # Check if i and j are non-adjacent
                    if frozenset([i, j]) not in skeleton:
                        # Ensure canonical ordering (i < j)
                        if i < j:
                            v_structs.append((i, k, j))
                        else:
                            v_structs.append((j, k, i))

        return v_structs

    def num_nodes(self) -> int:
        """Return the number of nodes."""
        return self._num_nodes

    def num_edges(self) -> int:
        """Return the number of edges."""
        return len(self._edges)

    @property
    def edges(self) -> Set[Tuple[int, int]]:
        """Return a copy of the edge set."""
        return self._edges.copy()

    @property
    def nodes(self) -> List[int]:
        """Return list of node indices."""
        return list(range(self._num_nodes))

    def max_degree(self) -> int:
        """Return the maximum degree (in + out) of any node."""
        return max(
            len(self._parents[i]) + len(self._children[i])
            for i in range(self._num_nodes)
        )

    def max_in_degree(self) -> int:
        """Return the maximum in-degree of any node."""
        return max(len(self._parents[i]) for i in range(self._num_nodes))

    def max_out_degree(self) -> int:
        """Return the maximum out-degree of any node."""
        return max(len(self._children[i]) for i in range(self._num_nodes))

    def in_degree(self, node: int) -> int:
        """Return the in-degree of a node."""
        self._validate_node(node)
        return len(self._parents[node])

    def out_degree(self, node: int) -> int:
        """Return the out-degree of a node."""
        self._validate_node(node)
        return len(self._children[node])

    def degree(self, node: int) -> int:
        """Return the total degree (in + out) of a node."""
        self._validate_node(node)
        return len(self._parents[node]) + len(self._children[node])

    def adjacency_matrix(self) -> np.ndarray:
        """
        Return the adjacency matrix of the DAG.

        A[i,j] = 1 if there is an edge i → j, 0 otherwise.

        Returns:
            d×d numpy array where d is the number of nodes
        """
        if self._adjacency_matrix is not None:
            return self._adjacency_matrix.copy()

        adj = np.zeros((self._num_nodes, self._num_nodes), dtype=np.int32)
        for parent, child in self._edges:
            adj[parent, child] = 1

        self._adjacency_matrix = adj
        return adj.copy()

    def to_adjacency_matrix(self) -> np.ndarray:
        """Alias for adjacency_matrix() for compatibility."""
        return self.adjacency_matrix()

    @classmethod
    def from_adjacency_matrix(cls, adj: np.ndarray) -> 'DAG':
        """
        Create a DAG from an adjacency matrix.

        Args:
            adj: d×d binary matrix where adj[i,j]=1 means edge i→j

        Returns:
            DAG object

        Raises:
            ValueError: If matrix is not square or creates a cycle
        """
        if adj.ndim != 2 or adj.shape[0] != adj.shape[1]:
            raise ValueError("Adjacency matrix must be square")

        num_nodes = adj.shape[0]
        edges = []

        for i in range(num_nodes):
            for j in range(num_nodes):
                if adj[i, j] != 0:
                    edges.append((i, j))

        return cls(num_nodes, edges)

    def copy(self) -> 'DAG':
        """Create a deep copy of the DAG."""
        return DAG(self._num_nodes, list(self._edges))

    def subgraph(self, nodes: Set[int]) -> 'DAG':
        """
        Create a subgraph induced by the given nodes.

        Args:
            nodes: Set of node indices to include

        Returns:
            New DAG containing only the specified nodes and edges between them
        """
        nodes = set(nodes)
        for node in nodes:
            self._validate_node(node)

        # Create mapping from old to new indices
        node_list = sorted(nodes)
        old_to_new = {old: new for new, old in enumerate(node_list)}

        # Create new DAG
        new_dag = DAG(len(nodes))

        for parent, child in self._edges:
            if parent in nodes and child in nodes:
                new_dag.add_edge(old_to_new[parent], old_to_new[child])

        return new_dag

    def is_adjacent(self, node1: int, node2: int) -> bool:
        """Check if two nodes are adjacent (connected by any edge)."""
        return (node1, node2) in self._edges or (node2, node1) in self._edges

    def neighbors(self, node: int) -> Set[int]:
        """Return all neighbors (parents and children) of a node."""
        self._validate_node(node)
        return self._parents[node] | self._children[node]

    def roots(self) -> Set[int]:
        """Return all root nodes (nodes with no parents)."""
        return {i for i in range(self._num_nodes) if len(self._parents[i]) == 0}

    def leaves(self) -> Set[int]:
        """Return all leaf nodes (nodes with no children)."""
        return {i for i in range(self._num_nodes) if len(self._children[i]) == 0}

    def is_connected(self) -> bool:
        """
        Check if the underlying undirected graph is connected.

        Uses BFS to check connectivity.
        """
        if self._num_nodes == 0:
            return True

        visited = set()
        queue = deque([0])

        while queue:
            node = queue.popleft()
            if node in visited:
                continue
            visited.add(node)

            # Add all neighbors (both directions)
            queue.extend(self._parents[node])
            queue.extend(self._children[node])

        return len(visited) == self._num_nodes

    def d_separated(self, x: Set[int], y: Set[int], z: Set[int]) -> bool:
        """
        Check if X and Y are d-separated given Z.

        Uses the Bayes-ball algorithm to check d-separation.

        Args:
            x: First set of nodes
            y: Second set of nodes
            z: Conditioning set

        Returns:
            True if X ⊥ Y | Z in the graph
        """
        # Use reachability to check d-separation
        # A path from X to Y is blocked by Z if:
        # - It contains a chain A → B → C or fork A ← B → C where B ∈ Z
        # - It contains a collider A → B ← C where B ∉ Z and no descendant of B is in Z

        # Find all ancestors of Z
        z_ancestors = set(z)
        for node in z:
            z_ancestors.update(self.ancestors(node))

        # BFS from X, tracking direction of entry
        # State: (node, came_from_parent)
        visited = set()
        reachable = set()
        queue = deque()

        # Initialize: start from X, can leave via any direction
        for node in x:
            queue.append((node, True))   # As if came from parent
            queue.append((node, False))  # As if came from child

        while queue:
            node, came_from_parent = queue.popleft()
            state = (node, came_from_parent)

            if state in visited:
                continue
            visited.add(state)

            if node in y:
                return False  # Found an active path

            # Determine which directions we can continue
            if came_from_parent:
                # Came from a parent, so we entered this node as a child
                # Can continue to children if node not in Z (chain)
                if node not in z:
                    for child in self._children[node]:
                        queue.append((child, True))
                    for parent in self._parents[node]:
                        queue.append((parent, False))
            else:
                # Came from a child, so we entered via a child edge
                # This is a collider configuration
                if node in z_ancestors:
                    # Collider is in Z or has descendant in Z - path is active
                    for parent in self._parents[node]:
                        queue.append((parent, False))
                    for child in self._children[node]:
                        queue.append((child, True))
                # If not in z_ancestors, path is blocked (collider blocks)

        return True

    def __eq__(self, other: object) -> bool:
        """Check equality with another DAG."""
        if not isinstance(other, DAG):
            return False
        return (self._num_nodes == other._num_nodes and
                self._edges == other._edges)

    def __hash__(self) -> int:
        """Hash based on structure."""
        return hash((self._num_nodes, frozenset(self._edges)))

    def __repr__(self) -> str:
        """String representation."""
        return f"DAG(num_nodes={self._num_nodes}, num_edges={len(self._edges)})"

    def __str__(self) -> str:
        """Human-readable string representation."""
        lines = [f"DAG with {self._num_nodes} nodes and {len(self._edges)} edges:"]
        for parent, child in sorted(self._edges):
            lines.append(f"  {parent} -> {child}")
        return "\n".join(lines)

    def __iter__(self) -> Iterator[int]:
        """Iterate over nodes."""
        return iter(range(self._num_nodes))

    def __len__(self) -> int:
        """Return number of nodes."""
        return self._num_nodes

    def __contains__(self, item) -> bool:
        """Check if node or edge is in the graph."""
        if isinstance(item, int):
            return 0 <= item < self._num_nodes
        elif isinstance(item, tuple) and len(item) == 2:
            return item in self._edges
        return False
