"""
PC Algorithm for causal discovery.

The PC algorithm learns the structure of a DAG from observational data
by performing conditional independence tests.
"""

from __future__ import annotations

from typing import Set, List, Tuple, Optional, Dict, FrozenSet
from dataclasses import dataclass, field
from itertools import combinations
import numpy as np
from scipy import stats

from ..core.dag import DAG
from ..core.mec import CPDAG
from ..core.partial_correlation import (
    sample_partial_correlation,
    fisher_z_transform,
    partial_correlation_test,
)


@dataclass
class PCResult:
    """Result of PC algorithm execution."""
    cpdag: CPDAG
    separation_sets: Dict[FrozenSet[int], Set[int]]
    n_tests: int
    test_results: List[Tuple[int, int, Set[int], float, bool]]  # (i, j, S, pvalue, independent)
    execution_time: float = 0.0

    def get_learned_dag(self) -> DAG:
        """Sample a DAG from the learned MEC."""
        from ..core.mec import sample_dag_from_mec
        return sample_dag_from_mec(self.cpdag)


class PCAlgorithm:
    """
    PC algorithm for causal structure learning.

    The PC algorithm works in three phases:
    1. Skeleton learning: Remove edges between conditionally independent pairs
    2. V-structure detection: Orient colliders (v-structures)
    3. Edge propagation: Apply Meek's rules to orient additional edges

    Attributes:
        alpha: Significance level for conditional independence tests
        max_cond_set_size: Maximum conditioning set size (None = d-2)
    """

    def __init__(
        self,
        alpha: float = 0.05,
        max_cond_set_size: Optional[int] = None,
        ci_test: str = 'fisher_z'
    ):
        """
        Initialize PC algorithm.

        Args:
            alpha: Significance level for CI tests
            max_cond_set_size: Maximum |S| to consider (default: min(d-2, max_degree))
            ci_test: CI test to use ('fisher_z' or 'partial_correlation')
        """
        if not 0 < alpha < 1:
            raise ValueError(f"Alpha must be in (0, 1), got {alpha}")

        self.alpha = alpha
        self.max_cond_set_size = max_cond_set_size
        self.ci_test = ci_test

        # State during fitting
        self._n: int = 0
        self._d: int = 0
        self._data: Optional[np.ndarray] = None
        self._adjacency: Optional[Dict[int, Set[int]]] = None
        self._sep_sets: Dict[FrozenSet[int], Set[int]] = {}
        self._test_results: List[Tuple[int, int, Set[int], float, bool]] = []

    def fit(self, X: np.ndarray) -> PCResult:
        """
        Run PC algorithm on data.

        Args:
            X: n×d data matrix (n samples, d variables)

        Returns:
            PCResult containing the learned CPDAG and metadata
        """
        import time
        start_time = time.time()

        self._n, self._d = X.shape
        self._data = X
        self._sep_sets = {}
        self._test_results = []

        # Determine max conditioning set size
        if self.max_cond_set_size is None:
            max_k = self._d - 2
        else:
            max_k = min(self.max_cond_set_size, self._d - 2)

        # Phase 1: Learn skeleton
        self._learn_skeleton(max_k)

        # Phase 2: Orient v-structures
        cpdag = self._orient_v_structures()

        # Phase 3: Apply Meek's rules
        cpdag._apply_meeks_rules()

        execution_time = time.time() - start_time

        return PCResult(
            cpdag=cpdag,
            separation_sets=self._sep_sets,
            n_tests=len(self._test_results),
            test_results=self._test_results,
            execution_time=execution_time
        )

    def _learn_skeleton(self, max_k: int) -> None:
        """
        Phase 1: Learn the skeleton using conditional independence tests.

        Start with complete graph and remove edges between
        conditionally independent pairs.
        """
        # Initialize complete adjacency
        self._adjacency = {
            i: set(range(self._d)) - {i}
            for i in range(self._d)
        }

        # Test conditioning sets of increasing size
        for k in range(max_k + 1):
            # Get edges to test (copy to avoid modification during iteration)
            edges_to_test = [
                (i, j) for i in range(self._d)
                for j in self._adjacency[i]
                if i < j
            ]

            for i, j in edges_to_test:
                # Check if still adjacent
                if j not in self._adjacency[i]:
                    continue

                # Get potential conditioning sets from neighbors
                neighbors_i = self._adjacency[i] - {j}
                neighbors_j = self._adjacency[j] - {i}

                # Use neighbors of both endpoints
                potential_cond = neighbors_i | neighbors_j

                if len(potential_cond) < k:
                    continue

                # Test all conditioning sets of size k
                found_independent = False
                for S in combinations(potential_cond, k):
                    S_set = set(S)

                    # Perform CI test
                    is_indep, pvalue = self._ci_test_func(i, j, S_set)

                    self._test_results.append((i, j, S_set, pvalue, is_indep))

                    if is_indep:
                        # Remove edge
                        self._adjacency[i].discard(j)
                        self._adjacency[j].discard(i)

                        # Record separation set
                        self._sep_sets[frozenset([i, j])] = S_set

                        found_independent = True
                        break

                if found_independent:
                    continue

    def _ci_test_func(self, i: int, j: int, S: Set[int]) -> Tuple[bool, float]:
        """
        Perform conditional independence test.

        Tests H0: X_i ⊥ X_j | X_S using Fisher's z-test.

        Returns:
            (is_independent, p_value)
        """
        # Compute sample partial correlation
        rho_hat = sample_partial_correlation(self._data, i, j, S)

        # Degrees of freedom
        df = self._n - len(S) - 3

        if df <= 0:
            # Not enough samples - assume independent
            return True, 1.0

        # Fisher's z-transform
        z = fisher_z_transform(rho_hat)

        # Standard error
        se = 1.0 / np.sqrt(df)

        # Test statistic
        z_stat = abs(z) / se

        # Two-sided p-value
        p_value = 2 * (1 - stats.norm.cdf(z_stat))

        # Decision
        is_independent = p_value >= self.alpha

        return is_independent, float(p_value)

    def _orient_v_structures(self) -> CPDAG:
        """
        Phase 2: Orient v-structures (colliders).

        A v-structure i -> k <- j exists when:
        - i and j are both adjacent to k
        - i and j are not adjacent to each other
        - k is not in the separation set of i and j
        """
        cpdag = CPDAG(self._d)

        # Add all skeleton edges as undirected
        for i in range(self._d):
            for j in self._adjacency[i]:
                if i < j:
                    cpdag.add_undirected_edge(i, j)

        # Find and orient v-structures
        for k in range(self._d):
            # Get neighbors of k
            neighbors_k = list(self._adjacency[k])

            # Check all pairs of neighbors
            for idx_i, i in enumerate(neighbors_k):
                for j in neighbors_k[idx_i + 1:]:
                    # Check if i and j are non-adjacent
                    if j not in self._adjacency[i]:
                        # Check if k is NOT in separation set of i, j
                        sep_key = frozenset([i, j])
                        sep_set = self._sep_sets.get(sep_key, set())

                        if k not in sep_set:
                            # Orient as v-structure: i -> k <- j
                            cpdag._undirected_edges.discard(frozenset([i, k]))
                            cpdag._undirected_edges.discard(frozenset([j, k]))
                            cpdag._directed_edges.add((i, k))
                            cpdag._directed_edges.add((j, k))

        return cpdag

    def fit_predict(self, X: np.ndarray) -> CPDAG:
        """Fit and return only the CPDAG."""
        return self.fit(X).cpdag

    def get_skeleton(self) -> Set[FrozenSet[int]]:
        """Return the learned skeleton after fitting."""
        if self._adjacency is None:
            raise RuntimeError("Must call fit() first")

        skeleton = set()
        for i in range(self._d):
            for j in self._adjacency[i]:
                if i < j:
                    skeleton.add(frozenset([i, j]))
        return skeleton


