"""
Factory functions for DAG generation.

Provides a unified interface for generating different types of DAGs.
"""

from typing import Optional, Union, Dict, Any
from dataclasses import dataclass, field
import numpy as np

from ..core.dag import DAG
from .chain import generate_chain, generate_random_chain
from .star import generate_star, generate_star_inward, generate_star_outward
from .complete import generate_complete_dag, generate_random_complete_dag
from .tree import generate_random_tree, generate_balanced_tree
from .erdos_renyi import generate_erdos_renyi_dag
from .benchmark import load_benchmark_graph


@dataclass
class DAGGeneratorConfig:
    """Configuration for DAG generation."""

    # Graph type: 'chain', 'star', 'complete', 'tree', 'erdos_renyi', 'benchmark'
    graph_type: str = 'erdos_renyi'

    # Number of nodes (ignored for benchmark graphs)
    num_nodes: int = 10

    # Erdős-Rényi parameters
    er_probability: float = 0.2

    # Star parameters
    star_center: int = 0
    star_center_is_parent: bool = True

    # Tree parameters
    tree_branching_factor: int = 2
    tree_balanced: bool = False

    # Benchmark graph name
    benchmark_name: str = 'sachs'

    # Random seed
    random_state: Optional[int] = None

    # Additional kwargs
    kwargs: Dict[str, Any] = field(default_factory=dict)


def generate_dag(
    graph_type: str,
    num_nodes: Optional[int] = None,
    random_state: Optional[int] = None,
    **kwargs
) -> DAG:
    """
    Generate a DAG of the specified type.

    Args:
        graph_type: One of 'chain', 'star', 'star_inward', 'star_outward',
                    'complete', 'tree', 'balanced_tree', 'erdos_renyi',
                    'erdos_renyi_sparse', 'erdos_renyi_dense', 'benchmark'
        num_nodes: Number of nodes (required for most types)
        random_state: Random seed
        **kwargs: Additional arguments for specific generators

    Returns:
        Generated DAG

    Raises:
        ValueError: If graph type is unknown or parameters are invalid
    """
    graph_type = graph_type.lower().replace('-', '_')

    # Benchmark graphs
    if graph_type == 'benchmark':
        name = kwargs.get('name', kwargs.get('benchmark_name', 'sachs'))
        return load_benchmark_graph(name)

    # All other types require num_nodes
    if num_nodes is None:
        raise ValueError(f"num_nodes required for graph type '{graph_type}'")

    # Chain graphs
    if graph_type == 'chain':
        return generate_chain(num_nodes)
    elif graph_type == 'random_chain':
        return generate_random_chain(num_nodes, random_state)

    # Star graphs
    elif graph_type == 'star':
        center = kwargs.get('center', 0)
        center_is_parent = kwargs.get('center_is_parent', True)
        return generate_star(num_nodes, center=center, center_is_parent=center_is_parent)
    elif graph_type == 'star_inward':
        center = kwargs.get('center', 0)
        return generate_star_inward(num_nodes, center=center)
    elif graph_type == 'star_outward':
        center = kwargs.get('center', 0)
        return generate_star_outward(num_nodes, center=center)

    # Complete graphs
    elif graph_type == 'complete':
        ordering = kwargs.get('ordering')
        return generate_complete_dag(num_nodes, ordering=ordering)
    elif graph_type == 'random_complete':
        return generate_random_complete_dag(num_nodes, random_state)

    # Tree graphs
    elif graph_type == 'tree' or graph_type == 'random_tree':
        root = kwargs.get('root')
        return generate_random_tree(num_nodes, random_state, root=root)
    elif graph_type == 'balanced_tree':
        branching = kwargs.get('branching_factor', 2)
        root = kwargs.get('root', 0)
        return generate_balanced_tree(num_nodes, branching_factor=branching, root=root)
    elif graph_type == 'binary_tree':
        root = kwargs.get('root', 0)
        return generate_balanced_tree(num_nodes, branching_factor=2, root=root)

    # Erdős-Rényi graphs
    elif graph_type == 'erdos_renyi' or graph_type == 'er':
        p = kwargs.get('p', kwargs.get('probability', 0.2))
        ordering = kwargs.get('ordering')
        return generate_erdos_renyi_dag(num_nodes, p, random_state, ordering)
    elif graph_type == 'erdos_renyi_sparse' or graph_type == 'sparse':
        p = kwargs.get('p', 0.1)
        return generate_erdos_renyi_dag(num_nodes, p, random_state)
    elif graph_type == 'erdos_renyi_moderate' or graph_type == 'moderate':
        p = kwargs.get('p', 0.2)
        return generate_erdos_renyi_dag(num_nodes, p, random_state)
    elif graph_type == 'erdos_renyi_dense' or graph_type == 'dense':
        p = kwargs.get('p', 0.5)
        return generate_erdos_renyi_dag(num_nodes, p, random_state)

    else:
        available_types = [
            'chain', 'random_chain',
            'star', 'star_inward', 'star_outward',
            'complete', 'random_complete',
            'tree', 'random_tree', 'balanced_tree', 'binary_tree',
            'erdos_renyi', 'er', 'sparse', 'moderate', 'dense',
            'benchmark'
        ]
        raise ValueError(
            f"Unknown graph type '{graph_type}'. "
            f"Available: {', '.join(sorted(available_types))}"
        )


