import torch.nn as nn
from transformers import BertModel

class HuggingFaceEncoder(nn.Module):
    def __init__(self, cfg):
        super(HuggingFaceEncoder, self).__init__()
        self.text_encoder = BertModel.from_pretrained(cfg.hf_textencoder_model_id)
        self.n_layers = cfg.textmlp_n_layers
        self.hidden_dim = cfg.textmlp_hidden_dim
        self.embedding_dim = cfg.embedding_dim
        self.dropout_prob = cfg.textmlp_dropout_prob

        # Freeze all the parameters in the BERT model
        if cfg.freeze_text_encoders:
            for param in self.text_encoder.parameters():
                param.requires_grad = False
            
        # Create MLP layers with ReLU activations and BatchNorm except for the last layer, and add Dropout conditionally
        mlp_layers = []
        for i in range(self.n_layers):
            input_dim = self.embedding_dim if i == 0 else self.hidden_dim
            output_dim = self.embedding_dim if i == self.n_layers - 1 else self.hidden_dim
            mlp_layers.append(nn.Linear(input_dim, output_dim))
            if i < self.n_layers - 1: 
                mlp_layers.append(nn.BatchNorm1d(output_dim)) 
                mlp_layers.append(nn.ReLU())
                if self.dropout_prob > 0:  # Only add dropout if dropout probability is greater than zero
                    mlp_layers.append(nn.Dropout(self.dropout_prob))
        
        # Sequential container for our MLP layers
        self.mlp = nn.Sequential(*mlp_layers)

    def forward(self, data):
        input_ids = data.tokenized_title["input_ids"]
        attention_mask = data.tokenized_title["attention_mask"]
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Pass the pooler_output through the MLP with batch normalization
        text_embeddings = self.mlp(outputs.pooler_output)
        
        return text_embeddings
