"""
https://github.com/facebookresearch/disentangling-correlated-factors/blob/main/dent/models/decoder/montero_large.py
"""

# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2019 Yann Dubois, Aleco Kastanos, Dave Lines, Bart Melman
# Copyright (c) 2018 Schlumberger
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn


class Decoder(nn.Module):

    def __init__(self, img_size, latent_dim=10):
        r"""Small Decoder as used in [1].

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32) or (3, 64, 64).

        latent_dim : int
            Dimensionality of latent output.

        Model Architecture (transposed for decoder)
        ------------
        - 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
        - 2 fully connected layers (each of 256 units)
        - Latent distribution:
            - 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)

        References:
            [1] Montero et al. "Lost in Latent Space: Disentangled Models and 
            the Challenge of Combinatorial Generalisation."
        """
        super(Decoder, self).__init__()

        # Layer parameters
        kernel_size = 4
        self.img_size = img_size
        # Shape required to start transpose convs
        n_chan = self.img_size[0]
        self.img_size = img_size
        inp_width = int(64/(2**5))
        self.reshape = (256, inp_width, inp_width)

        # Fully connected layers
        self.lin1 = nn.Linear(latent_dim, 256)
        self.lin2 = nn.Linear(256, inp_width**2 * 256)

        # Convolutional layers
        cnn_kwargs = dict(stride=2, padding=1)
        # If input image is 64x64 do fourth convolution
        self.convT1 = nn.ConvTranspose2d(256, 128, kernel_size, **cnn_kwargs)
        self.convT2 = nn.ConvTranspose2d(128, 128, kernel_size, **cnn_kwargs)
        self.convT3 = nn.ConvTranspose2d(128, 64, kernel_size, **cnn_kwargs)        
        self.convT4 = nn.ConvTranspose2d(64, 64, kernel_size, **cnn_kwargs)        
        self.convT5 = nn.ConvTranspose2d(64, n_chan, kernel_size, **cnn_kwargs)

    def forward(self, z):
        batch_size = z.size(0)

        # Fully connected layers with ReLu activations
        x = torch.relu(self.lin1(z))
        x = torch.relu(self.lin2(x))
        x = x.view(batch_size, *self.reshape)

        # Convolutional layers with ReLu activations
        x = torch.relu(self.convT1(x))
        x = torch.relu(self.convT2(x))
        x = torch.relu(self.convT3(x))
        x = torch.relu(self.convT4(x))
        x = torch.sigmoid(self.convT5(x))

        return {'reconstructions': x}
