import torch
import dgl
import dgl.function as fn

def _to_bidirected_compat(g, copy_ndata=True):
    try:
        return dgl.to_bidirected(g, copy_ndata=copy_ndata)
    except TypeError:
        return dgl.to_bidirected(g)

def compute_neighbor_leaf_overlap(g: dgl.DGLGraph, binned_feature: torch.Tensor, undirected=True):
    """
    第 n 维 = 一跳邻居中，与当前点在第 n 棵树落到同一叶子的邻居数量
    返回 FloatTensor [N, T]
    """
    g2 = _to_bidirected_compat(g) if undirected else g
    g2 = g2.cpu()
    binned_cpu = binned_feature.cpu()  # [N, T]
    N, T = binned_cpu.shape
    counts = torch.zeros((N, T), dtype=torch.float32)

    for t in range(T):
        g2.ndata['leaf_t'] = binned_cpu[:, t]

        def edge_match(edges):
            return {'m': (edges.src['leaf_t'] == edges.dst['leaf_t']).to(torch.float32)}
        g2.apply_edges(edge_match)

        g2.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'cnt'))
        counts[:, t] = g2.ndata['cnt'].view(-1)

        if 'leaf_t' in g2.ndata: g2.ndata.pop('leaf_t')
        if 'cnt' in g2.ndata:    g2.ndata.pop('cnt')
        if 'm' in g2.edata:      g2.edata.pop('m')

    return counts
