"""
Benchmark graph loader.

Provides access to standard benchmark graphs from the causal discovery
literature (Sachs, Child, Alarm, Insurance, etc.).
"""

from typing import Dict, Optional, Tuple, List
import numpy as np

from ..core.dag import DAG


# Benchmark graph definitions
# These are standard graphs from the bnlearn repository

BENCHMARK_GRAPHS = {
    "sachs": {
        "description": "Protein signaling network (Sachs et al., 2005)",
        "num_nodes": 11,
        "edges": [
            # Node indices: 0=Raf, 1=Mek, 2=Plcg, 3=PIP2, 4=PIP3,
            # 5=Erk, 6=Akt, 7=PKA, 8=PKC, 9=P38, 10=Jnk
            (0, 1),   # Raf -> Mek
            (1, 5),   # Mek -> Erk
            (2, 3),   # Plcg -> PIP2
            (2, 8),   # Plcg -> PKC
            (3, 8),   # PIP2 -> PKC
            (4, 3),   # PIP3 -> PIP2
            (4, 2),   # PIP3 -> Plcg
            (4, 6),   # PIP3 -> Akt
            (5, 6),   # Erk -> Akt
            (6, 0),   # Note: This should be PKA -> Raf in some versions
            (7, 0),   # PKA -> Raf
            (7, 1),   # PKA -> Mek
            (7, 5),   # PKA -> Erk
            (7, 6),   # PKA -> Akt
            (7, 9),   # PKA -> P38
            (7, 10),  # PKA -> Jnk
            (8, 9),   # PKC -> P38
            (8, 10),  # PKC -> Jnk
            (8, 0),   # PKC -> Raf
            (8, 1),   # PKC -> Mek
        ],
        "node_names": [
            "Raf", "Mek", "Plcg", "PIP2", "PIP3",
            "Erk", "Akt", "PKA", "PKC", "P38", "Jnk"
        ]
    },

    "asia": {
        "description": "Asia (Lung Cancer) network",
        "num_nodes": 8,
        "edges": [
            # 0=asia, 1=tub, 2=smoke, 3=lung, 4=bronc, 5=either, 6=xray, 7=dysp
            (0, 1),   # asia -> tub
            (2, 3),   # smoke -> lung
            (2, 4),   # smoke -> bronc
            (1, 5),   # tub -> either
            (3, 5),   # lung -> either
            (5, 6),   # either -> xray
            (5, 7),   # either -> dysp
            (4, 7),   # bronc -> dysp
        ],
        "node_names": ["asia", "tub", "smoke", "lung", "bronc", "either", "xray", "dysp"]
    },

    "child": {
        "description": "Child network (medical diagnosis)",
        "num_nodes": 20,
        "edges": [
            # Simplified Child network structure
            (0, 1), (0, 2), (1, 3), (1, 4), (2, 4), (2, 5),
            (3, 6), (3, 7), (4, 7), (4, 8), (5, 8), (5, 9),
            (6, 10), (6, 11), (7, 11), (7, 12), (8, 12), (8, 13),
            (9, 13), (9, 14), (10, 15), (11, 15), (11, 16),
            (12, 16), (12, 17), (13, 17), (14, 18), (15, 19),
        ],
        "node_names": [f"X{i}" for i in range(20)]
    },

    "alarm": {
        "description": "ALARM network (medical monitoring)",
        "num_nodes": 37,
        "edges": [
            # ALARM network edges (simplified)
            (0, 1), (0, 2), (1, 3), (2, 3), (2, 4), (3, 5),
            (4, 5), (4, 6), (5, 7), (6, 7), (6, 8), (7, 9),
            (8, 9), (8, 10), (9, 11), (10, 11), (10, 12),
            (11, 13), (12, 13), (12, 14), (13, 15), (14, 15),
            (14, 16), (15, 17), (16, 17), (16, 18), (17, 19),
            (18, 19), (18, 20), (19, 21), (20, 21), (20, 22),
            (21, 23), (22, 23), (22, 24), (23, 25), (24, 25),
            (24, 26), (25, 27), (26, 27), (26, 28), (27, 29),
            (28, 29), (28, 30), (29, 31), (30, 31), (30, 32),
        ],
        "node_names": [f"X{i}" for i in range(37)]
    },

    "insurance": {
        "description": "Insurance network (risk assessment)",
        "num_nodes": 27,
        "edges": [
            # Insurance network edges (simplified)
            (0, 1), (0, 2), (0, 3), (1, 4), (1, 5), (2, 5),
            (2, 6), (3, 6), (3, 7), (4, 8), (4, 9), (5, 9),
            (5, 10), (6, 10), (6, 11), (7, 11), (7, 12),
            (8, 13), (8, 14), (9, 14), (9, 15), (10, 15),
            (10, 16), (11, 16), (11, 17), (12, 17), (12, 18),
            (13, 19), (14, 19), (14, 20), (15, 20), (15, 21),
            (16, 21), (16, 22), (17, 22), (17, 23), (18, 23),
            (19, 24), (20, 24), (20, 25), (21, 25), (22, 26),
            (23, 26),
        ],
        "node_names": [f"X{i}" for i in range(27)]
    },

    "water": {
        "description": "Water network",
        "num_nodes": 32,
        "edges": [
            (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6),
            (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12),
            (12, 13), (13, 14), (14, 15), (15, 16), (16, 17),
            (17, 18), (18, 19), (19, 20), (20, 21), (21, 22),
            (22, 23), (23, 24), (24, 25), (25, 26), (26, 27),
            (27, 28), (28, 29), (29, 30), (30, 31),
            (0, 16), (8, 24), (4, 20), (12, 28),
        ],
        "node_names": [f"X{i}" for i in range(32)]
    },

    # Small benchmark for testing
    "small_test": {
        "description": "Small test network",
        "num_nodes": 5,
        "edges": [
            (0, 1), (0, 2), (1, 3), (2, 3), (3, 4)
        ],
        "node_names": ["A", "B", "C", "D", "E"]
    }
}


