from model.base import LOBAutoEncoder 

import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN2_encoder(nn.Module):
    def __init__(self, temp=264, padding=2, **kwargs):
        super().__init__(**kwargs)

        # Convolution 1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(10, 20), padding=(0, padding)) # kernel size is different from the original code
        self.bn1 = nn.BatchNorm2d(16)
        self.prelu1 = nn.PReLU()

        # Convolution 2
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=16, kernel_size=(10,))  # 3
        self.bn2 = nn.BatchNorm1d(16)
        self.prelu2 = nn.PReLU()

        # Convolution 3
        self.conv3 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=(8,))  # 1
        self.bn3 = nn.BatchNorm1d(32)
        self.prelu3 = nn.PReLU()

        # Convolution 4 
        self.conv4 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=(6,))  # 1
        self.bn4 = nn.BatchNorm1d(32)
        self.prelu4 = nn.PReLU()

        # Convolution 5
        self.conv5 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=(4,))  # 1
        self.bn5 = nn.BatchNorm1d(32)
        self.prelu5 = nn.PReLU()

        # Fully connected 1
        self.fc1 = nn.Linear(temp*32, 128)
        self.prelu6 = nn.PReLU()

    def forward(self, x):
        # Adding the channel dimension
        x = x[:, None, :]  # x.shape = [batch_size, 1, 100, 40]

        # print('x.shape:', x.shape)

        # Convolution 1
        out = self.conv1(x)
        # print('After convolution1:', out.shape)

        out = self.bn1(out)
        # print('After bn1:', out.shape)

        out = self.prelu1(out)
        out = out.reshape(out.shape[0], out.shape[1], -1)
        # print('After prelu1:', out.shape)

        # Convolution 2
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.prelu2(out)
        # print('After convolution2, bn2, prelu2:', out.shape)

        # Convolution 3
        out = self.conv3(out)
        out = self.bn3(out)
        out = self.prelu3(out)
        # print('After convolution3, bn3, prelu3:', out.shape)

        # Convolution 4
        out = self.conv4(out)
        out = self.bn4(out)
        out = self.prelu4(out)
        # print('After convolution4, bn4, prelu4:', out.shape)

        # Convolution 5
        out = self.conv5(out)
        out = self.bn5(out)
        out = self.prelu5(out)
        # print('After convolution5, bn5, prelu5:', out.shape)

        # flatten
        out = out.view(out.size(0), -1)
        # print('After flatten:', out.shape)

        # Linear function 1
        out = self.fc1(out)
        out = self.prelu6(out)
        # print('After fc1:', out.shape)

        return out
    
class CNN2_AE(LOBAutoEncoder):
    def __init__(self,
                 d_model,
                 seq_len,
                 enc_in,
                 padding,
                 temp=264,
                 **kwargs):
        super().__init__(**kwargs)
        self.seq_len = seq_len
        self.d_model = d_model
        self.enc_in = enc_in
        self.encoder = CNN2_encoder(temp=temp, padding=padding)
        # self.linear_encoding = nn.Linear(in_features=d_model, out_features=unified_d)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead=8, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        self.projection = nn.Linear(d_model, self.seq_len * self.enc_in, bias=True)
 
    def encode(self,x_enc):
        enc_out = self.encoder(x_enc)
        # enc_out = self.linear_encoding(enc_out)
        return enc_out
    
    def forward(self, x):
        enc_out = self.encoder(x)
        enc_out = enc_out.unsqueeze(1)
        # final
        memory = torch.zeros(enc_out.shape[0], enc_out.shape[1], enc_out.shape[2], device=enc_out.device)
        enc_out = self.decoder(enc_out, memory)
        out = self.projection(enc_out)
        out = out.view(out.shape[0], self.seq_len, self.enc_in)
        return out   
    
