#! -*- coding: utf-8
import typing
from collections import OrderedDict

import torch


class TextClassifierLSTMModel(torch.nn.Module):
    def __init__(self,
                 # embeddings param
                 vocab_size: int = 30522, emb_hidden_size: int = 512,
                 # lstm param
                 nrecurrent: int = 2, rnn_hidden_size: int = 512,
                 pad_token_id: int = 0,
                 emb_dropout: float = 0.5, rnn_dropout: float = 0.3,
                 eps: float = 1e-8,
                 num_classes: int = 20):
        super().__init__()

        self.embeddings = torch.nn.Sequential(OrderedDict([
            ("embeddings", torch.nn.Embedding(vocab_size, emb_hidden_size,
                                              padding_idx=pad_token_id)),
            ("embeddings_norm", torch.nn.LayerNorm(emb_hidden_size, eps=eps)),
            ("embeddings_activation", torch.nn.ReLU()),
            ("embeddings_dropout", torch.nn.Dropout(p=emb_dropout)),
        ]))

        rnn_hiddens = [emb_hidden_size] + [rnn_hidden_size] * nrecurrent
        self.rnns = torch.nn.Sequential(*[
            torch.nn.LSTM(input_size, hidden_size,
                          num_layers=1, batch_first=True)
            for input_size, hidden_size in zip(rnn_hiddens, rnn_hiddens[1:])])
        self.rnn_activ = torch.nn.ReLU()
        self.rnn_dropout = torch.nn.Dropout(p=rnn_dropout)

        self.proj = torch.nn.Linear(rnn_hidden_size, num_classes)

        self.init_weights()

    def init_weights(self):
        self.embeddings[0].weight.data.uniform_(-0.1, 0.1)
        self.proj.bias.data.fill_(0.0)
        self.proj.weight.data.uniform_(-0.1, 0.1)

    def forward(self, input: torch.Tensor, eos_indices: torch.Tensor = None,
                hidden: typing.Iterable[torch.Tensor] = None):
        emb = self.embeddings(input)

        hiddens = []
        inputs = emb
        for i, rnn in enumerate(self.rnns):
            rnn.flatten_parameters()
            h = hidden[i] if hidden is not None and len(hidden) > 0 else None
            o, h = rnn(inputs, h)
            # o = self.rnn_norm(o)
            o = self.rnn_activ(o)
            o = self.rnn_dropout(o)
            inputs = o
            hiddens.append(h)

        if eos_indices is None:
            o = o[:, -1]
        else:
            o = torch.stack([o[i, eos] for i, eos in enumerate(eos_indices)])
        o = self.rnn_dropout(o)
        o = self.proj(o)  # projection layer: hidden -> num_classes
        return o, hiddens
