import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence

PAD_CHAR = 0
NUM_CHARS = 256


class WordGRU(nn.Module):
    def __init__(self, msg_hdim, msg_edim, vocab_size=1000):
        super().__init__()
        self.msg_hdim = msg_hdim
        self.msg_edim = msg_edim

        self.emb = nn.Embedding(vocab_size, self.msg_edim, padding_idx=0)
        self.rnn = nn.GRU(self.msg_edim, self.msg_hdim, batch_first=True)

    def forward(self, messages, messages_len):
        messages_emb = self.emb(messages)
        packed_input = pack_padded_sequence(
            messages_emb, messages_len.cpu(), enforce_sorted=False, batch_first=True
        )
        _, hidden = self.rnn(packed_input)
        return hidden[0]
