
import torch
import torch.nn as nn

from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, masked_mean

from model.attention import Attention

class MentionEncoder(nn.Module):
    def __init__(self, word_embeddings, input_dim, hidden_dim, attn_dim, features_to_idx=None, device=None):
        super(MentionEncoder, self).__init__()
        self.word_embeddings = word_embeddings
        self.hidden_dim = hidden_dim
        self.left_bilstm = PytorchSeq2SeqWrapper(nn.LSTM(input_dim,
                                                    hidden_dim,
                                                    batch_first=True,
                                                    bidirectional=True))
        self.right_bilstm = PytorchSeq2SeqWrapper(nn.LSTM(input_dim,
                                                    hidden_dim,
                                                    batch_first=True,
                                                    bidirectional=True))

        self.attn = Attention(hidden_dim * 2, attn_dim)

        if features_to_idx is not None:
            self.feat_embs = nn.Embedding(len(features_to_idx) + 1, 50, padding_idx=0)
            self.feat_to_idx = features_to_idx

            self.mention_dim = 2 * hidden_dim + input_dim + 50 # 50 is for features
        else:
            self.feat_to_idx = None
            self.mention_dim = 2 * hidden_dim + input_dim

        self.dropout = nn.Dropout(0.5)

    def forward(self, mention_tokens, left_tokens,
                right_tokens, features=None):

        mention_mask = get_text_field_mask(mention_tokens)
        mention_embs = self.word_embeddings(mention_tokens)
        avg_mentions = masked_mean(mention_embs, mention_mask.unsqueeze(-1), dim=1)
        avg_mentions = self.dropout(avg_mentions)

        left_mask = get_text_field_mask(left_tokens)
        left_embs = self.word_embeddings(left_tokens)
        left_hidden_states = self.left_bilstm(left_embs, left_mask)

        right_mask = get_text_field_mask(right_tokens)
        right_embs = self.word_embeddings(right_tokens)
        right_hidden_states = self.right_bilstm(right_embs, right_mask)

        full_mask = torch.cat([left_mask, right_mask], dim=1)
        full_hidden = torch.cat([left_hidden_states, right_hidden_states], dim=1)


        attn_scores = self.attn(full_hidden, full_mask)
        context = torch.sum(full_hidden*attn_scores, dim=1)

        if self.feat_to_idx is not None:
            feat_rep = torch.sum(self.feat_embs(features), dim=1)
            feat_rep = self.dropout(feat_rep)

            mention_rep = torch.cat([
                avg_mentions,
                context,
                feat_rep
            ], dim=1)
        else:
            mention_rep = torch.cat([
                avg_mentions,
                context
            ], dim=1)

        return mention_rep
