import torch
import torch.nn as nn
from fairseq.models.roberta import RobertaModel

class RoBERTa(nn.Module):
    def __init__(self, config, num_classes, **kwargs):
        super(RoBERTa, self).__init__()

        self.roberta = RobertaModel.from_pretrained('./roberta.large.mnli', checkpoint_file='model.pt')
        for param in self.roberta.parameters():
            param.requires_grad = False

        self.d1 = nn.Dropout(0.2)
        self.l1 = nn.Linear(1024, 512)
        self.bn1 = nn.LayerNorm(512)
        self.d2 = nn.Dropout(0.2)
        self.l2 = nn.Linear(512, num_classes)


    def forward(self, x):
        x = torch.tensor(x, dtype=torch.long)
        x = self.roberta.extract_features(x)[:,0,:]
        x = self.d1(x)
        x = self.l1(x)
        x = self.bn1(x)
        x = self.d2(x)
        x = self.l2(x)

        return x

