import torch.nn as nn
from transformers import BertForSequenceClassification, BertConfig
import torch
import torch.nn.functional as F


from transformers import AutoConfig, AutoModel
#ForSequenceClassification

from help_ import BertNoEmbed



class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x


# Create the BertClassfier class
class Bert(nn.Module):
   def __init__(self, args, num_labels, task_name, ignore_mismatched_sizes=True, num_hidden_layers=6, hidden_size=256, num_attention_heads=8, dropout=0):
       super(Bert, self).__init__()    
       
       config_enc = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
       self.enc = AutoModel.from_pretrained(
         args.model_name_or_path,
         from_tf=bool(".ckpt" in args.model_name_or_path),
         config=config_enc,
         ignore_mismatched_sizes=args.ignore_mismatched_sizes,
    )

       for _, param in self.enc.named_parameters():
           param.requires_grad = False


       assert hidden_size%num_attention_heads == 0, "Hidden size must be multiple of number of heads"
       self.config = BertConfig(
          #vocab_size= vocab_size,
          num_labels=num_labels,
          finetuning_task=task_name,
          hidden_size=hidden_size, #768
          num_hidden_layers=num_hidden_layers,
          num_attention_heads=num_attention_heads, #12
          intermediate_size=2*hidden_size, #3072
          #hidden_act="gelu",
          hidden_dropout_prob=dropout,
          #attention_probs_dropout_prob=dropout,
          #max_position_embeddings=1000,#max_output_len,
          #type_vocab_size=2,
          #initializer_range=0.02,
          #layer_norm_eps=1e-12,
          #pad_token_id=pad_idx,
          #sep_token_id = sep_idx,
          # position_embedding_type="absolute",
          #position_embedding_type=positional_encodings,
          #use_cache=True,
          classifier_dropout=None,
          output_hidden_states=True,
          output_attentions=True
       )
       
       self.model = BertNoEmbed(self.config) 
       self.linear = nn.Linear(768,hidden_size)

   def get_features(self, inputs):

      outputs_enc = self.enc(inputs["input_ids"], attention_mask=inputs["attention_mask"])
      
      hidden_states_enc = outputs_enc[0]
      new_input = self.linear(hidden_states_enc)
      
      features = self.model.get_features(inputs_embeds=new_input, labels=inputs["labels"]) 
      return features

   def forward(self, inputs):
      outputs_enc = self.enc(inputs["input_ids"], attention_mask=inputs["attention_mask"])
      
      hidden_states_enc = outputs_enc[0]
      new_input = self.linear(hidden_states_enc)

      outputs = self.model(inputs_embeds=new_input, labels=inputs["labels"]) 
      
      return outputs


def set_encoder(model):
      config_enc = AutoConfig.from_pretrained("bert-base-cased", num_labels=3, finetuning_task="mnli")
      model.enc = AutoModel.from_pretrained(
         "bert-base-cased",
         from_tf=bool(".ckpt" in "bert-base-cased"),
         config=config_enc,
         ignore_mismatched_sizes=False,
    )

      for _, param in model.enc.named_parameters():
          param.requires_grad = False

