"""
Complexity proxies for causal graphs.

Alternative measures of graph complexity for comparison
with Fisher dimension.
"""

from __future__ import annotations

from typing import Dict, Optional
import numpy as np

from ..core.dag import DAG
from ..core.sem import LinearGaussianSEM
from ..core.mec import CPDAG
from ..core.fisher_dimension import compute_fisher_dimension


def graph_density(dag: DAG) -> float:
    """
    Compute edge density of DAG.

    Density = |E| / (d(d-1)/2)

    Args:
        dag: Input DAG

    Returns:
        Density in [0, 1]
    """
    d = dag.num_nodes()
    max_edges = d * (d - 1) // 2

    if max_edges == 0:
        return 0.0

    return dag.num_edges() / max_edges


def max_in_degree(dag: DAG) -> int:
    """
    Compute maximum in-degree.

    Args:
        dag: Input DAG

    Returns:
        Maximum number of parents for any node
    """
    return dag.max_in_degree()


def avg_in_degree(dag: DAG) -> float:
    """
    Compute average in-degree.

    Args:
        dag: Input DAG

    Returns:
        Average number of parents per node
    """
    d = dag.num_nodes()
    if d == 0:
        return 0.0
    return dag.num_edges() / d


def avg_markov_blanket_size(dag: DAG) -> float:
    """
    Compute average Markov blanket size.

    MB(v) = Pa(v) ∪ Ch(v) ∪ Pa(Ch(v))

    Args:
        dag: Input DAG

    Returns:
        Average |MB(v)| over all nodes
    """
    d = dag.num_nodes()
    if d == 0:
        return 0.0

    total_mb_size = sum(len(dag.markov_blanket(i)) for i in range(d))
    return total_mb_size / d


def num_v_structures(dag: DAG) -> int:
    """
    Count number of v-structures (immoralities).

    A v-structure is i → k ← j where i and j are non-adjacent.

    Args:
        dag: Input DAG

    Returns:
        Number of v-structures
    """
    return len(dag.v_structures())


def mec_size(dag: DAG, max_size: int = 10000) -> int:
    """
    Compute size of Markov Equivalence Class.

    Args:
        dag: Input DAG
        max_size: Maximum size to enumerate

    Returns:
        |[G]| or -1 if too large
    """
    cpdag = CPDAG.from_dag(dag)
    return cpdag.mec_size()


def mec_size_estimate(dag: DAG) -> int:
    """
    Estimate MEC size (upper bound).

    Upper bound is 2^k where k is number of undirected edges in CPDAG.

    Args:
        dag: Input DAG

    Returns:
        Upper bound on MEC size
    """
    cpdag = CPDAG.from_dag(dag)
    return cpdag.mec_size_estimate()


def num_compelled_edges(dag: DAG) -> int:
    """
    Count number of compelled (directed) edges in CPDAG.

    Args:
        dag: Input DAG

    Returns:
        Number of compelled edges
    """
    cpdag = CPDAG.from_dag(dag)
    return cpdag.num_directed_edges()


def num_reversible_edges(dag: DAG) -> int:
    """
    Count number of reversible (undirected) edges in CPDAG.

    Args:
        dag: Input DAG

    Returns:
        Number of reversible edges
    """
    cpdag = CPDAG.from_dag(dag)
    return cpdag.num_undirected_edges()


def fraction_compelled(dag: DAG) -> float:
    """
    Compute fraction of edges that are compelled.

    Args:
        dag: Input DAG

    Returns:
        Fraction in [0, 1]
    """
    if dag.num_edges() == 0:
        return 1.0

    return num_compelled_edges(dag) / dag.num_edges()


def max_path_length(dag: DAG) -> int:
    """
    Compute length of longest directed path.

    Args:
        dag: Input DAG

    Returns:
        Maximum path length
    """
    d = dag.num_nodes()

    if d == 0:
        return 0

    # Use dynamic programming on topological order
    topo_order = dag.topological_sort()
    max_dist = {node: 0 for node in range(d)}

    for node in topo_order:
        for child in dag.children(node):
            max_dist[child] = max(max_dist[child], max_dist[node] + 1)

    return max(max_dist.values()) if max_dist else 0


def graph_diameter(dag: DAG) -> int:
    """
    Compute diameter of underlying undirected graph.

    Args:
        dag: Input DAG

    Returns:
        Diameter (longest shortest path)
    """
    from collections import deque

    d = dag.num_nodes()
    if d == 0:
        return 0

    max_diameter = 0

    for start in range(d):
        # BFS from start
        dist = {start: 0}
        queue = deque([start])

        while queue:
            node = queue.popleft()
            for neighbor in dag.neighbors(node):
                if neighbor not in dist:
                    dist[neighbor] = dist[node] + 1
                    queue.append(neighbor)

        if dist:
            max_diameter = max(max_diameter, max(dist.values()))

    return max_diameter


