"""
Fisher Dimension computation for causal graphs.

The Fisher dimension F([G]) = 1/ρ_min² characterizes the sample complexity
of causal structure learning. This module implements Algorithm 1 from
the paper (Definition 3.1).
"""

from __future__ import annotations

from typing import List, Tuple, Set, Optional, NamedTuple, Dict
from dataclasses import dataclass
import numpy as np

from .dag import DAG
from .sem import LinearGaussianSEM
from .partial_correlation import partial_correlation


@dataclass
class EdgeTestResult:
    """Result of a partial correlation test for a single edge."""
    parent: int
    child: int
    conditioning_set: Set[int]
    partial_correlation: float
    edge_coefficient: Optional[float] = None

    @property
    def is_edge_detectable(self) -> bool:
        """Check if the edge is detectable (non-zero partial correlation)."""
        return abs(self.partial_correlation) > 1e-10


@dataclass
class FisherDimensionResult:
    """
    Complete result of Fisher dimension computation.

    Attributes:
        fisher_dimension: F([G]) = 1/ρ_min²
        rho_min: Minimum nonzero partial correlation
        edge_tests: List of test results for each edge
        hardest_edge: The edge with smallest |ρ_{ij|S}|
        num_edges: Total number of edges tested
    """
    fisher_dimension: float
    rho_min: float
    edge_tests: List[EdgeTestResult]
    hardest_edge: Optional[Tuple[int, int]]
    num_edges: int

    def to_dict(self) -> Dict:
        """Convert to dictionary for serialization."""
        return {
            'fisher_dimension': self.fisher_dimension,
            'rho_min': self.rho_min,
            'hardest_edge': self.hardest_edge,
            'num_edges': self.num_edges,
            'edge_tests': [
                {
                    'parent': et.parent,
                    'child': et.child,
                    'conditioning_set': list(et.conditioning_set),
                    'partial_correlation': et.partial_correlation,
                    'edge_coefficient': et.edge_coefficient
                }
                for et in self.edge_tests
            ]
        }


def compute_fisher_dimension(
    dag: DAG,
    sem: LinearGaussianSEM,
    include_edge_coefficients: bool = True
) -> FisherDimensionResult:
    """
    Compute the Fisher dimension using Definition 3.1.

    Algorithm 1: Compute Fisher Dimension (Direct Method)

    For each edge (parent, child) in E(G):
        S = Pa(child) \ {parent}  # Canonical conditioning set
        Compute ρ_{child,parent|S}
        Track minimum |ρ|

    F([G]) = 1/ρ_min²

    Args:
        dag: The causal DAG
        sem: The Linear Gaussian SEM with parameters
        include_edge_coefficients: Whether to include β values in results

    Returns:
        FisherDimensionResult with full computation details
    """
    if dag.num_edges() == 0:
        return FisherDimensionResult(
            fisher_dimension=float('inf'),
            rho_min=0.0,
            edge_tests=[],
            hardest_edge=None,
            num_edges=0
        )

    # Get covariance matrix
    Sigma = sem.covariance_matrix()

    edge_tests = []
    rho_min = float('inf')
    hardest_edge = None

    for parent, child in dag.edges:
        # Canonical conditioning set: S = Pa(child) \ {parent}
        S = dag.parents(child) - {parent}

        # Compute partial correlation
        rho = partial_correlation(Sigma, child, parent, S)

        # Get edge coefficient if requested
        edge_coef = None
        if include_edge_coefficients:
            edge_coef = sem.get_coefficient(parent, child)

        # Create test result
        test_result = EdgeTestResult(
            parent=parent,
            child=child,
            conditioning_set=S,
            partial_correlation=rho,
            edge_coefficient=edge_coef
        )
        edge_tests.append(test_result)

        # Track minimum (for nonzero correlations)
        if abs(rho) > 1e-15 and abs(rho) < rho_min:
            rho_min = abs(rho)
            hardest_edge = (parent, child)

    # Compute Fisher dimension
    if rho_min == float('inf') or rho_min < 1e-15:
        fisher_dim = float('inf')
        rho_min = 0.0
    else:
        fisher_dim = 1.0 / (rho_min ** 2)

    return FisherDimensionResult(
        fisher_dimension=fisher_dim,
        rho_min=rho_min,
        edge_tests=edge_tests,
        hardest_edge=hardest_edge,
        num_edges=dag.num_edges()
    )


def compute_fisher_dimension_simple(
    dag: DAG,
    sem: LinearGaussianSEM
) -> Tuple[float, float]:
    """
    Simplified interface for Fisher dimension computation.

    Args:
        dag: The causal DAG
        sem: The Linear Gaussian SEM

    Returns:
        Tuple of (fisher_dimension, rho_min)
    """
    result = compute_fisher_dimension(dag, sem, include_edge_coefficients=False)
    return result.fisher_dimension, result.rho_min


