import torch.nn as nn


class MLPAE(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim, n_hidden_layers, final_activation):
        super(MLPAE, self).__init__()

        self.n_hidden_layers = n_hidden_layers

        self.encoder_input_layer = nn.Linear(
            in_features=input_dim, out_features=hidden_dim
        )
        self.encoder_hidden_layers = nn.ModuleList(
            [nn.Linear(in_features=hidden_dim, out_features=hidden_dim) for _ in range(n_hidden_layers)]
        )
        self.encoder_output_layer = nn.Linear(
            in_features=hidden_dim, out_features=latent_dim
        )

        self.decoder_input_layer = nn.Linear(
            in_features=latent_dim, out_features=hidden_dim
        )
        self.decoder_hidden_layers = nn.ModuleList(
            [nn.Linear(in_features=hidden_dim, out_features=hidden_dim) for _ in range(n_hidden_layers)]
        )
        self.decoder_output_layer = nn.Linear(
            in_features=hidden_dim, out_features=input_dim
        )

        self.activation = nn.ReLU()
        self.final_activation = final_activation

    def forward(self, x):

        x = self.activation(self.encoder_input_layer(x))
        for i in range(self.n_hidden_layers):
            x = self.activation(self.encoder_hidden_layers[i](x))
        x = self.activation(self.encoder_output_layer(x))
        latent_vec = x

        x = self.activation(self.decoder_input_layer(x))
        for i in range(self.n_hidden_layers):
            x = self.activation(self.decoder_hidden_layers[i](x))
        x = self.final_activation(self.decoder_output_layer(x))

        return x, latent_vec


class CnnAE(nn.Module):
    def __init__(self, latent_dim, channels):
        super(CnnAE, self).__init__()

        self.encoder = nn.Sequential(
            # Layer 1
            nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=(3, 3), stride=1, padding=1, bias=True),
            nn.BatchNorm2d(channels),
            nn.ReLU(),

            # Layer 2
            nn.Conv2d(in_channels=channels, out_channels=2*channels, kernel_size=(3, 3), stride=1, padding=1,
                      bias=True),
            nn.BatchNorm2d(channels * 2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2)),

            # Layer 3
            nn.Conv2d(in_channels=2*channels, out_channels=4*channels, kernel_size=(3, 3), stride=1, padding=1,
                      bias=True),
            nn.BatchNorm2d(channels * 4),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Flatten(),
            nn.Linear(4*channels*7*7, latent_dim)
        )

        self.decoder = nn.Sequential(
            # Layer 1
            nn.Linear(latent_dim, 4*channels*7*7),
            nn.Unflatten(-1, (4*channels, 7, 7)),
            nn.ConvTranspose2d(in_channels=4*channels, out_channels=2*channels, kernel_size=(3, 3), stride=2, padding=1,
                               output_padding=1, bias=True),
            nn.BatchNorm2d(channels * 2),
            nn.ReLU(),

            # Layer 2
            nn.ConvTranspose2d(in_channels=2*channels, out_channels=channels, kernel_size=(3, 3), stride=2, padding=1,
                               output_padding=1, bias=True),
            nn.BatchNorm2d(channels),
            nn.ReLU(),

            # Layer 3
            nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=(3, 3), stride=1, padding='same', bias=True),
            nn.BatchNorm2d(3),
            nn.Identity()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x, x