import numpy as np
from typing import List, Tuple


def build_poset_from_edges(node_names: List[str], edges: List[Tuple[str, str]]):
    F = len(node_names)
    name_to_id = {name: i for i, name in enumerate(node_names)}
    preds = [set() for _ in range(F)]
    succs = [set() for _ in range(F)]

    for u_name, v_name in edges:
        u = name_to_id.get(u_name)
        v = name_to_id.get(v_name)
        if u is None or v is None or u == v:
            continue
        succs[u].add(v)
        preds[v].add(u)

    return {"n": F, "preds": preds, "succs": succs}


def topo_sort(P, rng):
    n = P["n"]
    succs = P["succs"]
    preds = P["preds"]
    indeg = np.array([len(preds[i]) for i in range(n)], dtype=np.int64)
    frontier = [i for i in range(n) if indeg[i] == 0]
    order = []
    while frontier:
        j = int(rng.randint(len(frontier)))
        u = int(frontier.pop(j))
        order.append(u)
        for v in succs[u]:
            indeg[v] -= 1
            if indeg[v] == 0:
                frontier.append(v)
    if len(order) != n:
        raise ValueError("Poset has a cycle.")
    return np.asarray(order, dtype=np.int64)


def build_income_poset(node_names: List[str]):
    edges_spec_general = [
        ("age", "education"), ("age", "marital-status"), ("age", "occupation"),
        ("age", "relationship"), ("age", "capital-gain"), ("age", "capital-loss"),
        ("sex", "education"), ("sex", "occupation"), ("sex", "workclass"),
        ("sex", "marital-status"), ("sex", "relationship"),
        ("race", "education"), ("race", "occupation"), ("race", "workclass"),
        ("native-country", "education"), ("native-country", "occupation"), ("native-country", "workclass"),
        ("education", "occupation"), ("occupation", "workclass"), ("occupation", "hours-per-week"),
        ("workclass", "hours-per-week"), ("education", "capital-gain"), ("education", "capital-loss"),
        ("occupation", "capital-gain"), ("occupation", "capital-loss"),
        ("marital-status", "relationship"),
    ]
    edges_general = [(u, v) for (u, v) in edges_spec_general if u in node_names and v in node_names]
    P_general = build_poset_from_edges(node_names, edges_general)
    
    _ = topo_sort(P_general, rng=np.random.RandomState(0))
    
    return P_general


def build_ordered_partition_poset(
    node_names: List[str],
    upstream_block: List[str],
    downstream_block: List[str]
):
    edges_bipartite = [(u, v) for u in upstream_block for v in downstream_block]
    edges_filtered = [(u, v) for (u, v) in edges_bipartite if u in node_names and v in node_names]
    
    P_ordered_partition = build_poset_from_edges(node_names, edges_filtered)
    
    _ = topo_sort(P_ordered_partition, rng=np.random.RandomState(0))
    
    return P_ordered_partition




