"""
Curvature matrix computation for Conjecture 4.1.

The curvature matrix C_G measures edge overlap in Markov blankets
and is conjectured to characterize Fisher dimension.
"""

from __future__ import annotations

from typing import Set, List, Tuple, Optional, Dict
import numpy as np
from scipy import linalg

from ..core.dag import DAG
from ..core.mec import CPDAG


def get_edges_in_markov_blanket(dag: DAG, node: int) -> Set[Tuple[int, int]]:
    """
    Get all edges that touch the Markov blanket of a node.

    MB(v) = Pa(v) ∪ Ch(v) ∪ Pa(Ch(v))

    Edges in MB are those with at least one endpoint in MB.

    Args:
        dag: The DAG
        node: Node whose Markov blanket to consider

    Returns:
        Set of directed edges (parent, child) in the Markov blanket
    """
    mb_nodes = dag.markov_blanket(node)
    mb_nodes.add(node)

    mb_edges = set()

    for parent, child in dag.edges:
        if parent in mb_nodes or child in mb_nodes:
            mb_edges.add((parent, child))

    return mb_edges


def compute_curvature_matrix(dag: DAG) -> np.ndarray:
    """
    Compute the curvature matrix C_G for a single DAG.

    Definition 4.1:
    (C_G)_{e, e'} = Σ_v 1[e, e' ∈ MB(v)]

    This counts how many Markov blankets contain both edges e and e'.

    Args:
        dag: Input DAG

    Returns:
        |E| × |E| curvature matrix
    """
    edges = list(dag.edges)
    m = len(edges)

    if m == 0:
        return np.array([[]])

    edge_to_idx = {e: i for i, e in enumerate(edges)}

    C = np.zeros((m, m), dtype=np.float64)

    for v in range(dag.num_nodes()):
        # Get edges in Markov blanket of v
        mb_edges = get_edges_in_markov_blanket(dag, v)

        # Count co-occurrences
        mb_edge_list = [e for e in mb_edges if e in edge_to_idx]

        for e1 in mb_edge_list:
            for e2 in mb_edge_list:
                idx1 = edge_to_idx[e1]
                idx2 = edge_to_idx[e2]
                C[idx1, idx2] += 1

    return C


def compute_mec_curvature_matrix(
    cpdag: CPDAG,
    max_mec_size: int = 1000,
    sample_size: int = 100,
    random_state: Optional[int] = None
) -> np.ndarray:
    """
    Compute the MEC curvature matrix C_{[G]} = Σ_{G' ∈ [G]} C_{G'}.

    Warning: This can be expensive for large MECs.
    Uses sampling approximation if MEC is too large.

    Args:
        cpdag: CPDAG representing the MEC
        max_mec_size: Maximum MEC size to enumerate exactly
        sample_size: Number of samples if using approximation
        random_state: Random seed for sampling

    Returns:
        |E| × |E| curvature matrix summed over MEC
    """
    # Get skeleton edges
    skeleton = cpdag.skeleton()
    edges = list(skeleton)
    m = len(edges)

    if m == 0:
        return np.array([[]])

    # Create edge index mapping
    edge_to_idx = {frozenset(e) if isinstance(e, tuple) else e: i
                   for i, e in enumerate(edges)}

    try:
        # Try exact enumeration
        mec_dags = cpdag.enumerate_mec(max_size=max_mec_size)
        C_total = np.zeros((m, m), dtype=np.float64)

        for dag in mec_dags:
            C_dag = compute_curvature_matrix(dag)

            # Map DAG edges to skeleton indices
            dag_edges = list(dag.edges)
            dag_edge_to_idx = {e: i for i, e in enumerate(dag_edges)}

            for i, e1 in enumerate(dag_edges):
                for j, e2 in enumerate(dag_edges):
                    skel_idx1 = edge_to_idx.get(frozenset(e1))
                    skel_idx2 = edge_to_idx.get(frozenset(e2))

                    if skel_idx1 is not None and skel_idx2 is not None:
                        C_total[skel_idx1, skel_idx2] += C_dag[i, j]

        return C_total

    except ValueError:
        # MEC too large - use sampling
        from ..core.mec import sample_dag_from_mec
        rng = np.random.default_rng(random_state)

        C_total = np.zeros((m, m), dtype=np.float64)

        for _ in range(sample_size):
            seed = int(rng.integers(0, 2**31))
            dag = sample_dag_from_mec(cpdag, random_state=seed)

            C_dag = compute_curvature_matrix(dag)

            dag_edges = list(dag.edges)
            for i, e1 in enumerate(dag_edges):
                for j, e2 in enumerate(dag_edges):
                    skel_idx1 = edge_to_idx.get(frozenset(e1))
                    skel_idx2 = edge_to_idx.get(frozenset(e2))

                    if skel_idx1 is not None and skel_idx2 is not None:
                        C_total[skel_idx1, skel_idx2] += C_dag[i, j]

        # Scale by estimated MEC size
        estimated_mec_size = cpdag.mec_size_estimate()
        C_total *= estimated_mec_size / sample_size

        return C_total


