import torch
import torch.nn as nn

class Stage3MLP(nn.Module):
    """ 使用 N 维邻域叶共现向量的补偿模型（Stage-3） """
    def __init__(self, in_dim, hidden_dim=0, dropout=0.2):
        super().__init__()
        if hidden_dim is None or hidden_dim <= 0:
            self.net = nn.Linear(in_dim, 1)
        else:
            self.net = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, 1)
            )

    def forward(self, feat_nb_overlap: torch.Tensor):
        return self.net(feat_nb_overlap)  # [N,1]
