import torch.nn as nn


class RNNModel(nn.Module):

    def __init__(self, num_tokens, embed_size,
                 num_hidden, num_layers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(num_tokens, embed_size)
        self.rnn = nn.LSTM(embed_size, num_hidden, num_layers, dropout=dropout)
        self.decoder = nn.Linear(num_hidden, num_tokens)
        self.init_weights()
        self.num_layers = num_layers
        self.num_hidden = num_hidden

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        # input size(bptt, batch_size)
        encoded_sequence = self.encoder(input)
        # embedding size(bptt, batch_size, embed_size)
        # hidden size(layers, batch_size, num_hidden)
        output, hidden = self.rnn(encoded_sequence, hidden)
        # output size(bptt, batch_size, num_hidden)
        # output = self.drop(output)
        # decoder: num_hidden -> num_tokens
        decoded = self.decoder(output.view(
                        output.size(0)*output.size(1), output.size(2))
                    )
        return decoded, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (
            weight.new(self.num_layers, batch_size, self.num_hidden).zero_(),
            weight.new(self.num_layers, batch_size, self.num_hidden).zero_())
        return hidden