def theoretical_fisher_dimension_bound(
    sem: LinearGaussianSEM
) -> float:
    """
    Compute the upper bound from Proposition 7.1.

    F([G]) ≤ C_Σ / (ε² * σ_min²)

    where:
        C_Σ = λ_max(Σ) is the maximum covariance eigenvalue
        ε = min |β_ij| is the minimum edge coefficient
        σ_min² = min σ_i² is the minimum noise variance

    Args:
        sem: The Linear Gaussian SEM

    Returns:
        Upper bound on Fisher dimension
    """
    from scipy import linalg

    Sigma = sem.covariance_matrix()
    eigenvalues = linalg.eigvalsh(Sigma)
    C_Sigma = np.max(eigenvalues)

    epsilon = sem.min_edge_coefficient()
    sigma_min_sq = sem.min_noise_variance()

    if epsilon == 0 or sigma_min_sq == 0:
        return float('inf')

    bound = C_Sigma / (epsilon ** 2 * sigma_min_sq)

    return float(bound)


def is_well_conditioned_fisher(
    sem: LinearGaussianSEM,
    threshold: float = 100.0
) -> bool:
    """
    Check if SEM has well-conditioned Fisher dimension.

    Per Proposition 7.1, for spectrally well-conditioned SEMs,
    F([G]) = O(1) independent of graph structure.

    Args:
        sem: The Linear Gaussian SEM
        threshold: Maximum Fisher dimension to consider "well-conditioned"

    Returns:
        True if Fisher dimension bound is below threshold
    """
    bound = theoretical_fisher_dimension_bound(sem)
    return bound <= threshold


def partial_correlation_for_edge(
    sem: LinearGaussianSEM,
    parent: int,
    child: int,
    use_canonical_conditioning: bool = True
) -> float:
    """
    Compute the partial correlation for a specific edge.

    Args:
        sem: The Linear Gaussian SEM
        parent: Parent node
        child: Child node
        use_canonical_conditioning: If True, use S = Pa(child) \ {parent}

    Returns:
        Partial correlation ρ_{child,parent|S}
    """
    dag = sem.dag
    Sigma = sem.covariance_matrix()

    if use_canonical_conditioning:
        S = dag.parents(child) - {parent}
    else:
        S = set()

    return partial_correlation(Sigma, child, parent, S)


def verify_lemma_7_1(
    sem: LinearGaussianSEM,
    parent: int,
    child: int,
    tolerance: float = 1e-6
) -> Tuple[bool, float, float]:
    """
    Verify Lemma 7.1 formula for an edge.

    Lemma 7.1 states:
    ρ_{ij|S} = β_ji * sqrt(Var(X_j | X_S) / Var(X_i | X_S))

    where S = Pa(i) \ {j}

    Args:
        sem: The Linear Gaussian SEM
        parent: Parent node (j in the formula)
        child: Child node (i in the formula)
        tolerance: Tolerance for comparison

    Returns:
        Tuple of (formula_matches, computed_rho, formula_rho)
    """
    dag = sem.dag
    S = dag.parents(child) - {parent}

    # Computed partial correlation
    Sigma = sem.covariance_matrix()
    computed_rho = partial_correlation(Sigma, child, parent, S)

    # Formula-based computation
    beta_ji = sem.get_coefficient(parent, child)
    var_j_given_S = sem.conditional_variance(parent, S)
    var_i_given_S = sem.conditional_variance(child, S)

    if var_i_given_S > 0 and var_j_given_S >= 0:
        formula_rho = beta_ji * np.sqrt(var_j_given_S / var_i_given_S)
    else:
        formula_rho = 0.0

    matches = abs(computed_rho - formula_rho) < tolerance

    return matches, computed_rho, formula_rho


def analyze_fisher_dimension_components(
    dag: DAG,
    sem: LinearGaussianSEM
) -> Dict:
    """
    Detailed analysis of Fisher dimension components.

    Provides insights into what's driving the Fisher dimension value.

    Args:
        dag: The causal DAG
        sem: The Linear Gaussian SEM

    Returns:
        Dictionary with analysis results
    """
    from scipy import linalg

    result = compute_fisher_dimension(dag, sem)
    Sigma = sem.covariance_matrix()

    # Spectral analysis
    eigenvalues = linalg.eigvalsh(Sigma)

    # Edge coefficient statistics
    edge_coefs = [abs(et.edge_coefficient) for et in result.edge_tests
                  if et.edge_coefficient is not None]

    # Partial correlation statistics
    partial_corrs = [abs(et.partial_correlation) for et in result.edge_tests]

    analysis = {
        'fisher_dimension': result.fisher_dimension,
        'rho_min': result.rho_min,
        'hardest_edge': result.hardest_edge,

        # Spectral properties
        'lambda_min': float(np.min(eigenvalues)),
        'lambda_max': float(np.max(eigenvalues)),
        'condition_number': float(np.max(eigenvalues) / np.min(eigenvalues))
            if np.min(eigenvalues) > 0 else float('inf'),

        # Edge coefficients
        'beta_min': float(min(edge_coefs)) if edge_coefs else None,
        'beta_max': float(max(edge_coefs)) if edge_coefs else None,
        'beta_mean': float(np.mean(edge_coefs)) if edge_coefs else None,

        # Noise variances
        'sigma_min': sem.min_noise_variance(),
        'sigma_max': sem.max_noise_variance(),

        # Partial correlations
        'rho_mean': float(np.mean(partial_corrs)) if partial_corrs else None,
        'rho_std': float(np.std(partial_corrs)) if partial_corrs else None,

        # Theoretical bound
        'theoretical_bound': theoretical_fisher_dimension_bound(sem),

        # Graph structure
        'num_nodes': dag.num_nodes(),
        'num_edges': dag.num_edges(),
        'max_in_degree': dag.max_in_degree(),
        'num_v_structures': len(dag.v_structures()),
    }

    return analysis


