from model.base import LOBAutoEncoder 

import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTM_encoder(nn.Module):
    def __init__(self, enc_in):
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=128,
            hidden_size=128,
            num_layers=3,
            batch_first=True
        )
        
        self.proj=nn.Linear(enc_in,128)
        self.leakyReLU = nn.LeakyReLU()
        self.output_proj=nn.Linear(128,128)

    def forward(self, x):
        x = x.float()

        x=self.proj(x)
        output, (hn, _) = self.lstm(x)          # lstm with input, hidden, and internal state (batch, time-step, features)
        
        # before hn.shape = [1, batch_size, features]
        hn=hn[-1]
        hn = hn.view(-1, 128)  # reshaping the data for Dense layer next
        # after hn.shape = [batch_size, features]
        
        out=self.output_proj(hn)
        return out
    
class LSTM_AE(LOBAutoEncoder):
    def __init__(self,
                 d_model,
                 seq_len,
                 enc_in,
                 **kwargs):
        super().__init__(**kwargs)
        self.seq_len = seq_len
        self.d_model = d_model
        self.enc_in = enc_in
        self.encoder = LSTM_encoder(enc_in=enc_in)
        # self.linear_encoding = nn.Linear(in_features=d_model, out_features=unified_d)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead=8, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        self.projection = nn.Linear(d_model, self.seq_len * self.enc_in, bias=True)

    def encode(self, x_enc):
        enc_out = self.encoder(x_enc)
        # enc_out = self.linear_encoding(enc_out.view((-1, self.seq_len * self.enc_in)))
        return enc_out
    
    def forward(self, x):
        enc_out = self.encoder(x)
        enc_out = enc_out.unsqueeze(1)
        # final
        memory = torch.zeros(enc_out.shape[0], enc_out.shape[1], enc_out.shape[2], device=enc_out.device)
        enc_out = self.decoder(enc_out, memory)
        out = self.projection(enc_out)
        out = out.view(out.shape[0], self.seq_len, self.enc_in)
        return out