from transformers import AutoModel
import torch
import torch.nn as nn

class BERTNetwork(torch.nn.Module):
    def __init__(self, opt):
        super(BERTNetwork, self).__init__()

        self.pars  = opt
        self.name = 'BERT'

        self.model = AutoModel.from_pretrained("bert-base-uncased")

        if 'frozen' in opt.text_arch:
            for param in self.model.parameters():
                param.requires_grad = False
            self.model.eval()
        else:
            self.model.train()

        self.model.last_linear = nn.Linear(768, opt.embed_dim) 
        
        self.enc_out_dim = 768 # opt.embed_dim
        self.out_adjust = None


    def forward(self, input_ids, token_type_ids, 
            attention_mask, **kwargs):
        emb_out = self.model(input_ids = input_ids, token_type_ids = token_type_ids, 
            attention_mask = attention_mask)[1]
        x = self.model.last_linear(emb_out)

        if 'normalize' in self.pars.text_arch:
            x = torch.nn.functional.normalize(emb_out, dim=-1)
        if self.out_adjust and not self.train:
            x = self.out_adjust(x)

        return x, (emb_out, emb_out)