def generate_dag_from_config(config: DAGGeneratorConfig) -> DAG:
    """
    Generate a DAG from a configuration object.

    Args:
        config: DAGGeneratorConfig object

    Returns:
        Generated DAG
    """
    kwargs = dict(config.kwargs)

    if config.graph_type == 'erdos_renyi':
        kwargs['p'] = config.er_probability

    if config.graph_type == 'star':
        kwargs['center'] = config.star_center
        kwargs['center_is_parent'] = config.star_center_is_parent

    if config.graph_type in ('balanced_tree', 'tree'):
        kwargs['branching_factor'] = config.tree_branching_factor

    if config.graph_type == 'benchmark':
        kwargs['name'] = config.benchmark_name

    return generate_dag(
        config.graph_type,
        num_nodes=config.num_nodes,
        random_state=config.random_state,
        **kwargs
    )


def generate_multiple_dags(
    graph_type: str,
    num_nodes: int,
    n_graphs: int,
    random_state: Optional[int] = None,
    **kwargs
) -> list:
    """
    Generate multiple independent DAGs of the same type.

    Args:
        graph_type: Type of graph to generate
        num_nodes: Number of nodes per graph
        n_graphs: Number of graphs to generate
        random_state: Random seed
        **kwargs: Additional arguments for the generator

    Returns:
        List of DAGs
    """
    rng = np.random.default_rng(random_state)

    dags = []
    for _ in range(n_graphs):
        seed = int(rng.integers(0, 2**31))
        dag = generate_dag(graph_type, num_nodes, random_state=seed, **kwargs)
        dags.append(dag)

    return dags


def graph_type_description(graph_type: str) -> str:
    """Get a description of a graph type."""
    descriptions = {
        'chain': "Linear chain: X_0 -> X_1 -> ... -> X_{d-1}. No v-structures.",
        'star': "Star graph with configurable center and direction.",
        'star_inward': "Star with leaves pointing to center. Many v-structures.",
        'star_outward': "Star with center pointing to leaves. No v-structures.",
        'complete': "Complete DAG with all possible edges.",
        'tree': "Random tree structure. No v-structures.",
        'balanced_tree': "Balanced tree with fixed branching factor.",
        'erdos_renyi': "Random DAG G(d, p) with edge probability p.",
        'benchmark': "Standard benchmark graph from causal discovery literature.",
    }
    return descriptions.get(graph_type.lower(), "Unknown graph type")


# Convenience functions for common experiment setups


def generate_experiment_graphs(
    d: int,
    n_per_type: int = 10,
    random_state: Optional[int] = None
) -> Dict[str, list]:
    """
    Generate a suite of graphs for experiments.

    Generates multiple instances of each common graph type.

    Args:
        d: Number of nodes
        n_per_type: Number of graphs per type
        random_state: Random seed

    Returns:
        Dict mapping graph type to list of DAGs
    """
    rng = np.random.default_rng(random_state)

    types_and_params = [
        ('chain', {}),
        ('star_outward', {}),
        ('star_inward', {}),
        ('complete', {}),
        ('tree', {}),
        ('erdos_renyi', {'p': 0.1}),
        ('erdos_renyi', {'p': 0.2}),
        ('erdos_renyi', {'p': 0.3}),
    ]

    result = {}

    for graph_type, params in types_and_params:
        key = graph_type
        if 'p' in params:
            key = f"{graph_type}_p{params['p']}"

        dags = []
        for _ in range(n_per_type):
            seed = int(rng.integers(0, 2**31))
            dag = generate_dag(graph_type, d, random_state=seed, **params)
            dags.append(dag)

        result[key] = dags

    return result
