#!/usr/bin/env python3
# coding: utf-8

import math
import torch
import torch.nn as nn
import numpy as np
from transformers import BertModel, BertTokenizer, GPT2Model, GPT2Tokenizer,\
    RobertaModel, RobertaTokenizer, DistilBertModel, DistilBertTokenizer



class BertEncoder(nn.Module):

    def __init__(self, bert_type, dropout, freeze):
        super(BertEncoder, self).__init__()


        self.tokenizer = BertTokenizer.from_pretrained("float_distributed/pretrained_models/text_encoder/"+bert_type)
        if 'clip' not in bert_type:
            self.bert_encoder = BertModel.from_pretrained("float_distributed/pretrained_models/text_encoder/"+bert_type,
                                                                     add_pooling_layer=False,
                                                                     hidden_dropout_prob=dropout,
                                                                     attention_probs_dropout_prob=dropout,
                                                                     output_hidden_states=False)
        else:
            self.bert_encoder = BertModel.from_pretrained(bert_type)

        if freeze:
            for name, param in self.bert_encoder.named_parameters():
                param.requires_grad = False
    def forward(self, input_ids, attention_mask):
        # device = next(self.parameters()).device
        device = torch.device('cuda')
        # tokenized = self.tokenizer(captions, add_special_tokens=True,
        #                            padding=True, return_tensors='pt')
        # input_ids = tokenized['input_ids'].to(device)
        # attention_mask = tokenized['attention_mask'].to(device)
        output = self.bert_encoder(input_ids=input_ids,
                                   attention_mask=attention_mask)[0]

        cls = output[:, 0, :]
        return cls

