import torch
from torch import nn

import lightning as L


class LSTMNet(L.LightningModule):
    def __init__(
        self,
        seq_len: int = 100,
        input_size: int = 45,
        hidden_size: int = 128,
        output_size: int = 33,
        num_layers: int = 2,
        hidden_size_fc2: int = 256,
        bidirectional: bool = False,
        dropout: float = 0.0
    ) -> None:
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.L: int = seq_len
        self.H_in: int = input_size
        self.H_cell: int = hidden_size
        self.H_out: int = hidden_size
        self.D: int = 2 if bidirectional else 1
        self.num_layers: int = num_layers

        self.H_in_fc: int = hidden_size * self.D
        self.H_out_fc: int = output_size

        self.linear_h_in = nn.Linear(self.H_in, self.H_cell * self.num_layers)
        self.linear_c_in = nn.Linear(self.H_in, self.H_cell * self.num_layers)
        # Input shape: (N, L, H_in)
        # LSTM output shape: (N, L, D * H_out)
        self.lstm = nn.LSTM(
            input_size=self.H_in,
            hidden_size=self.H_cell,
            num_layers=self.num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout
        )
        self.linear = nn.Linear(self.H_in_fc, hidden_size_fc2)
        self.linear2 = nn.Linear(hidden_size_fc2, self.H_out_fc)

    def forward(self, x: torch.Tensor, hiddens = None) -> torch.Tensor:
        # Set initial hidden and cell states
        #h_0 = self.linear_h_in(x[:, 0:1]).view(self.num_layers, -1, self.H_cell) # The trick from PIP, set the initial hidden state with an fc network
        """
            This FC neural network might contribute to "cheating" in the first node -> giving the model a high speed for example
            We could regulate this by self-supervising the neural network on later lstm hidden states (?)
        """
        #c_0 = self.linear_c_in(x[:, 0:1]).view(self.num_layers, -1, self.H_cell)

        h_0 = torch.randn(self.num_layers * self.D, x.size(0), self.H_out).to(x.device) * 0.01
        c_0 = torch.randn(self.num_layers * self.D, x.size(0), self.H_out).to(x.device) * 0.01
        if hiddens is not None:
            h_0 = hiddens[0]
            c_0 = hiddens[1]
        out, (h_t, c_t) = self.lstm(x, (h_0, c_0))
        # out: tensor of shape (N, L, D * H_out)
        out = self.linear2(torch.nn.functional.relu(self.linear(out)))
        if hiddens is not None:
            return out, (h_t, c_t)
        return out


if __name__ == "__main__":
    _ = LSTMNet()
