import torch
import torch.nn as nn
import numpy as np
import math

#######################################################################################################################
# 1. LateFuseBERT
#######################################################################################################################

class LateFuseBERT(nn.Module):
    def __init__(self,
                 text_model,
                 cat_vocab_sizes,
                 num_cat_var,
                 num_numerical_var,
                 d_model,
                 n_heads,
                 n_layers, 
                 dropout,
                 d_fc,
                 n_classes):
        # super constructor
        super().__init__()

        # attributes
        self.text_model = text_model
        self.cat_vocab_sizes = cat_vocab_sizes # list of vocabulary sizes for categorical variables
        self.num_cat_var = num_cat_var # number of categorical variables
        self.num_numerical_var = num_numerical_var # number of numerical variables
        self.n_heads = n_heads # number of attention heads
        self.d_model = d_model # embedding dimension
        self.n_layers = n_layers # number of encoder layers
        self.dropout = dropout # dropout rate
        self.cat_dropout = nn.Dropout(dropout) # dropout after cat embedding
        self.d_fc = d_fc # dimension of hidden layer in final fully connected layer
        self.n_classes = n_classes # number of classes
        
        # categorical embeddings
        self.cat_embeddings = nn.ModuleList([nn.Embedding(num_embeddings=self.cat_vocab_sizes[i], embedding_dim=self.d_model, padding_idx = 0) for i in range(self.num_cat_var)])
    
        # linear mapper for numerical data
        self.num_linears = nn.ModuleList([nn.Linear(1, self.d_model) for i in range(self.num_numerical_var)])
        
        # classification token [CLS], is learnable
        self.tab_cls = nn.Parameter(data=torch.rand(1, self.d_model), requires_grad=True)
          
        # Self Attention Transformer encoder
        self.tab_encoder_layers = nn.TransformerEncoderLayer(self.d_model, self.n_heads, 
                                                              self.d_model, dropout=self.dropout, batch_first=True)
        self.tab_transformer_encoder = nn.TransformerEncoder(self.tab_encoder_layers, self.n_layers)
        
        # last fully connected network
        self.fc1 = nn.Sequential(nn.Linear(2*self.d_model, self.d_fc),
                                nn.ReLU(),
                                nn.Dropout(self.dropout),
                                nn.Linear(self.d_fc, self.n_classes))
        
        # weight initialization
        self.init_weights()

    def init_weights(self):
        # embeddings
        for i in range(self.num_cat_var):
            nn.init.kaiming_uniform_(self.cat_embeddings[i].weight)
        # numerical linear
        for i in range(self.num_numerical_var):
            nn.init.zeros_(self.num_linears[i].bias)
            nn.init.kaiming_uniform_(self.num_linears[i].weight)
        # final FC network
        nn.init.zeros_(self.fc1[0].bias)
        nn.init.kaiming_uniform_(self.fc1[0].weight)
        nn.init.zeros_(self.fc1[3].bias)
        nn.init.kaiming_uniform_(self.fc1[3].weight)


    def forward(self, texts, attention_mask, categoricals, numericals):

        # 1. reshape categoricals for embeddings and numericals before linear transformation 
        categorical_list = [categoricals[:,i].unsqueeze(dim=1) for i in range(self.num_cat_var)]
        numerical_list = [numericals[:,i].unsqueeze(dim=1).unsqueeze(dim=1) for i in range(self.num_numerical_var)]
        
        # 2. embedding layers
        cat_embedding_list = [self.cat_embeddings[i](categorical_list[i]) for i in range(self.num_cat_var)]
        categoricals = torch.cat([cat_embedding_list[i] for i in range(self.num_cat_var)], dim = 1)
        categoricals = self.cat_dropout(categoricals)
        numerical_embedding_list = [self.num_linears[i](numerical_list[i].float()) for i in range(self.num_numerical_var)]
        numericals = torch.cat([numerical_embedding_list[i] for i in range(self.num_numerical_var)], dim = 1)
        tabulars = torch.cat([categoricals, numericals], dim = 1) # concatenate categorical and numerical embeddings
        
        # 3. add classification token [CLS], * sqrt(d) prevent these input embeddings from becoming excessively small
        tabulars = torch.stack([torch.vstack((self.tab_cls, tabulars[i])) for i in range(len(tabulars))]) * math.sqrt(self.d_model)
        
        # 4. Self attention Transformer encoder (tabular stream)
        tabulars = self.tab_transformer_encoder(tabulars)
   
        # 5. text model prediction
        texts = self.text_model(texts, attention_mask = attention_mask).last_hidden_state

        # 6. Concatenate CLS tokens
        text_cls = texts[:,0,:]
        tabular_cls = tabulars[:,0,:]
        mm_cls = torch.cat([text_cls, tabular_cls], dim = 1)

        # 7. Fully connected network for classification purpose
        pred = self.fc1(mm_cls)
        
        return pred, text_cls, tabular_cls

#######################################################################################################################
# 2. AllTextBERT
#######################################################################################################################

class AllTextBERT(nn.Module):
    def __init__(self,
                 text_model,
                 d_model,
                 dropout,
                 d_fc,
                 n_classes):
        # super constructor
        super().__init__()

        # attributes
        self.text_model = text_model
        self.d_model = d_model
        self.d_fc = d_fc # dimension of hidden layer in final fully connected layer
        self.dropout = dropout # dropout rate
        self.n_classes = n_classes # number of classes
        
        # last fully connected network
        self.fc1 = nn.Sequential(nn.Linear(self.d_model, self.d_fc),
                                nn.ReLU(),
                                nn.Dropout(self.dropout),
                                nn.Linear(self.d_fc, self.n_classes))

        
        # weight initialization
        self.init_weights()

    def init_weights(self):
        # final FC network
        nn.init.zeros_(self.fc1[0].bias)
        nn.init.kaiming_uniform_(self.fc1[0].weight)
        nn.init.zeros_(self.fc1[3].bias)
        nn.init.kaiming_uniform_(self.fc1[3].weight)

    def forward(self, texts, attention_mask, categoricals, numericals):

        # 1. text model prediction
        texts = self.text_model(texts, attention_mask = attention_mask).last_hidden_state

        # 2. Extract CLS tokens
        text_cls = texts[:,0,:]
        
        # 3. Logits
        text_pred = self.fc1(text_cls)

        return text_pred, text_cls    
            
##############################################################################################################

def init_model(model_type, d_model, cat_vocab_sizes, 
               num_cat_var, num_numerical_var, n_heads,
               n_layers, dropout, d_fc, n_classes, seed, text_model):
                   

    if model_type == "LateFuseBERT":
        torch.manual_seed(seed)
        model = LateFuseBERT(text_model,
                             cat_vocab_sizes,
                             num_cat_var,
                             num_numerical_var,
                             d_model,
                             n_heads,
                             n_layers, 
                             dropout,
                             d_fc,
                             n_classes)
                                 
    if model_type == "AllTextBERT":
        torch.manual_seed(seed)
        model = AllTextBERT(text_model,
                            d_model,
                            dropout,
                            d_fc,
                            n_classes)
                   
    
    return model
