#! -*- coding: utf-8
import typing
from collections import OrderedDict

import torch


class PTBLSTMModel(torch.nn.Module):
    def __init__(self,
                 # embeddings param
                 vocab_size: int = 10001, emb_hidden_size: int = 400,
                 # lstm param
                 nrecurrent: int = 3, rnn_hidden_size: int = 1150,
                 pad_token_id: int = 0,
                 emb_dropout: float = 0.5, rnn_dropout: float = 0.3,
                 eps: float = 1e-8,
                 tie_weights: bool = False):
        super().__init__()

        emb = torch.nn.Embedding(vocab_size, emb_hidden_size)
        self.embeddings = torch.nn.Sequential(OrderedDict([
            ("embeddings", emb),
            ("embeddings_dropout", torch.nn.Dropout(p=emb_dropout)),
        ]))

        self.rnn = torch.nn.LSTM(emb_hidden_size, rnn_hidden_size, num_layers=nrecurrent,
                                 dropout=rnn_dropout, batch_first=True)
        self.rnn_dropout = torch.nn.Dropout(p=rnn_dropout)

        self.proj = torch.nn.Linear(rnn_hidden_size, vocab_size)
        if tie_weights:
            assert emb.weight.shape == self.proj.weight.shape, \
                f"must euqals embeddings hidden size and rnn hidden size."
            self.proj.weight = emb.weight

        self.init_weights()

    def init_weights(self):
        self.embeddings[0].weight.data.uniform_(-0.1, 0.1)
        self.proj.bias.data.fill_(0.0)
        self.proj.weight.data.uniform_(-0.1, 0.1)

    def forward(self, input: torch.Tensor, hidden: typing.Iterable[torch.Tensor] = None):
        emb = self.embeddings(input)
        self.rnn.flatten_parameters()
        o, h = self.rnn(emb, hidden)
        o = self.rnn_dropout(o)
        o = self.proj(o)  # projection layer: hidden -> vocab
        return o, h
