import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from models import CrossAttentionModule, CrossAttentionModule_blend
import pandas as pd
# from torch_geometric.data.batch import DataBatch
from utils.embedding_utils import TokenizedInput

# class TokenizedInput:
#     def __init__(self, tokenized_title):
#         self.tokenized_title = tokenized_title

def tokenized_texts(text_list, tokenizer, device):
    encoded_input = tokenizer(text_list, padding=True, truncation=True, return_tensors='pt')

    data = TokenizedInput({
        "input_ids": encoded_input['input_ids'].to(device),
        "attention_mask": encoded_input['attention_mask'].to(device)
    })
      
    return data

class HuggingFaceEncoder(nn.Module):
    def __init__(self, cfg):
        super(HuggingFaceEncoder, self).__init__()
        self.text_encoder = BertModel.from_pretrained(cfg.hf_textencoder_model_id)
        self.n_layers = cfg.textmlp_n_layers
        self.hidden_dim = cfg.textmlp_hidden_dim
        self.embedding_dim = cfg.embedding_dim
        self.dropout_prob = cfg.textmlp_dropout_prob
        if cfg.attention == "self":
            self.attention_name = "self"
            self.encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, batch_first=True)
            self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=cfg.num_attention_layers)
        elif cfg.attention == "cross":
            self.attention_name = "cross"
            tokenizer = BertTokenizer.from_pretrained(cfg.hf_textencoder_model_id)
            keywords_list = pd.read_csv("/workspace/data/physical_properties.csv")["Physical Properties"].tolist()
            self.keywords = tokenized_texts(keywords_list, tokenizer, device="cuda:0")
            self.cross_attention = CrossAttentionModule.TransformerEncoderLayer_SelfAndCrossAttention(
            layer_order_for_src1 = ["c"] * cfg.num_attention_layers,
            layer_order_for_src2 = ["n"] * cfg.num_attention_layers,
            d_model = 768,
            nhead = 8,
            dropout=0.1
            )
        elif cfg.attention == "scsc":
            self.attention_name = "scsc"
            tokenizer = BertTokenizer.from_pretrained(cfg.hf_textencoder_model_id)
            keywords_list = pd.read_csv("/workspace/data/physical_properties.csv")["Physical Properties"].tolist()
            self.keywords = tokenized_texts(keywords_list, tokenizer, device="cuda:0")
            self.self_and_cross_attention = CrossAttentionModule_blend.TransformerEncoderLayer_SelfAndCrossAttention(
            layer_order_for_src1 = ["s", "c", "s", "c"],
            layer_order_for_src2 = ["n", "n", "n", "n"],
            d_model = 768,
            nhead = 8,
            dropout=0.1
            )
        else:
            self.attention_name = "baseline"

    

        # Freeze all the parameters in the BERT model
        if cfg.freeze_text_encoders:
            for param in self.text_encoder.parameters():
                param.requires_grad = False
            
        # Create MLP layers with ReLU activations and BatchNorm except for the last layer, and add Dropout conditionally
        mlp_layers = []
        for i in range(self.n_layers):
            input_dim = self.embedding_dim if i == 0 else self.hidden_dim
            output_dim = self.embedding_dim if i == self.n_layers - 1 else self.hidden_dim
            mlp_layers.append(nn.Linear(input_dim, output_dim))
            if i < self.n_layers - 1: 
                mlp_layers.append(nn.BatchNorm1d(output_dim)) 
                mlp_layers.append(nn.ReLU())
                if self.dropout_prob > 0:  # Only add dropout if dropout probability is greater than zero
                    mlp_layers.append(nn.Dropout(self.dropout_prob))
        
        # Sequential container for our MLP layers
        self.mlp = nn.Sequential(*mlp_layers)

    def forward(self, data):
        input_ids = data.tokenized_title["input_ids"]
        attention_mask = data.tokenized_title["attention_mask"]
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        if self.attention_name == "self":
            outputs = self.transformer_encoder(outputs.pooler_output.unsqueeze(0))[0, :, :]
            if isinstance(data,TokenizedInput):
                print(outputs.shape)
                outputs = outputs[0, :].unsqueeze(0)
        elif self.attention_name == "cross":
            input_keywords_ids = self.keywords.tokenized_title["input_ids"]
            attention_keywords_mask = self.keywords.tokenized_title["attention_mask"]
            outputs_keywords =  self.text_encoder(input_ids=input_keywords_ids.to(outputs.pooler_output.device), attention_mask=attention_keywords_mask.to(outputs.pooler_output.device))
            outputs, _ = self.cross_attention(
                outputs.pooler_output.unsqueeze(1), 
                outputs_keywords.pooler_output.unsqueeze(0).repeat(outputs.pooler_output.shape[0], 1, 1),
                src1_key_padding_mask=None,
                src2_key_padding_mask=None
            )
            outputs = outputs[:, 0, :]
        elif self.attention_name == "scsc":
            input_keywords_ids = self.keywords.tokenized_title["input_ids"]
            attention_keywords_mask = self.keywords.tokenized_title["attention_mask"]
            outputs_keywords =  self.text_encoder(input_ids=input_keywords_ids.to(outputs.pooler_output.device), attention_mask=attention_keywords_mask.to(outputs.pooler_output.device))
            outputs, _ = self.self_and_cross_attention(
                outputs.pooler_output, 
                outputs_keywords.pooler_output.unsqueeze(0).repeat(outputs.pooler_output.shape[0], 1, 1),
                src1_key_padding_mask=None,
                src2_key_padding_mask=None
            )
            if isinstance(data,TokenizedInput):
                print(outputs.shape)
                outputs = outputs[0, :].unsqueeze(0)
        else:
            outputs = outputs.pooler_output


        
        # Pass the pooler_output through the MLP with batch normalization
        text_embeddings = self.mlp(outputs)


        
        return text_embeddings
