import torch
import torch.nn.functional as F
from torch import nn
from .subNets.transformers_encoder.transformer import TransformerEncoder
from .subNets.BertTextEncoder import BertTextEncoder


class MULTModel(nn.Module):
    """
    Implements the MultimodalTransformer Model.
    
    See https://github.com/yaohungt/Multimodal-Transformer for more.
    """
    def __init__(self, n_modalities, n_features, args):
        """Construct a MulT model."""
        super().__init__()
        self.n_modalities = n_modalities
        self.embed_dim = args.embed_dim
        if args.use_bert:
            self.text_model = BertTextEncoder(use_finetune=args.use_finetune, transformers=args.transformers,pretrained=args.pretrained)
        self.use_bert = args.use_bert
        self.text_dropout = args.text_dropout
        
        self.num_heads = args.num_heads
        self.layers = args.nlevels
        self.attn_dropout = args.attn_dropout
        self.attn_dropout_modalities = args.attn_dropout_modalities
        self.relu_dropout = args.relu_dropout
        self.res_dropout = args.res_dropout
        self.out_dropout = args.out_dropout
        self.embed_dropout = args.embed_dropout
        self.attn_mask = args.attn_mask
        self.all_steps = args.all_steps

        combined_dim = self.embed_dim * self.n_modalities * (self.n_modalities-1) 

        output_dim = args.num_classes if args.train_mode == "classification" else 1

        # 1. Temporal convolutional layers
        self.proj = [nn.Conv1d(n_features[i], self.embed_dim, kernel_size=1,
                               padding=0, bias=False) for i in range(n_modalities)]
        self.proj = nn.ModuleList(self.proj)

        # 2. Crossmodal Attentions
        self.trans = [nn.ModuleList([self.get_network(i, j, mem=False) for j in range(
            n_modalities) if i != j]) for i in range(n_modalities)] 
        self.trans = nn.ModuleList(self.trans)

        # 3. Self Attentions (Could be replaced by LSTMs, GRUs, etc.)
        self.trans_mems = [self.get_network(
            i, i, mem=True, layers=3) for i in range(n_modalities)]
        self.trans_mems = nn.ModuleList(self.trans_mems)

        # Projection layers
        self.proj1 = nn.Linear(combined_dim, combined_dim)
        self.proj2 = nn.Linear(combined_dim, combined_dim)
        self.out_layer = nn.Linear(combined_dim, output_dim)

    def get_network(self, mod1, mod2, mem, layers=-1):
        """Create TransformerEncoder network from layer information."""
        if not mem:
            embed_dim = self.embed_dim
            attn_dropout = self.attn_dropout_modalities[mod2]
        else:
            embed_dim = (self.n_modalities-1) * self.embed_dim 
            attn_dropout = self.attn_dropout

        return TransformerEncoder(embed_dim=embed_dim,
                                  num_heads=self.num_heads,
                                  layers=max(self.layers, layers),
                                  attn_dropout=attn_dropout,
                                  relu_dropout=self.relu_dropout,
                                  res_dropout=self.res_dropout,
                                  embed_dropout=self.embed_dropout,
                                  attn_mask=self.attn_mask)

    def forward(self, x):
        """
        Apply MultModel Module to Layer Input.
        
        Args:
            x: layer input. Has size n_modalities * [batch_size, seq_len, n_features]
        """
        proj_x = []
        for i, v in enumerate(x): 
            if self.use_bert and i == 0:
                with torch.no_grad():
                    v = self.text_model(v)
                    v = F.dropout(v.transpose(1, 2), p=self.text_dropout, training=self.training)
                    v = self.proj[i](v) 
            else:
                v = v.permute(0, 2, 1)
                v = self.proj[i](v)
            proj_x.append(v)
    
        proj_x = torch.stack(proj_x)
        proj_x = proj_x.permute(0, 3, 1, 2)

        hs = []
        last_hs = []
        for i in range(self.n_modalities):
            h = []
            cur_idx = 0
            for j in range(self.n_modalities):
                if i != j:
                    h.append(self.trans[i][cur_idx](proj_x[i], proj_x[j], proj_x[j]))
                    cur_idx += 1
            h = torch.cat(h, dim=2)  # feature dimension
            h = self.trans_mems[i](h)

            if self.all_steps:
                hs.append(h)
            else:
                last_hs.append(h[-1])

        if self.all_steps:
            out = torch.cat(hs, dim=2)  
            out = out.permute(1, 0, 2) 
        else:
            out = torch.cat(last_hs, dim=1)

        # A residual block
        out_proj = self.proj2(
            F.dropout(F.relu(self.proj1(out)), p=self.out_dropout, training=self.training))
        out_proj += out

        out = self.out_layer(out_proj)
        return out, last_hs