import torch.nn as nn

class NodeNN(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(NodeNN, self).__init__()
        self.fc1 = nn.Linear(in_dim, out_dim)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(out_dim, out_dim)

    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    def reset(self, gain):
        nn.init.xavier_uniform_(self.fc1.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc2.weight, gain=gain)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)
