from abc import abstractmethod

import torch.nn as nn

from gcip.modules.mlp import MLP
from gcip.utils.activations import get_act_fn


class NodeWrapper(nn.Module):
    def __init__(self,
                 layer):
        super(NodeWrapper, self).__init__()

        self.layer = layer



    def forward(self, batch, **kwargs):
        batch.x = self.layer(batch.x, *kwargs)
        return batch