"""LSTM layer."""

import torch.nn as nn


class LSTMLayer(nn.Module):
    """LSTM layer for LM."""
    def __init__(
            self,
            d_model,
            n_layer=1,
            dropout=0.1,
            layer_idx=None,
            device=None,
            dtype=None,
            reinit=True):
        """Initalize LSTM Layer."""
        super().__init__()
        self.d_model = d_model
        self.lstm = nn.LSTM(input_size=d_model, hidden_size=d_model, num_layers=n_layer, batch_first=True, dropout=dropout, bidirectional=False).to(device)

        self.layer_idx = layer_idx
        # self.norm = nn.LayerNorm(d_model).to(device)

        if reinit:
            self._reinitialize()

    def _reinitialize(self):
        """Reinitialize LSTM weights with good practices."""
        for name, p in self.lstm.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(p.data)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(p.data)
            elif 'bias_ih' in name:
                p.data.fill_(0)
                # Set forget-gate bias to 1
                hidden_size = p.size(0) // 4
                p.data[hidden_size:2 * hidden_size] = 1.0
            elif 'bias_hh' in name:
                p.data.fill_(0)

    def forward(self, x):
        """Forward method for LSTM layer."""
        output, _ = self.lstm(x)
        return output
