"""
CPDAG (Completed Partially Directed Acyclic Graph) and MEC operations.

A CPDAG represents a Markov Equivalence Class (MEC) - the set of DAGs
that encode the same conditional independence relationships.
"""

from __future__ import annotations

from typing import Set, List, Tuple, Optional, FrozenSet, Iterator
from collections import deque
import numpy as np

from .dag import DAG


class CPDAG:
    """
    Completed Partially Directed Acyclic Graph.

    Represents a Markov Equivalence Class where:
    - Directed edges are compelled (same in all DAGs in the MEC)
    - Undirected edges are reversible (can go either way)

    Attributes:
        _num_nodes: Number of nodes
        _directed_edges: Set of directed edges (parent, child)
        _undirected_edges: Set of undirected edges as frozensets
    """

    def __init__(self, num_nodes: int):
        """
        Initialize an empty CPDAG.

        Args:
            num_nodes: Number of nodes
        """
        if num_nodes < 1:
            raise ValueError(f"Number of nodes must be positive, got {num_nodes}")

        self._num_nodes = num_nodes
        self._directed_edges: Set[Tuple[int, int]] = set()
        self._undirected_edges: Set[FrozenSet[int]] = set()

    def num_nodes(self) -> int:
        """Return the number of nodes."""
        return self._num_nodes

    def add_directed_edge(self, parent: int, child: int) -> None:
        """Add a directed edge parent -> child."""
        if not (0 <= parent < self._num_nodes and 0 <= child < self._num_nodes):
            raise ValueError("Node index out of bounds")
        if parent == child:
            raise ValueError("Self-loops not allowed")

        # Remove undirected edge if exists
        self._undirected_edges.discard(frozenset([parent, child]))
        self._directed_edges.add((parent, child))

    def add_undirected_edge(self, node1: int, node2: int) -> None:
        """Add an undirected edge node1 - node2."""
        if not (0 <= node1 < self._num_nodes and 0 <= node2 < self._num_nodes):
            raise ValueError("Node index out of bounds")
        if node1 == node2:
            raise ValueError("Self-loops not allowed")

        # Only add if no directed edge exists
        if (node1, node2) not in self._directed_edges and \
           (node2, node1) not in self._directed_edges:
            self._undirected_edges.add(frozenset([node1, node2]))

    def has_directed_edge(self, parent: int, child: int) -> bool:
        """Check if directed edge parent -> child exists."""
        return (parent, child) in self._directed_edges

    def has_undirected_edge(self, node1: int, node2: int) -> bool:
        """Check if undirected edge node1 - node2 exists."""
        return frozenset([node1, node2]) in self._undirected_edges

    def is_adjacent(self, node1: int, node2: int) -> bool:
        """Check if two nodes are adjacent (any edge type)."""
        return (self.has_directed_edge(node1, node2) or
                self.has_directed_edge(node2, node1) or
                self.has_undirected_edge(node1, node2))

    def neighbors(self, node: int) -> Set[int]:
        """Return all neighbors of a node (any edge type)."""
        neighbors = set()

        # Directed edges
        for parent, child in self._directed_edges:
            if parent == node:
                neighbors.add(child)
            elif child == node:
                neighbors.add(parent)

        # Undirected edges
        for edge in self._undirected_edges:
            if node in edge:
                neighbors.update(edge - {node})

        return neighbors

    @property
    def directed_edges(self) -> Set[Tuple[int, int]]:
        """Return copy of directed edges."""
        return self._directed_edges.copy()

    @property
    def undirected_edges(self) -> Set[FrozenSet[int]]:
        """Return copy of undirected edges."""
        return self._undirected_edges.copy()

    def skeleton(self) -> Set[FrozenSet[int]]:
        """Return the skeleton (all edges as undirected)."""
        skeleton = set()

        for parent, child in self._directed_edges:
            skeleton.add(frozenset([parent, child]))

        skeleton.update(self._undirected_edges)

        return skeleton

    def num_directed_edges(self) -> int:
        """Return number of directed edges."""
        return len(self._directed_edges)

    def num_undirected_edges(self) -> int:
        """Return number of undirected edges."""
        return len(self._undirected_edges)

    def num_edges(self) -> int:
        """Return total number of edges."""
        return self.num_directed_edges() + self.num_undirected_edges()

    @classmethod
    def from_dag(cls, dag: DAG) -> 'CPDAG':
        """
        Convert a DAG to its CPDAG representation.

        Algorithm:
        1. Start with skeleton (all edges undirected)
        2. Orient edges involved in v-structures
        3. Apply Meek's rules until convergence

        Args:
            dag: Input DAG

        Returns:
            CPDAG representing the MEC of the input DAG
        """
        cpdag = cls(dag.num_nodes())

        # Add all edges as undirected initially
        for parent, child in dag.edges:
            cpdag._undirected_edges.add(frozenset([parent, child]))

        # Get v-structures and orient their edges
        v_structures = dag.v_structures()
        for i, k, j in v_structures:
            # Orient i -> k and j -> k
            cpdag._undirected_edges.discard(frozenset([i, k]))
            cpdag._undirected_edges.discard(frozenset([j, k]))
            cpdag._directed_edges.add((i, k))
            cpdag._directed_edges.add((j, k))

        # Apply Meek's rules until no more orientations
        cpdag._apply_meeks_rules()

        return cpdag

    def _apply_meeks_rules(self) -> None:
        """
        Apply Meek's orientation rules until convergence.

        Rules:
        R1: If i -> j - k and i, k not adjacent: orient j -> k
        R2: If i -> k -> j and i - j: orient i -> j
        R3: If i - j, i - k, i - l, j -> k, l -> k, j and l not adjacent: orient i -> k
        R4: If i - j, i - k, k -> l -> j: orient i -> j
        """
        changed = True

        while changed:
            changed = False

            # Get current undirected edges as list for iteration
            undirected_list = list(self._undirected_edges)

            for edge in undirected_list:
                if edge not in self._undirected_edges:
                    continue  # Already oriented

                nodes = list(edge)
                node1, node2 = nodes[0], nodes[1]

                # Try orienting node1 -> node2
                if self._should_orient_r1(node1, node2):
                    self._orient_edge(node1, node2)
                    changed = True
                    continue

                if self._should_orient_r2(node1, node2):
                    self._orient_edge(node1, node2)
                    changed = True
                    continue

                if self._should_orient_r3(node1, node2):
                    self._orient_edge(node1, node2)
                    changed = True
                    continue

                if self._should_orient_r4(node1, node2):
                    self._orient_edge(node1, node2)
                    changed = True
                    continue

                # Try orienting node2 -> node1
                if self._should_orient_r1(node2, node1):
                    self._orient_edge(node2, node1)
                    changed = True
                    continue

                if self._should_orient_r2(node2, node1):
                    self._orient_edge(node2, node1)
                    changed = True
                    continue

                if self._should_orient_r3(node2, node1):
                    self._orient_edge(node2, node1)
                    changed = True
                    continue

                if self._should_orient_r4(node2, node1):
                    self._orient_edge(node2, node1)
                    changed = True
                    continue

    def _orient_edge(self, parent: int, child: int) -> None:
        """Orient an undirected edge as parent -> child."""
        self._undirected_edges.discard(frozenset([parent, child]))
        self._directed_edges.add((parent, child))

    def _should_orient_r1(self, i: int, j: int) -> bool:
        """
        Meek's Rule 1: If there exists k such that:
        - k -> i (directed)
        - i - j (undirected, the edge we're considering)
        - k and j not adjacent
        Then orient i -> j
        """
        for parent, child in self._directed_edges:
            if child == i:
                k = parent
                if not self.is_adjacent(k, j):
                    return True
        return False

    def _should_orient_r2(self, i: int, j: int) -> bool:
        """
        Meek's Rule 2: If there exists k such that:
        - i -> k (directed)
        - k -> j (directed)
        - i - j (undirected, the edge we're considering)
        Then orient i -> j (avoids creating a cycle)
        """
        for k in range(self._num_nodes):
            if k == i or k == j:
                continue
            if self.has_directed_edge(i, k) and self.has_directed_edge(k, j):
                return True
        return False

    def _should_orient_r3(self, i: int, k: int) -> bool:
        """
        Meek's Rule 3: If there exist j, l such that:
        - i - j, i - l (undirected)
        - j -> k, l -> k (directed)
        - j and l not adjacent
        - i - k (undirected, the edge we're considering)
        Then orient i -> k
        """
        # Find nodes j, l with:
        # - i - j, i - l (undirected with i)
        # - j -> k, l -> k (directed into k)
        # - j, l not adjacent

        # Get undirected neighbors of i
        undirected_neighbors_i = []
        for edge in self._undirected_edges:
            if i in edge:
                undirected_neighbors_i.extend(edge - {i})

        # Get directed parents of k
        directed_parents_k = [p for p, c in self._directed_edges if c == k]

        # Find j, l satisfying conditions
        candidates = set(undirected_neighbors_i) & set(directed_parents_k)
        candidates.discard(k)

        candidates = list(candidates)
        for idx_j, j in enumerate(candidates):
            for l in candidates[idx_j + 1:]:
                if not self.is_adjacent(j, l):
                    return True

        return False

    def _should_orient_r4(self, i: int, j: int) -> bool:
        """
        Meek's Rule 4: If there exist k, l such that:
        - i - k (undirected)
        - k -> l (directed)
        - l -> j (directed)
        - i - j (undirected, the edge we're considering)
        Then orient i -> j
        """
        # Find k with i - k (undirected)
        undirected_neighbors_i = []
        for edge in self._undirected_edges:
            if i in edge and frozenset([i, j]) != edge:
                undirected_neighbors_i.extend(edge - {i})

        for k in undirected_neighbors_i:
            if k == j:
                continue
            # Find l with k -> l and l -> j
            for l in range(self._num_nodes):
                if l == i or l == j or l == k:
                    continue
                if self.has_directed_edge(k, l) and self.has_directed_edge(l, j):
                    return True

        return False

    def enumerate_mec(self, max_size: int = 10000) -> List[DAG]:
        """
        Enumerate all DAGs in this Markov Equivalence Class.

        Warning: MEC size can be exponential in the number of undirected edges.

        Args:
            max_size: Maximum number of DAGs to enumerate

        Returns:
            List of DAGs in the MEC

        Raises:
            ValueError: If MEC is larger than max_size
        """
        if not self._undirected_edges:
            # Only directed edges - single DAG
            dag = DAG(self._num_nodes)
            for parent, child in self._directed_edges:
                dag.add_edge(parent, child)
            return [dag]

        # Use recursive enumeration
        dags = []
        self._enumerate_recursive(
            list(self._undirected_edges),
            0,
            set(self._directed_edges),
            dags,
            max_size
        )

        return dags

    def _enumerate_recursive(
        self,
        undirected_list: List[FrozenSet[int]],
        index: int,
        current_directed: Set[Tuple[int, int]],
        result: List[DAG],
        max_size: int
    ) -> bool:
        """Recursive helper for MEC enumeration."""
        if len(result) >= max_size:
            raise ValueError(f"MEC size exceeds {max_size}")

        if index == len(undirected_list):
            # All edges oriented - create DAG
            dag = DAG(self._num_nodes)
            for parent, child in current_directed:
                dag.add_edge(parent, child)
            result.append(dag)
            return True

        edge = undirected_list[index]
        nodes = list(edge)
        node1, node2 = nodes[0], nodes[1]

        # Try orientation node1 -> node2
        current_directed_copy1 = set(current_directed)
        current_directed_copy1.add((node1, node2))
        if self._is_valid_partial(current_directed_copy1, undirected_list[index + 1:]):
            self._enumerate_recursive(
                undirected_list, index + 1,
                current_directed_copy1, result, max_size
            )

        # Try orientation node2 -> node1
        current_directed_copy2 = set(current_directed)
        current_directed_copy2.add((node2, node1))
        if self._is_valid_partial(current_directed_copy2, undirected_list[index + 1:]):
            self._enumerate_recursive(
                undirected_list, index + 1,
                current_directed_copy2, result, max_size
            )

        return True

    def _is_valid_partial(
        self,
        directed: Set[Tuple[int, int]],
        remaining_undirected: List[FrozenSet[int]]
    ) -> bool:
        """Check if partial orientation is consistent (no cycles, no new v-structures)."""
        # Check for cycles using DFS
        adjacency = {i: set() for i in range(self._num_nodes)}
        for parent, child in directed:
            adjacency[parent].add(child)

        # Check for cycles
        visited = set()
        rec_stack = set()

        def has_cycle(node):
            visited.add(node)
            rec_stack.add(node)

            for neighbor in adjacency[node]:
                if neighbor not in visited:
                    if has_cycle(neighbor):
                        return True
                elif neighbor in rec_stack:
                    return True

            rec_stack.remove(node)
            return False

        for node in range(self._num_nodes):
            if node not in visited:
                if has_cycle(node):
                    return False

        return True

    def mec_size(self) -> int:
        """
        Return the size of the MEC (number of DAGs).

        Note: This can be expensive for large MECs.
        """
        try:
            dags = self.enumerate_mec(max_size=100000)
            return len(dags)
        except ValueError:
            return -1  # Too large to enumerate

    def mec_size_estimate(self) -> int:
        """
        Estimate MEC size based on undirected edges.

        Upper bound is 2^k where k is number of undirected edges.
        Actual size is usually smaller due to v-structure constraints.
        """
        k = len(self._undirected_edges)
        return 2 ** k

    def to_adjacency_matrices(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        Return adjacency matrices for directed and undirected edges.

        Returns:
            (directed_adj, undirected_adj) where:
            - directed_adj[i,j] = 1 if i -> j
            - undirected_adj[i,j] = undirected_adj[j,i] = 1 if i - j
        """
        d = self._num_nodes
        directed_adj = np.zeros((d, d), dtype=np.int32)
        undirected_adj = np.zeros((d, d), dtype=np.int32)

        for parent, child in self._directed_edges:
            directed_adj[parent, child] = 1

        for edge in self._undirected_edges:
            nodes = list(edge)
            undirected_adj[nodes[0], nodes[1]] = 1
            undirected_adj[nodes[1], nodes[0]] = 1

        return directed_adj, undirected_adj

    def to_pdag_matrix(self) -> np.ndarray:
        """
        Return combined PDAG adjacency matrix.

        For directed edge i -> j: matrix[i,j] = 1, matrix[j,i] = 0
        For undirected edge i - j: matrix[i,j] = matrix[j,i] = 1
        """
        d = self._num_nodes
        matrix = np.zeros((d, d), dtype=np.int32)

        for parent, child in self._directed_edges:
            matrix[parent, child] = 1

        for edge in self._undirected_edges:
            nodes = list(edge)
            matrix[nodes[0], nodes[1]] = 1
            matrix[nodes[1], nodes[0]] = 1

        return matrix

    def __eq__(self, other: object) -> bool:
        """Check equality with another CPDAG."""
        if not isinstance(other, CPDAG):
            return False
        return (self._num_nodes == other._num_nodes and
                self._directed_edges == other._directed_edges and
                self._undirected_edges == other._undirected_edges)

    def __hash__(self) -> int:
        """Hash based on structure."""
        return hash((
            self._num_nodes,
            frozenset(self._directed_edges),
            frozenset(self._undirected_edges)
        ))

    def __repr__(self) -> str:
        """String representation."""
        return (f"CPDAG(nodes={self._num_nodes}, "
                f"directed={len(self._directed_edges)}, "
                f"undirected={len(self._undirected_edges)})")

    def __str__(self) -> str:
        """Human-readable string representation."""
        lines = [f"CPDAG with {self._num_nodes} nodes:"]
        lines.append(f"  Directed edges: {len(self._directed_edges)}")
        for parent, child in sorted(self._directed_edges):
            lines.append(f"    {parent} -> {child}")
        lines.append(f"  Undirected edges: {len(self._undirected_edges)}")
        for edge in sorted(self._undirected_edges, key=lambda x: tuple(sorted(x))):
            nodes = sorted(edge)
            lines.append(f"    {nodes[0]} - {nodes[1]}")
        return "\n".join(lines)


def dag_to_cpdag(dag: DAG) -> CPDAG:
    """Convert a DAG to its CPDAG representation."""
    return CPDAG.from_dag(dag)


def are_markov_equivalent(dag1: DAG, dag2: DAG) -> bool:
    """
    Check if two DAGs are Markov equivalent.

    Two DAGs are Markov equivalent iff they have:
    1. Same skeleton
    2. Same v-structures
    """
    # Check same number of nodes
    if dag1.num_nodes() != dag2.num_nodes():
        return False

    # Check same skeleton
    if dag1.skeleton() != dag2.skeleton():
        return False

    # Check same v-structures
    v1 = set(dag1.v_structures())
    v2 = set(dag2.v_structures())

    return v1 == v2


def cpdag_from_skeleton_and_vstructures(
    num_nodes: int,
    skeleton: Set[FrozenSet[int]],
    v_structures: List[Tuple[int, int, int]]
) -> CPDAG:
    """
    Create a CPDAG from skeleton and v-structures.

    Args:
        num_nodes: Number of nodes
        skeleton: Set of undirected edges as frozensets
        v_structures: List of (i, k, j) for v-structures i -> k <- j

    Returns:
        CPDAG with oriented v-structures and Meek's rules applied
    """
    cpdag = CPDAG(num_nodes)

    # Add all skeleton edges as undirected
    for edge in skeleton:
        nodes = list(edge)
        cpdag.add_undirected_edge(nodes[0], nodes[1])

    # Orient v-structures
    for i, k, j in v_structures:
        cpdag._undirected_edges.discard(frozenset([i, k]))
        cpdag._undirected_edges.discard(frozenset([j, k]))
        cpdag._directed_edges.add((i, k))
        cpdag._directed_edges.add((j, k))

    # Apply Meek's rules
    cpdag._apply_meeks_rules()

    return cpdag


def sample_dag_from_mec(cpdag: CPDAG, random_state: Optional[int] = None) -> DAG:
    """
    Sample a random DAG from the MEC.

    Args:
        cpdag: The CPDAG representing the MEC
        random_state: Random seed

    Returns:
        A randomly sampled DAG from the MEC
    """
    rng = np.random.default_rng(random_state)

    dag = DAG(cpdag._num_nodes)

    # Add all directed edges
    for parent, child in cpdag._directed_edges:
        dag.add_edge(parent, child)

    # Randomly orient undirected edges
    for edge in cpdag._undirected_edges:
        nodes = list(edge)

        # Try random orientation
        if rng.random() < 0.5:
            parent, child = nodes[0], nodes[1]
        else:
            parent, child = nodes[1], nodes[0]

        # Check if this creates a cycle
        try:
            dag.add_edge(parent, child)
        except ValueError:
            # Creates cycle, use other orientation
            dag.add_edge(child, parent)

    return dag
