import torch
import torch.nn as nn
import dgl.function as fn

class WeightedSAGEConv(nn.Module):
    """ 加权均值 SAGEConv，使用边权 edata['w'] """
    def __init__(self, in_feats, out_feats, bias=True, eps=1e-6):
        super().__init__()
        self.fc_self = nn.Linear(in_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(in_feats, out_feats, bias=False)
        self.eps = eps

    def forward(self, g, x, eweight=None):
        with g.local_scope():
            g.ndata['h'] = x
            if eweight is None:
                g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
                neigh = g.ndata['neigh']
            else:
                g.edata['w'] = eweight
                g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'wsum_h'))
                g.update_all(fn.copy_e('w', 'wmsg'), fn.sum('wmsg', 'wsum'))
                wsum = g.ndata['wsum']
                neigh = g.ndata['wsum_h'] / (wsum + self.eps)
            return self.fc_self(x) + self.fc_neigh(neigh)