def pc_algorithm(
    X: np.ndarray,
    alpha: float = 0.05,
    max_cond_set_size: Optional[int] = None
) -> CPDAG:
    """
    Convenience function to run PC algorithm.

    Args:
        X: n×d data matrix
        alpha: Significance level
        max_cond_set_size: Maximum conditioning set size

    Returns:
        Learned CPDAG
    """
    pc = PCAlgorithm(alpha=alpha, max_cond_set_size=max_cond_set_size)
    return pc.fit_predict(X)


def pc_with_known_skeleton(
    X: np.ndarray,
    skeleton: Set[FrozenSet[int]],
    alpha: float = 0.05
) -> CPDAG:
    """
    Run PC algorithm with known skeleton (skip Phase 1).

    Useful for oracle experiments where skeleton is given.

    Args:
        X: n×d data matrix
        skeleton: Known skeleton edges
        alpha: Significance level

    Returns:
        Learned CPDAG
    """
    n, d = X.shape

    # Build adjacency from skeleton
    adjacency = {i: set() for i in range(d)}
    for edge in skeleton:
        nodes = list(edge)
        adjacency[nodes[0]].add(nodes[1])
        adjacency[nodes[1]].add(nodes[0])

    # Find separation sets for non-adjacent pairs
    sep_sets = {}
    for i in range(d):
        for j in range(i + 1, d):
            if frozenset([i, j]) not in skeleton:
                # Find minimal separation set
                neighbors = adjacency[i] | adjacency[j]
                neighbors.discard(i)
                neighbors.discard(j)

                for k in range(len(neighbors) + 1):
                    for S in combinations(neighbors, k):
                        S_set = set(S)
                        is_indep, _ = partial_correlation_test(X, i, j, S_set, alpha)
                        if is_indep:
                            sep_sets[frozenset([i, j])] = S_set
                            break
                    if frozenset([i, j]) in sep_sets:
                        break

    # Create CPDAG and orient v-structures
    cpdag = CPDAG(d)

    for edge in skeleton:
        nodes = list(edge)
        cpdag.add_undirected_edge(nodes[0], nodes[1])

    # Orient v-structures
    for k in range(d):
        neighbors_k = list(adjacency[k])
        for idx_i, i in enumerate(neighbors_k):
            for j in neighbors_k[idx_i + 1:]:
                if j not in adjacency[i]:
                    sep_key = frozenset([i, j])
                    sep_set = sep_sets.get(sep_key, set())
                    if k not in sep_set:
                        cpdag._undirected_edges.discard(frozenset([i, k]))
                        cpdag._undirected_edges.discard(frozenset([j, k]))
                        cpdag._directed_edges.add((i, k))
                        cpdag._directed_edges.add((j, k))

    cpdag._apply_meeks_rules()

    return cpdag