def sample_complexity_prediction(
    fisher_dim: float,
    d: int,
    delta: float = 0.05,
    constant: float = 8.0
) -> int:
    """
    Predict required sample size using Theorem 4.1.

    n = C * F([G]) * log(d/δ)

    Args:
        fisher_dim: Fisher dimension F([G])
        d: Number of nodes
        delta: Failure probability
        constant: Constant C (depends on max degree)

    Returns:
        Predicted minimum sample size
    """
    if fisher_dim == float('inf'):
        return float('inf')

    log_factor = np.log(d / delta)
    n_pred = constant * fisher_dim * log_factor

    return int(np.ceil(n_pred))


def minimum_detectable_effect(
    n: int,
    d: int,
    alpha: float = 0.05,
    power: float = 0.8,
    max_cond_set_size: int = 0
) -> float:
    """
    Compute minimum detectable partial correlation given sample size.

    Based on power analysis for Fisher's z-test.

    Args:
        n: Sample size
        d: Number of nodes
        alpha: Significance level (Bonferroni corrected)
        power: Desired power
        max_cond_set_size: Maximum conditioning set size

    Returns:
        Minimum detectable |ρ|
    """
    from scipy import stats
    from .partial_correlation import inverse_fisher_z

    # Number of tests (rough approximation)
    num_tests = d * (d - 1) // 2

    # Bonferroni correction
    alpha_corrected = alpha / num_tests

    # Degrees of freedom
    df = n - max_cond_set_size - 3
    if df <= 0:
        return 1.0

    # Critical values
    z_alpha = stats.norm.ppf(1 - alpha_corrected / 2)
    z_beta = stats.norm.ppf(power)

    # Required z-value difference
    z_diff = (z_alpha + z_beta) / np.sqrt(df)

    # Convert to correlation
    rho_min = inverse_fisher_z(z_diff)

    return abs(rho_min)


def fisher_dimension_from_rho_min(rho_min: float) -> float:
    """
    Compute Fisher dimension from minimum partial correlation.

    F([G]) = 1/ρ_min²

    Args:
        rho_min: Minimum nonzero partial correlation

    Returns:
        Fisher dimension
    """
    if rho_min <= 0:
        return float('inf')
    return 1.0 / (rho_min ** 2)


def rho_min_from_fisher_dimension(fisher_dim: float) -> float:
    """
    Compute minimum partial correlation from Fisher dimension.

    ρ_min = 1/sqrt(F([G]))

    Args:
        fisher_dim: Fisher dimension

    Returns:
        Minimum partial correlation
    """
    if fisher_dim <= 0 or fisher_dim == float('inf'):
        return 0.0
    return 1.0 / np.sqrt(fisher_dim)


def compare_fisher_dimensions(
    sems: List[LinearGaussianSEM],
    names: Optional[List[str]] = None
) -> Dict:
    """
    Compare Fisher dimensions across multiple SEMs.

    Args:
        sems: List of Linear Gaussian SEMs
        names: Optional names for each SEM

    Returns:
        Dictionary with comparison results
    """
    if names is None:
        names = [f"SEM_{i}" for i in range(len(sems))]

    results = []

    for name, sem in zip(names, sems):
        fd_result = compute_fisher_dimension(sem.dag, sem)
        results.append({
            'name': name,
            'fisher_dimension': fd_result.fisher_dimension,
            'rho_min': fd_result.rho_min,
            'num_edges': fd_result.num_edges,
            'hardest_edge': fd_result.hardest_edge,
            'theoretical_bound': theoretical_fisher_dimension_bound(sem)
        })

    # Sort by Fisher dimension
    results.sort(key=lambda x: x['fisher_dimension'])

    return {
        'results': results,
        'easiest': results[0]['name'] if results else None,
        'hardest': results[-1]['name'] if results else None,
        'range': (
            results[0]['fisher_dimension'] if results else None,
            results[-1]['fisher_dimension'] if results else None
        )
    }
