import torch
import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self, params):
        super(Classifier, self).__init__()
        self.paras = params
        self.num_class = 3
        self.input_dim = 768

        self.model = nn.Sequential(
            nn.Linear(self.input_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, self.num_class),
            # nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        out = self.model(x)

        return out