def compute_curvature_eigenvalues(
    C: np.ndarray,
    k: Optional[int] = None
) -> np.ndarray:
    """
    Compute eigenvalues of curvature matrix.

    Args:
        C: Curvature matrix
        k: Number of eigenvalues to compute (None = all)

    Returns:
        Array of eigenvalues in ascending order
    """
    if C.size == 0:
        return np.array([])

    eigenvalues = linalg.eigvalsh(C)

    if k is not None:
        eigenvalues = eigenvalues[:k]

    return eigenvalues


def estimate_fisher_dimension_from_curvature(
    curvature_matrix: np.ndarray,
    regularization: float = 1e-10
) -> float:
    """
    Estimate Fisher dimension from curvature matrix.

    Under Conjecture 4.1: F([G]) ≈ 1 / λ_min(C_{[G]})

    Args:
        curvature_matrix: Curvature matrix C_{[G]}
        regularization: Small value to avoid division by zero

    Returns:
        Estimated Fisher dimension
    """
    if curvature_matrix.size == 0:
        return float('inf')

    eigenvalues = compute_curvature_eigenvalues(curvature_matrix)

    lambda_min = eigenvalues[0]

    if lambda_min <= regularization:
        return float('inf')

    return 1.0 / lambda_min


def curvature_condition_number(C: np.ndarray) -> float:
    """
    Compute condition number of curvature matrix.

    Condition number = λ_max / λ_min

    Args:
        C: Curvature matrix

    Returns:
        Condition number
    """
    if C.size == 0:
        return float('inf')

    eigenvalues = compute_curvature_eigenvalues(C)

    if eigenvalues[0] <= 0:
        return float('inf')

    return eigenvalues[-1] / eigenvalues[0]


def analyze_curvature_structure(dag: DAG) -> Dict:
    """
    Detailed analysis of curvature matrix structure.

    Args:
        dag: Input DAG

    Returns:
        Dict with analysis results
    """
    C = compute_curvature_matrix(dag)

    if C.size == 0:
        return {
            'num_edges': 0,
            'curvature_matrix_shape': (0, 0),
            'eigenvalues': [],
            'lambda_min': None,
            'lambda_max': None,
            'condition_number': None,
            'estimated_fisher_dim': float('inf'),
            'sparsity': 1.0,
        }

    eigenvalues = compute_curvature_eigenvalues(C)

    # Compute sparsity (fraction of zero entries)
    sparsity = np.sum(C == 0) / C.size

    # Compute row sums (edge "importance")
    row_sums = np.sum(C, axis=1)

    return {
        'num_edges': dag.num_edges(),
        'curvature_matrix_shape': C.shape,
        'eigenvalues': eigenvalues.tolist(),
        'lambda_min': float(eigenvalues[0]),
        'lambda_max': float(eigenvalues[-1]),
        'condition_number': float(eigenvalues[-1] / eigenvalues[0]) if eigenvalues[0] > 0 else float('inf'),
        'estimated_fisher_dim': estimate_fisher_dimension_from_curvature(C),
        'sparsity': float(sparsity),
        'min_row_sum': float(np.min(row_sums)),
        'max_row_sum': float(np.max(row_sums)),
        'mean_row_sum': float(np.mean(row_sums)),
    }


