import torch
import torch.nn as nn
from .weighted_sage import WeightedSAGEConv

class BoosterGraphSAGE(nn.Module):
    """ 相似图上的补偿模型（Stage-2） """
    def __init__(self, in_feats, hidden_feats, num_classes, bin_counts):
        super().__init__()
        self.num_trees = len(bin_counts)
        self.embeddings = nn.ModuleList(
            nn.Embedding(bin_counts[i], in_feats) for i in range(self.num_trees)
        )
        self.conv_in_dim = self.num_trees * in_feats
        self.conv1 = WeightedSAGEConv(self.conv_in_dim, hidden_feats)
        self.conv2 = WeightedSAGEConv(hidden_feats, hidden_feats)
        self.classify = nn.Linear(hidden_feats, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def _embed(self, binned_feature_ids):
        embs = []
        for i in range(self.num_trees):
            embs.append(self.embeddings[i](binned_feature_ids[:, i]))
        return torch.cat(embs, dim=1)

    def forward(self, g_sim, binned_feature_ids):
        eweight = g_sim.edata['w'] if 'w' in g_sim.edata else None
        x = self._embed(binned_feature_ids)
        h = self.conv1(g_sim, x, eweight); h = self.relu(h); h = self.dropout(h)
        h = self.conv2(g_sim, h, eweight); h = self.relu(h); h = self.dropout(h)
        logits = self.classify(h)
        return logits, h
