import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable

class TransformerModel(nn.Module):
    def __init__(
        self,
        word_embed_dim,
        encoder_dim,
        n_enc_layers,
        dpout_model,
        dpout_fc,
        fc_dim,
        n_classes,
        pool_type,
        linear_fc,
    ):
        super(TransformerModel, self).__init__()

        self.encoder_dim = encoder_dim      # 4096
        self.n_enc_layers = n_enc_layers
        self.dpout_fc = dpout_fc
        self.fc_dim = fc_dim                # 1024
        self.n_classes = n_classes
        self.linear_fc = linear_fc

        self.encoder = TransformerEncoder(
            word_embed_dim, self.encoder_dim, n_enc_layers, pool_type, dpout_model
        )

        self.inputdim = self.encoder_dim


        self.classifier = nn.Sequential(
            nn.Dropout(p=self.dpout_fc),
            nn.Linear(self.inputdim, self.fc_dim),
            nn.Tanh(),
            nn.Dropout(p=self.dpout_fc),
            nn.Linear(self.fc_dim, self.fc_dim),
            nn.Tanh(),
            nn.Dropout(p=self.dpout_fc),
            nn.Linear(self.fc_dim, self.n_classes),
        )

    def forward(self, s1):
        features = self.encoder(s1)
        output = self.classifier(features)
        return output

class TransformerEncoder(nn.Module):
    def __init__(
        self, word_embed_dim, encoder_dim, n_enc_layers, pool_type, dpout_model, n_heads=8, dim_feedforward=4096
    ):
        super(TransformerEncoder, self).__init__()
        self.word_embed_dim = word_embed_dim  # 300
        self.encoder_dim = encoder_dim        # 4096
        self.n_enc_layers = n_enc_layers
        self.pool_type = pool_type
        self.dpout_model = dpout_model
        self.n_heads = n_heads
        self.dim_feedforward = dim_feedforward


        self.input_linear = nn.Linear(self.word_embed_dim, self.encoder_dim)

        self.pos_encoder = PositionalEncoding(self.encoder_dim, self.dpout_model)

        encoder_layers = nn.TransformerEncoderLayer(
            d_model=self.encoder_dim,
            nhead=self.n_heads,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dpout_model
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=self.n_enc_layers)

    def forward(self, sent_tuple):
        sent, sent_len = sent_tuple

        device = sent.device
        if not isinstance(sent_len, torch.Tensor):
            sent_len = torch.tensor(sent_len, dtype=torch.long, device=device)
        else:
            sent_len = sent_len.to(device)

        sent = self.input_linear(sent)  # (seq_len, batch_size, encoder_dim)

        max_len = sent.size(0)
        batch_size = sent.size(1)
        mask = torch.arange(max_len, device=device).expand(batch_size, max_len) >= sent_len.unsqueeze(1)

        sent = self.pos_encoder(sent)

        output = self.transformer_encoder(sent, src_key_padding_mask=mask)

        if self.pool_type == "mean":
            sent_len = sent_len.unsqueeze(1).float()
            sum_output = torch.sum(output * (~mask.transpose(0, 1).unsqueeze(2)), dim=0)
            emb = sum_output / sent_len
        elif self.pool_type == "max":
            output = output.masked_fill(mask.transpose(0, 1).unsqueeze(2), float('-inf'))
            emb = torch.max(output, dim=0)[0]
        else:
            idx = (sent_len - 1).unsqueeze(0).unsqueeze(2).expand(1, batch_size, self.encoder_dim)
            emb = output.gather(0, idx).squeeze(0)

        return emb

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # (d_model/2)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (seq_len, batch_size, d_model)
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