def estimate_sample_complexity(
    dag: DAG,
    sem: 'LinearGaussianSEM',
    alpha: float = 0.05,
    power: float = 0.9,
    n_trials: int = 50,
    sample_sizes: Optional[List[int]] = None,
    random_state: Optional[int] = None
) -> int:
    """
    Estimate sample complexity empirically.

    Binary search over sample sizes to find minimum n
    where PC recovers the correct MEC with probability >= power.

    Args:
        dag: True DAG
        sem: SEM parameters
        alpha: Significance level for PC
        power: Required success probability
        n_trials: Number of trials per sample size
        sample_sizes: Sample sizes to try (default: powers of 2)
        random_state: Random seed

    Returns:
        Estimated minimum sample size
    """
    from ..core.mec import CPDAG as CPDAGClass
    from ..metrics.shd import structural_hamming_distance

    rng = np.random.default_rng(random_state)

    true_cpdag = CPDAGClass.from_dag(dag)

    if sample_sizes is None:
        sample_sizes = [50, 100, 200, 500, 1000, 2000, 5000, 10000]

    def success_rate(n: int) -> float:
        """Compute success rate at sample size n."""
        successes = 0
        for _ in range(n_trials):
            seed = int(rng.integers(0, 2**31))
            X = sem.sample(n, random_state=seed)

            pc = PCAlgorithm(alpha=alpha)
            result = pc.fit(X)

            shd = structural_hamming_distance(true_cpdag, result.cpdag)
            if shd == 0:
                successes += 1

        return successes / n_trials

    # Find minimum n with success rate >= power
    for n in sample_sizes:
        rate = success_rate(n)
        if rate >= power:
            return n

    # Return largest tried if none succeeded
    return sample_sizes[-1]
