
import torch
import torch.nn as nn

from allennlp.nn.util import masked_softmax
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper

from allennlp.nn.util import get_text_field_mask

class Attention(nn.Module):
    def __init__(self, hidden_dim, attn_dim):
        super(Attention, self).__init__()
        self.linear_1 = nn.Linear(hidden_dim, attn_dim)
        self.linear_2 = nn.Linear(attn_dim, 1)

        self.tanh = nn.Tanh()

    def forward(self, hidden_states, mask):
        lin_out = self.tanh(self.linear_1(hidden_states))
        final_out = self.linear_2(lin_out)
        masked_scores = masked_softmax(final_out, mask.unsqueeze(-1), dim=1)
        return masked_scores


class TextEncoder(nn.Module):
    def __init__(self, word_embeddings, input_dim, hidden_dim, attn_dim):
        super(TextEncoder, self).__init__()
        self.word_embeddings = word_embeddings
        self.hidden_dim = hidden_dim
        self.bilstm = PytorchSeq2SeqWrapper(nn.LSTM(input_dim,
                                                    hidden_dim,
                                                    batch_first=True,
                                                    bidirectional=True))
        self.attn = Attention(hidden_dim * 2, attn_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, sentence):
        mask = get_text_field_mask(sentence)
        sentence_embs = self.word_embeddings(sentence)
        hidden_states = self.bilstm(sentence_embs, mask)
        attn_scores = self.attn(hidden_states, mask)
        context = torch.sum(hidden_states * attn_scores, 1)

        # dropout
        context = self.dropout(context)

        return context