import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class Combiner(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embedding_size = config.embedding_size
        self.rnn_units = config.rnn_units
        self.num_layers = config.num_layers
        self.dropout_prob = config.dropout_prob
        self.output_dim = config.output_dim

        self.lstm = nn.LSTM(
            self.embedding_size,
            self.rnn_units // 2,
            self.num_layers,
            bidirectional=True,
            dropout=self.dropout_prob,
            batch_first=True
        )
        if self.rnn_units != self.output_dim:
            self.linear = nn.Linear(self.rnn_units, self.output_dim)

    def forward(self, embeddings, word_mask):
        lengths = torch.sum(word_mask, dim=-1).cpu().numpy()
        lengths[lengths == 0] = 1
        x_pack = pack_padded_sequence(embeddings, lengths, batch_first=True, enforce_sorted=False)
        h_pack, _ = self.lstm(x_pack)
        h, _ = pad_packed_sequence(h_pack, batch_first=True)
        if hasattr(self, 'linear'):
            h = self.linear(h)
        return h