def clustering_coefficient(dag: DAG) -> float:
    """
    Compute clustering coefficient of underlying undirected graph.

    Args:
        dag: Input DAG

    Returns:
        Clustering coefficient in [0, 1]
    """
    d = dag.num_nodes()
    if d < 3:
        return 0.0

    skeleton = dag.skeleton()
    total_triangles = 0
    total_triples = 0

    for node in range(d):
        neighbors = list(dag.neighbors(node))
        n_neighbors = len(neighbors)

        if n_neighbors < 2:
            continue

        # Count edges between neighbors
        edges_between = 0
        for i, n1 in enumerate(neighbors):
            for n2 in neighbors[i + 1:]:
                if frozenset([n1, n2]) in skeleton:
                    edges_between += 1

        total_triangles += edges_between
        total_triples += n_neighbors * (n_neighbors - 1) // 2

    if total_triples == 0:
        return 0.0

    return total_triangles / total_triples


def compute_all_proxies(
    dag: DAG,
    sem: Optional[LinearGaussianSEM] = None,
    include_curvature: bool = False
) -> Dict[str, float]:
    """
    Compute all complexity proxies for a DAG.

    Args:
        dag: Input DAG
        sem: Optional SEM for Fisher dimension computation
        include_curvature: Whether to compute curvature-based estimate

    Returns:
        Dict mapping proxy name to value
    """
    proxies = {
        'num_nodes': dag.num_nodes(),
        'num_edges': dag.num_edges(),
        'graph_density': graph_density(dag),
        'max_in_degree': max_in_degree(dag),
        'avg_in_degree': avg_in_degree(dag),
        'avg_markov_blanket_size': avg_markov_blanket_size(dag),
        'num_v_structures': num_v_structures(dag),
        'mec_size_estimate': mec_size_estimate(dag),
        'num_compelled_edges': num_compelled_edges(dag),
        'num_reversible_edges': num_reversible_edges(dag),
        'fraction_compelled': fraction_compelled(dag),
        'max_path_length': max_path_length(dag),
        'graph_diameter': graph_diameter(dag),
        'clustering_coefficient': clustering_coefficient(dag),
    }

    # Try to compute exact MEC size
    mec = mec_size(dag)
    proxies['mec_size'] = mec if mec > 0 else float('nan')

    # Fisher dimension if SEM provided
    if sem is not None:
        fd_result = compute_fisher_dimension(dag, sem)
        proxies['fisher_dimension'] = fd_result.fisher_dimension
        proxies['rho_min'] = fd_result.rho_min

    # Curvature-based estimate
    if include_curvature:
        from ..algorithms.curvature import (
            compute_curvature_matrix,
            estimate_fisher_dimension_from_curvature
        )
        C = compute_curvature_matrix(dag)
        proxies['curvature_fisher_estimate'] = estimate_fisher_dimension_from_curvature(C)

    return proxies


def rank_proxies_by_correlation(
    dags: list,
    sems: list,
    sample_complexities: list,
    proxies_to_test: Optional[list] = None
) -> Dict[str, float]:
    """
    Rank complexity proxies by their correlation with empirical sample complexity.

    Args:
        dags: List of DAGs
        sems: Corresponding SEMs
        sample_complexities: Empirical n* values
        proxies_to_test: Which proxies to evaluate (default: all)

    Returns:
        Dict mapping proxy name to Spearman correlation
    """
    from scipy.stats import spearmanr

    if proxies_to_test is None:
        proxies_to_test = [
            'fisher_dimension', 'graph_density', 'max_in_degree',
            'avg_markov_blanket_size', 'num_v_structures', 'mec_size_estimate',
            'fraction_compelled'
        ]

    # Compute proxy values
    proxy_values = {name: [] for name in proxies_to_test}

    for dag, sem in zip(dags, sems):
        proxies = compute_all_proxies(dag, sem)
        for name in proxies_to_test:
            if name in proxies:
                proxy_values[name].append(proxies[name])
            else:
                proxy_values[name].append(float('nan'))

    # Compute correlations
    correlations = {}
    for name, values in proxy_values.items():
        # Filter out NaN values
        valid_mask = ~np.isnan(values) & ~np.isnan(sample_complexities)
        if np.sum(valid_mask) >= 3:
            corr, _ = spearmanr(
                np.array(values)[valid_mask],
                np.array(sample_complexities)[valid_mask]
            )
            correlations[name] = corr
        else:
            correlations[name] = float('nan')

    # Sort by absolute correlation
    sorted_correlations = dict(
        sorted(correlations.items(), key=lambda x: -abs(x[1]) if not np.isnan(x[1]) else -float('inf'))
    )

    return sorted_correlations