def compare_fisher_dimension_methods(
    dag: DAG,
    sem: 'LinearGaussianSEM'
) -> Dict:
    """
    Compare Fisher dimension from partial correlations vs curvature.

    Tests Conjecture 4.1 by comparing both methods.

    Args:
        dag: The DAG
        sem: SEM parameters

    Returns:
        Dict comparing both methods
    """
    from ..core.fisher_dimension import compute_fisher_dimension
    from ..core.mec import CPDAG

    # Method 1: Direct partial correlation (Definition 3.1)
    fd_result = compute_fisher_dimension(dag, sem)

    # Method 2: Curvature matrix (Conjecture 4.1)
    cpdag = CPDAG.from_dag(dag)
    C = compute_mec_curvature_matrix(cpdag)
    fd_curvature = estimate_fisher_dimension_from_curvature(C)

    eigenvalues = compute_curvature_eigenvalues(C)

    return {
        'method_partial_correlation': {
            'fisher_dimension': fd_result.fisher_dimension,
            'rho_min': fd_result.rho_min,
            'hardest_edge': fd_result.hardest_edge,
        },
        'method_curvature': {
            'fisher_dimension': fd_curvature,
            'lambda_min': float(eigenvalues[0]) if len(eigenvalues) > 0 else None,
            'lambda_max': float(eigenvalues[-1]) if len(eigenvalues) > 0 else None,
        },
        'ratio': fd_result.fisher_dimension / fd_curvature if fd_curvature > 0 else float('inf'),
        'conjecture_supported': 0.1 < fd_result.fisher_dimension / fd_curvature < 10 if fd_curvature > 0 else False,
    }


def curvature_matrix_for_chain(d: int) -> np.ndarray:
    """
    Compute curvature matrix for a chain graph analytically.

    For a chain X_1 -> X_2 -> ... -> X_d, the curvature matrix
    is approximately tridiagonal.

    Args:
        d: Number of nodes

    Returns:
        (d-1) × (d-1) curvature matrix for chain
    """
    if d <= 1:
        return np.array([[]])

    m = d - 1  # Number of edges

    C = np.zeros((m, m), dtype=np.float64)

    # For chain, edge i connects node i to node i+1
    # MB(node k) includes edges incident to k, children of k, and parents of children
    for k in range(d):
        # Edges in MB(k):
        # - Edge k-1 (if exists): connects k-1 to k
        # - Edge k (if exists): connects k to k+1
        mb_edges = []
        if k > 0:
            mb_edges.append(k - 1)  # Edge from k-1 to k
        if k < d - 1:
            mb_edges.append(k)  # Edge from k to k+1

        for e1 in mb_edges:
            for e2 in mb_edges:
                C[e1, e2] += 1

    return C


def curvature_matrix_for_complete(d: int) -> np.ndarray:
    """
    Compute curvature matrix for a complete DAG analytically.

    For a complete DAG, each edge appears in O(d) Markov blankets.

    Args:
        d: Number of nodes

    Returns:
        Curvature matrix for complete DAG
    """
    from ..generators.complete import generate_complete_dag

    dag = generate_complete_dag(d)
    return compute_curvature_matrix(dag)


def predict_fisher_dimension_from_structure(
    graph_type: str,
    d: int
) -> Dict:
    """
    Predict Fisher dimension based on graph structure.

    Based on theoretical analysis from the paper.

    Args:
        graph_type: 'chain', 'star', 'complete', 'tree', 'erdos_renyi'
        d: Number of nodes

    Returns:
        Dict with prediction and confidence
    """
    predictions = {
        'chain': {
            'curvature_prediction': d,  # λ_min = O(1/d)
            'partial_corr_prediction': 1,  # With Θ(1) coefficients
            'note': "Discrepancy reflects different notions of difficulty"
        },
        'complete': {
            'curvature_prediction': 1,  # λ_min = O(d), so F = O(1)
            'partial_corr_prediction': 1,  # Proposition 6.3
            'note': "Both methods agree: F([G]) = Θ(1)"
        },
        'star_outward': {
            'curvature_prediction': d,  # λ_min = O(1/d)
            'partial_corr_prediction': 1,  # With Θ(1) coefficients
            'note': "Discrepancy: curvature captures orientation difficulty"
        },
        'star_inward': {
            'curvature_prediction': 1,  # Many v-structures fix orientation
            'partial_corr_prediction': 1,
            'note': "Both methods agree for inward star"
        },
        'tree': {
            'curvature_prediction': np.log(d),  # Conjectured
            'partial_corr_prediction': 1,  # With Θ(1) coefficients
            'note': "Conjectured scaling"
        },
        'erdos_renyi': {
            'curvature_prediction': np.log(d),  # Conjectured
            'partial_corr_prediction': 1,  # With Θ(1) coefficients
            'note': "Conjectured scaling for random graphs"
        }
    }

    if graph_type.lower() in predictions:
        return predictions[graph_type.lower()]
    else:
        return {
            'curvature_prediction': None,
            'partial_corr_prediction': None,
            'note': "Unknown graph type"
        }
