"""
Chain graph generator.

A chain graph is a simple DAG: X_1 -> X_2 -> ... -> X_d

Properties:
- d-1 edges
- No v-structures
- MEC size = d (all orientations with at most one direction change)
"""

from typing import Optional
from ..core.dag import DAG


def generate_chain(d: int, reverse: bool = False) -> DAG:
    """
    Generate a chain DAG: X_0 -> X_1 -> X_2 -> ... -> X_{d-1}

    Args:
        d: Number of nodes
        reverse: If True, reverse direction: X_{d-1} -> ... -> X_1 -> X_0

    Returns:
        Chain DAG with d-1 edges

    Properties:
        - No v-structures (all non-adjacent pairs are d-separated)
        - All edges undirected in CPDAG
        - MEC has d DAGs (one direction change point anywhere, including ends)
    """
    if d < 1:
        raise ValueError(f"Number of nodes must be positive, got {d}")

    dag = DAG(d)

    if d == 1:
        return dag  # Single node, no edges

    if reverse:
        for i in range(d - 1, 0, -1):
            dag.add_edge(i, i - 1)
    else:
        for i in range(d - 1):
            dag.add_edge(i, i + 1)

    return dag


def generate_chain_with_ordering(d: int, ordering: list) -> DAG:
    """
    Generate a chain DAG with a specified node ordering.

    The chain follows the ordering: ordering[0] -> ordering[1] -> ... -> ordering[d-1]

    Args:
        d: Number of nodes
        ordering: Permutation of [0, 1, ..., d-1] specifying the chain order

    Returns:
        Chain DAG following the specified ordering
    """
    if d < 1:
        raise ValueError(f"Number of nodes must be positive, got {d}")

    if len(ordering) != d:
        raise ValueError(f"Ordering must have {d} elements, got {len(ordering)}")

    if set(ordering) != set(range(d)):
        raise ValueError("Ordering must be a permutation of [0, 1, ..., d-1]")

    dag = DAG(d)

    for i in range(d - 1):
        dag.add_edge(ordering[i], ordering[i + 1])

    return dag


def generate_random_chain(d: int, random_state: Optional[int] = None) -> DAG:
    """
    Generate a chain DAG with random node ordering.

    Args:
        d: Number of nodes
        random_state: Random seed for reproducibility

    Returns:
        Chain DAG with random ordering
    """
    import numpy as np

    if d < 1:
        raise ValueError(f"Number of nodes must be positive, got {d}")

    rng = np.random.default_rng(random_state)
    ordering = list(rng.permutation(d))

    return generate_chain_with_ordering(d, ordering)
