import torch
import torch.nn as nn
import torchdrug


class MLP(nn.Module):
    def __init__(self, n_tokens: int = 512,
                 n_hidden: int = 256,
                 activation=nn.ReLU(),
                 num_classes: int = 21):
        super().__init__()
        self.activation = activation
        self.num_classes = num_classes
        self.model = nn.Sequential(
            nn.Linear(n_tokens, n_hidden),
            self.activation,
            nn.Linear(n_hidden, self.num_classes),
        )
	
    
    def forward(self, input_embedding: torch.tensor):
        return self.model(input_embedding) 


class GraphMLP(nn.Module):
    def __init__(self, hidden_features: int = 512, activation=nn.ReLU(), num_classes: int=21):
        super().__init__()
        self.activation = activation
        self.num_classes = num_classes
        self.model = torchdrug.layers.Sequential(
            torchdrug.layers.MultiLayerPerceptron(hidden_features, hidden_features, activation='relu'),
            torchdrug.layers.MultiLayerPerceptron(hidden_features, self.num_classes),
        )

    def forward(self, input_embedding: torch.tensor):
        return self.model(input_embedding)

