import torch
import torch.nn as nn

class BERT_Arch(nn.Module):
    def __init__(self, bert, num_labels):
      
      super(BERT_Arch, self).__init__()
      self.bert = bert 
      self.dropout = nn.Dropout(0.1)
      self.relu =  nn.ReLU()
      self.fc1 = nn.Linear(768,512)
      self.fc2 = nn.Linear(512,num_labels)
      self.softmax = nn.LogSoftmax(dim=1)

    #define the forward pass
    def forward(self, sent_id, mask):
      #pass the inputs to the model  
      output = self.bert(sent_id, attention_mask=mask)
      x = self.fc1(output['pooler_output'])
      x = self.relu(x)
      x = self.dropout(x)
      x = self.fc2(x)
      x = self.relu(x)
      x = self.softmax(x)

      return x