import torch.nn as nn
import pdb
from dl.models.layers.base import CNNTrasnposedLayer

class Decoder1C(nn.Module):
    def __init__(self, config):
        super(Decoder1C, self).__init__()
        modules = []
        self.latent_dim = config.latent_dim
        self.hidden_states = config.hidden_states
        self.dense1 = nn.Linear(self.latent_dim, config.dense_dim[1])
        self.dense2 = nn.Linear(config.dense_dim[1], 4 * config.dense_dim[0])
        self.relu = nn.ReLU(True)
        self.active = nn.Sigmoid()

        modules.append(CNNTrasnposedLayer(in_channels=64, out_channels=64))
        modules.append(CNNTrasnposedLayer(in_channels=64, out_channels=32))
        modules.append(CNNTrasnposedLayer(in_channels=32, out_channels=32))
        modules.append(
            nn.ConvTranspose2d(
                in_channels=32, out_channels=1, kernel_size=4, stride=2, padding=1
            )
        )
        self.hidden_layers = nn.ModuleList(modules)

    def forward(self, input):
        # input: [batch, latent_dim]
        all_hidden_states = ()
        # pdb.set_trace()
        output = self.dense1(input)
        output = self.relu(output)
        output = self.dense2(output)
        output = self.relu(output)  # (B, ...)
        output = output.view(output.size(0), 64, 4, 4)

        if self.hidden_states:
            all_hidden_states = all_hidden_states + (output,)
        for i, hidden_layer in enumerate(self.hidden_layers):
            output = hidden_layer(output)
            if self.hidden_states:
                all_hidden_states = all_hidden_states + (output,)
        # output = torch.flatten(output, start_dim=1)
        output = self.active(output)  # 1 / (1 + torch.exp(-output))
        # if torch.any(torch.isnan(output)):

        outputs = (output,) + (all_hidden_states,)
        return outputs

class Decoder3C(Decoder1C):
    def __init__(self, config):
        super(Decoder3C, self).__init__(config)

        modules = []
        self.latent_dim = config.latent_dim
        self.hidden_states = config.hidden_states
        self.dense1 = nn.Linear(self.latent_dim, config.dense_dim[1])
        self.dense2 = nn.Linear(config.dense_dim[1], 4 * config.dense_dim[0])
        self.relu = nn.ReLU(True)
        self.active = nn.Sigmoid()

        modules.append(CNNTrasnposedLayer(in_channels=64, out_channels=64))
        modules.append(CNNTrasnposedLayer(in_channels=64, out_channels=32))
        modules.append(CNNTrasnposedLayer(in_channels=32, out_channels=32))
        modules.append(
            nn.ConvTranspose2d(
                in_channels=32, out_channels=3, kernel_size=4, stride=2, padding=1
            )
        )
        self.hidden_layers = nn.ModuleList(modules)


class MNISTDecoder1C(Decoder1C):
    def __init__(self, config):
        super(MNISTDecoder1C, self).__init__(config)

        modules = []
        self.latent_dim = config.latent_dim
        self.hidden_states = config.hidden_states
        self.dense1 = nn.Linear(self.latent_dim, config.dense_dim[1])
        self.dense2 = nn.Linear(config.dense_dim[1], 7*7*32)
        self.relu = nn.ReLU(True)
        self.active = nn.Sigmoid()

        modules.append(CNNTrasnposedLayer(in_channels=32, out_channels=32))
        modules.append(nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=4, stride=2, padding=1))
        self.hidden_layers = nn.ModuleList(modules)

    def forward(self, input):
        # input: [batch, latent_dim]
        all_hidden_states = ()
        # pdb.set_trace()
        output = self.dense1(input)
        # output = self.relu(output)
        output = self.dense2(output)
        # output = self.relu(output)  # (B, ...)
        output = output.view(output.size(0), 32, 7, 7)

        if self.hidden_states:
            all_hidden_states = all_hidden_states + (output,)
        for i, hidden_layer in enumerate(self.hidden_layers):
            output = hidden_layer(output)
            if self.hidden_states:
                all_hidden_states = all_hidden_states + (output,)
        # output = torch.flatten(output, start_dim=1)
        output = self.active(output)  # 1 / (1 + torch.exp(-output))
        # if torch.any(torch.isnan(output)):

        outputs = (output,) + (all_hidden_states,)
        return outputs