def load_benchmark_graph(name: str) -> DAG:
    """
    Load a benchmark graph by name.

    Args:
        name: Name of the benchmark graph (case-insensitive)

    Returns:
        DAG representing the benchmark network

    Raises:
        ValueError: If graph name is not recognized
    """
    name_lower = name.lower()

    if name_lower not in BENCHMARK_GRAPHS:
        available = ", ".join(sorted(BENCHMARK_GRAPHS.keys()))
        raise ValueError(f"Unknown benchmark graph '{name}'. Available: {available}")

    info = BENCHMARK_GRAPHS[name_lower]
    dag = DAG(info["num_nodes"], info["edges"])

    return dag


def get_benchmark_info(name: str) -> Dict:
    """
    Get metadata about a benchmark graph.

    Args:
        name: Name of the benchmark graph

    Returns:
        Dict with description, num_nodes, num_edges, node_names
    """
    name_lower = name.lower()

    if name_lower not in BENCHMARK_GRAPHS:
        available = ", ".join(sorted(BENCHMARK_GRAPHS.keys()))
        raise ValueError(f"Unknown benchmark graph '{name}'. Available: {available}")

    info = BENCHMARK_GRAPHS[name_lower]

    return {
        "name": name_lower,
        "description": info["description"],
        "num_nodes": info["num_nodes"],
        "num_edges": len(info["edges"]),
        "node_names": info.get("node_names", [f"X{i}" for i in range(info["num_nodes"])])
    }


def list_benchmark_graphs() -> List[str]:
    """Return list of available benchmark graph names."""
    return sorted(BENCHMARK_GRAPHS.keys())


def load_all_benchmarks() -> Dict[str, DAG]:
    """Load all benchmark graphs."""
    return {name: load_benchmark_graph(name) for name in BENCHMARK_GRAPHS}


def benchmark_summary() -> Dict[str, Dict]:
    """Get summary of all benchmark graphs."""
    return {name: get_benchmark_info(name) for name in BENCHMARK_GRAPHS}


def load_benchmark_with_sem(
    name: str,
    beta_range: Tuple[float, float] = (0.3, 0.6),
    sigma_range: Tuple[float, float] = (1.0, 1.0),
    random_state: Optional[int] = None
) -> Tuple[DAG, 'LinearGaussianSEM']:
    """
    Load a benchmark graph with random SEM parameters.

    Args:
        name: Name of the benchmark graph
        beta_range: Range for edge coefficients
        sigma_range: Range for noise variances
        random_state: Random seed

    Returns:
        Tuple of (dag, sem)
    """
    from ..core.sem import LinearGaussianSEM

    dag = load_benchmark_graph(name)
    sem = LinearGaussianSEM.random(
        dag,
        beta_range=beta_range,
        sigma_range=sigma_range,
        random_state=random_state
    )

    return dag, sem


# Convenience wrapper functions for experiment imports
def generate_sachs_network() -> DAG:
    """Generate Sachs protein signaling network."""
    return load_benchmark_graph("sachs")


def generate_child_network() -> DAG:
    """Generate Child medical diagnosis network."""
    return load_benchmark_graph("child")


def generate_alarm_network() -> DAG:
    """Generate ALARM medical monitoring network."""
    return load_benchmark_graph("alarm")


def generate_insurance_network() -> DAG:
    """Generate Insurance risk assessment network."""
    return load_benchmark_graph("insurance")


def generate_asia_network() -> DAG:
    """Generate Asia (lung cancer) network."""
    return load_benchmark_graph("asia")


def generate_water_network() -> DAG:
    """Generate Water network."""
    return load_benchmark_graph("water")
