"""
Tree graph generators.

Trees are connected acyclic graphs with d-1 edges.
As DAGs, trees have no v-structures (unless edges point inward to a node).
"""

from typing import Optional, List, Tuple
import numpy as np

from ..core.dag import DAG


def generate_random_tree(
    d: int,
    random_state: Optional[int] = None,
    root: Optional[int] = None
) -> DAG:
    """
    Generate a random tree DAG using Prüfer sequence.

    The tree structure is uniformly random among all labeled trees.
    Edge directions point away from a randomly chosen root.

    Args:
        d: Number of nodes
        random_state: Random seed
        root: Root node (default: random)

    Returns:
        Tree DAG with d-1 edges

    Properties:
        - Connected
        - No v-structures
        - All edges undirected in CPDAG
    """
    if d < 1:
        raise ValueError(f"Number of nodes must be positive, got {d}")

    rng = np.random.default_rng(random_state)

    dag = DAG(d)

    if d == 1:
        return dag

    if d == 2:
        if rng.random() < 0.5:
            dag.add_edge(0, 1)
        else:
            dag.add_edge(1, 0)
        return dag

    # Generate random Prüfer sequence
    prufer = rng.integers(0, d, size=d - 2)

    # Convert Prüfer sequence to tree edges
    edges = prufer_to_edges(prufer.tolist(), d)

    # Choose root
    if root is None:
        root = rng.integers(0, d)

    # Orient edges away from root using BFS
    oriented_edges = orient_tree_from_root(edges, root, d)

    # Add edges to DAG
    for parent, child in oriented_edges:
        dag.add_edge(parent, child)

    return dag


def prufer_to_edges(prufer: List[int], d: int) -> List[Tuple[int, int]]:
    """
    Convert Prüfer sequence to tree edges.

    Args:
        prufer: Prüfer sequence of length d-2
        d: Number of nodes

    Returns:
        List of undirected edges (as unordered pairs)
    """
    if len(prufer) != d - 2:
        raise ValueError(f"Prüfer sequence length must be {d - 2}")

    # Count occurrences in prufer sequence
    degree = [1] * d
    for node in prufer:
        degree[node] += 1

    edges = []

    # Process prufer sequence
    prufer_set = list(prufer)
    for i in range(d - 2):
        # Find minimum leaf (degree 1 not in remaining prufer)
        for leaf in range(d):
            if degree[leaf] == 1:
                # Add edge between leaf and prufer[i]
                edges.append((leaf, prufer_set[i]))
                degree[leaf] -= 1
                degree[prufer_set[i]] -= 1
                break

    # Add final edge between remaining two nodes with degree 1
    remaining = [node for node in range(d) if degree[node] == 1]
    if len(remaining) == 2:
        edges.append((remaining[0], remaining[1]))

    return edges


def orient_tree_from_root(
    edges: List[Tuple[int, int]],
    root: int,
    d: int
) -> List[Tuple[int, int]]:
    """
    Orient tree edges away from root using BFS.

    Args:
        edges: Undirected edges
        root: Root node
        d: Number of nodes

    Returns:
        Directed edges (parent, child) pointing away from root
    """
    from collections import defaultdict, deque

    # Build adjacency list
    adj = defaultdict(list)
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)

    # BFS from root
    oriented = []
    visited = {root}
    queue = deque([root])

    while queue:
        parent = queue.popleft()
        for neighbor in adj[parent]:
            if neighbor not in visited:
                visited.add(neighbor)
                oriented.append((parent, neighbor))
                queue.append(neighbor)

    return oriented


def generate_balanced_tree(
    d: int,
    branching_factor: int = 2,
    root: int = 0
) -> DAG:
    """
    Generate a balanced tree DAG.

    Each node has at most `branching_factor` children.
    The tree is filled level by level.

    Args:
        d: Number of nodes
        branching_factor: Maximum children per node
        root: Root node index

    Returns:
        Balanced tree DAG
    """
    if d < 1:
        raise ValueError(f"Number of nodes must be positive, got {d}")
    if branching_factor < 1:
        raise ValueError(f"Branching factor must be positive, got {branching_factor}")

    dag = DAG(d)

    if d == 1:
        return dag

    # Build balanced tree structure
    # Level 0: root
    # Level 1: root's children
    # etc.

    # Map from level position to actual node index
    # Position 0 is root
    position_to_node = {0: root}
    node_to_position = {root: 0}

    # Assign positions to other nodes
    other_nodes = [i for i in range(d) if i != root]
    for pos, node in enumerate(other_nodes, start=1):
        position_to_node[pos] = node
        node_to_position[node] = pos

    # Add edges based on balanced tree structure
    for pos in range(d):
        node = position_to_node[pos]
        # Children are at positions: pos * branching_factor + 1, ... pos * branching_factor + branching_factor
        for c in range(1, branching_factor + 1):
            child_pos = pos * branching_factor + c
            if child_pos < d:
                child_node = position_to_node[child_pos]
                dag.add_edge(node, child_node)

    return dag


def generate_path_tree(d: int, root: int = 0) -> DAG:
    """
    Generate a path tree (equivalent to a chain).

    Args:
        d: Number of nodes
        root: Root node

    Returns:
        Path tree DAG
    """
    return generate_balanced_tree(d, branching_factor=1, root=root)


def generate_binary_tree(d: int, root: int = 0) -> DAG:
    """
    Generate a binary tree DAG.

    Args:
        d: Number of nodes
        root: Root node

    Returns:
        Binary tree DAG
    """
    return generate_balanced_tree(d, branching_factor=2, root=root)


def tree_depth(dag: DAG, root: int = 0) -> int:
    """
    Compute the depth (height) of a tree DAG.

    Args:
        dag: Tree DAG
        root: Root node

    Returns:
        Maximum depth from root to any leaf
    """
    from collections import deque

    if dag.num_nodes() == 1:
        return 0

    visited = {root}
    queue = deque([(root, 0)])
    max_depth = 0

    while queue:
        node, depth = queue.popleft()
        max_depth = max(max_depth, depth)

        for child in dag.children(node):
            if child not in visited:
                visited.add(child)
                queue.append((child, depth + 1))

    return max_depth


def is_tree(dag: DAG) -> bool:
    """
    Check if a DAG is a tree.

    A tree has exactly d-1 edges and is connected.
    """
    d = dag.num_nodes()
    if dag.num_edges() != d - 1:
        return False
    return dag.is_connected()
