import torch
import torch.nn as nn
from dgl.nn import SAGEConv

class BaseGraphSAGE(nn.Module):
    """ 原图上的基模型（两层 SAGEConv） """
    def __init__(self, in_feats, hidden_feats, num_classes, bin_counts):
        super().__init__()
        self.num_trees = len(bin_counts)
        self.embeddings = nn.ModuleList(
            nn.Embedding(bin_counts[i], in_feats) for i in range(self.num_trees)
        )
        self.conv_in_dim = self.num_trees * in_feats
        self.conv1 = SAGEConv(self.conv_in_dim, hidden_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(hidden_feats, hidden_feats, aggregator_type='mean')
        self.classify = nn.Linear(hidden_feats, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def _embed(self, binned_feature_ids):
        embs = []
        for i in range(self.num_trees):
            embs.append(self.embeddings[i](binned_feature_ids[:, i]))
        return torch.cat(embs, dim=1)

    def forward(self, g, binned_feature_ids):
        x = self._embed(binned_feature_ids)
        h = self.conv1(g, x); h = self.relu(h); h = self.dropout(h)
        h = self.conv2(g, h); h = self.relu(h); h = self.dropout(h)
        logits = self.classify(h)
        return logits, h
