from torch import nn
from torch.distributions.normal import Normal

from vae.layers.misc import Swish


class Backbone(nn.Module):
    """ Maps from x_i to h_i """

    def __init__(self):
        super().__init__()
        self.first_layers = self.build_first_layers()
        self.output_layer = self.build_output_layer()

    def forward(self, x):
        x = self.first_layers(x)
        x = x.view(x.size(0), -1)
        x = self.output_layer(x)
        return x


class BackboneX1(Backbone):
    """ Maps from x_1 to h_1"""

    @staticmethod
    def build_first_layers():
        layers = nn.Sequential(
            # Miniature DCGAN as in
            # https://github.com/mhw32/multimodal-vae-public/blob/master/celeba/model.py
            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            Swish(),
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            Swish(),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            Swish(),
            nn.Conv2d(128, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            Swish(),
        )
        return layers

    @staticmethod
    def build_output_layer():
        # Vasco et al.: "followed by a linear layer of 512 hidden units."
        layer = nn.Sequential(nn.Linear(256 * 5 * 5, 512),
                              Swish())
        return layer


class BackboneX2(Backbone):
    """ Maps from x_2 to h_2"""

    @staticmethod
    def build_first_layers():
        layers = nn.Sequential(
            # Same capacity as lower encoder of HMVAE
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 768),
            nn.LeakyReLU(),
            # Add some further capacity
            # (otherwise Vasco's model has less capacity than the HMVAE,
            # because the HMVAE has more capacity in the upper encoders)
            nn.Linear(768, 768),
            nn.LeakyReLU(),
        )
        return layers

    @staticmethod
    def build_output_layer():
        # Final layer:
        #   512 output as in Vasco et al.
        #   Ensures that the network has sufficient capacity
        layer = nn.Sequential(nn.Linear(768, 512),
                              nn.LeakyReLU())
        return layer


class Encoder(nn.Module):
    """ Maps from h to z or g """

    def __init__(self, *args, **kwargs):
        super().__init__()
        self.layers = self.build_layers(*args, **kwargs)
        self.softplus = nn.Softplus()


class EncoderG(Encoder):
    """ Maps from h to g """

    @staticmethod
    def build_layers(latent_size, hidden_size=768):
        # Vasco et al. use "three linear layers with 256 hidden units"
        # We increase the amount of hidden units to make sure Vasco's HMVAE
        # has similar capacity to the HMVAE

        layers = nn.Sequential(
            nn.Linear(512 * 2, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, latent_size * 2),
        )
        return layers

    def forward(self, h, k=None):
        x = self.layers(h)  # logits
        mu, scale = x.chunk(2, dim=-1)
        scale = self.softplus(scale)
        dist = Normal(mu, scale)
        if k is not None:
            samples = dist.rsample((k,))
        else:
            samples = dist.rsample()
        output = {'posterior': {'dist': dist,
                                'samples': samples  # K x N x D
                                }}
        return output, samples


class EncoderZ(Encoder):
    """ Maps from h to z """

    @staticmethod
    def build_layers(latent_size):
        # Vasco et al. do not explicitly describe this network.
        # We avoid a nonlinearity here. That is because the natural inference
        # model would contain the factor q(g|z) and using a simple function
        # f(h)=z "approximates" this.
        layers = nn.Sequential(
            nn.Linear(512, latent_size * 2),
        )
        return layers

    def forward(self, x, k=None):
        x = self.layers(x)  # logits
        mu, scale = x.chunk(2, dim=-1)
        scale = self.softplus(scale)
        dist = Normal(mu, scale)
        if k is not None:
            samples = dist.rsample((k,))
        else:
            samples = dist.rsample()
        output = {'posterior': {'dist': dist,
                                'samples': samples  # K x N x D
                                }}
        return output, samples
