import torch
import torch.nn as nn

from allennlp.nn.util import masked_softmax

class Attention(nn.Module):
    def __init__(self, hidden_dim, attn_dim):
        super(Attention, self).__init__()
        self.linear_1 = nn.Linear(hidden_dim, attn_dim)
        self.linear_2 = nn.Linear(attn_dim, 1)

        self.tanh = nn.Tanh()

    def forward(self, hidden_states, mask):
        lin_out = self.tanh(self.linear_1(hidden_states))
        final_out = self.linear_2(lin_out)
        masked_scores = masked_softmax(final_out, mask.unsqueeze(-1), dim=1)
        return masked_scores
