import numpy as np


def build_poset_from_node_dag(node_ids, node_names):
    n = int(node_ids.shape[0])
    preds = [set() for _ in range(n)]
    succs = [set() for _ in range(n)]

    name_to_nid = {name: i for i, name in enumerate(node_names)}

    node_edges = [
        ("owner", "anchor"),
        ("owner", "booster_from_owner"),
        ("owner", "booster1_from_owner"),
        ("owner", "booster2_from_owner"),
        ("owner", "booster3_from_owner"),
        ("owner", "booster4_from_owner"),
        ("anchor", "booster_from_anchor"),
        ("anchor", "booster1_from_anchor"),
        ("anchor", "booster2_from_anchor"),
        ("anchor", "booster3_from_anchor"),
        ("anchor", "booster4_from_anchor"),
        ("owner", "copier_from_owner"),
        ("owner", "poisoner_from_owner"),
        ("anchor", "copier_from_anchor"),
        ("anchor", "poisoner_from_anchor"),
    ]

    for u_name, v_name in node_edges:
        u_nid = name_to_nid.get(u_name)
        v_nid = name_to_nid.get(v_name)
        if u_nid is None or v_nid is None:
            continue
        idx_u = np.where(node_ids == u_nid)[0]
        idx_v = np.where(node_ids == v_nid)[0]
        if idx_u.size == 0 or idx_v.size == 0:
            continue
        for i in idx_u:
            for j in idx_v:
                succs[i].add(int(j))
                preds[j].add(int(i))

    group_ids = np.asarray(node_ids, dtype=np.int64)
    num_groups = int(group_ids.max()) + 1
    group_succ = [set() for _ in range(num_groups)]
    for u_name, v_name in node_edges:
        u_nid = name_to_nid.get(u_name)
        v_nid = name_to_nid.get(v_name)
        if u_nid is None or v_nid is None:
            continue
        group_succ[u_nid].add(v_nid)

    return {
        "n": n,
        "preds": preds,
        "succs": succs,
        "group_ids": group_ids,
        "group_succ": group_succ,
    }


def is_edge(P, a, b):
    return (b in P["succs"][a])


def topo_sort(P, rng=np.random):
    n = P["n"]
    succs = P["succs"]
    preds = P["preds"]
    indeg = np.array([len(preds[i]) for i in range(n)], dtype=np.int64)
    frontier = [i for i in range(n) if indeg[i] == 0]
    order = []
    while frontier:
        j = int(rng.randint(len(frontier)))
        u = frontier.pop(j)
        order.append(u)
        for v in succs[u]:
            indeg[v] -= 1
            if indeg[v] == 0:
                frontier.append(v)
    if len(order) != n:
        raise ValueError("Poset has a cycle.")
    return np.asarray(order, dtype=np.int64)


def max_stats_prefix(P, lamb, perm, pos, k, swap_with, in_prefix, prefix_buf):
    succs = P["succs"]
    lam = lamb
    in_prefix[:] = False

    if swap_with is None:
        prefix = perm[:k+1]
    else:
        if k > 0:
            prefix_buf[:k] = perm[:k]
        prefix_buf[k] = int(swap_with)
        prefix = prefix_buf[:k+1]

    for v in prefix:
        in_prefix[int(v)] = True

    cnt = 0
    lam_sum = 0.0
    for v in prefix:
        v = int(v)
        has_succ_in_prefix = False
        for w in succs[v]:
            if in_prefix[w]:
                has_succ_in_prefix = True
                break
        if not has_succ_in_prefix:
            cnt += 1
            lam_sum += float(lam[v])
    if cnt <= 0:
        cnt = 1
    if lam_sum <= 0.0:
        lam_sum = 1e-12
    return cnt, lam_sum


def group_max_stats_from_counts(cnt_vec, lam_vec, group_succ):
    num_groups = cnt_vec.shape[0]
    cnt = 0.0
    lam_sum = 0.0
    for gid in range(num_groups):
        if cnt_vec[gid] == 0:
            continue
        has_succ_in_prefix = False
        for q in group_succ[gid]:
            if cnt_vec[q] > 0:
                has_succ_in_prefix = True
                break
        if not has_succ_in_prefix:
            cnt += float(cnt_vec[gid])
            lam_sum += float(lam_vec[gid])
    if cnt <= 0.0:
        cnt = 1.0
    if lam_sum <= 0.0:
        lam_sum = 1e-12
    return cnt, lam_sum


