
import torch
import torch.nn as nn

from allennlp.nn.util import masked_softmax
from allennlp.nn.util import get_text_field_mask
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper

from model.attention import Attention

class TextEncoder(nn.Module):
    def __init__(self, word_embeddings, input_dim, hidden_dim, attn_dim):
        super(TextEncoder, self).__init__()
        self.word_embeddings = word_embeddings
        self.hidden_dim = hidden_dim
        self.bilstm = PytorchSeq2SeqWrapper(nn.LSTM(input_dim,
                                                    hidden_dim,
                                                    batch_first=True,
                                                    bidirectional=True))
        self.attn = Attention(hidden_dim * 2, attn_dim)

    def forward(self, sentence):
        mask = get_text_field_mask(sentence)
        sentence_embs = self.word_embeddings(sentence)
        hidden_states = self.bilstm(sentence_embs, mask)
        attn_scores = self.attn(hidden_states, mask)
        context = torch.sum(hidden_states * attn_scores, 1)
        return context
