from torch import nn
from torch.distributions.normal import Normal

from vae.layers.misc import Swish


class LowerDecoder(nn.Module):
    """ Maps from z_i to x_i """

    def __init__(self):
        super().__init__()
        self.softplus = nn.Softplus()

    def forward(self, x):
        x = self.infer_logits(x)
        mu, scale = x.chunk(2, dim=-1)
        scale = self.softplus(scale)
        dist = Normal(mu, scale)
        samples = dist.rsample()
        output = {'rec': {'dist': dist,
                          'samples': samples  # K x N x D
                          }}
        return output, samples


class LowerDecoderX1(LowerDecoder):
    """ Maps from z_1 to x_1"""

    def __init__(self, latent_size):
        super().__init__()
        self.upsample_layers = self.build_upsample_layers(latent_size)
        self.layers = self.build_layers()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.upsample_layers(x)
        if len(x.size()) == 3:
            k, n = x.size(0), x.size(1)
        else:
            k, n = None, None
        x = x.view(-1, 256, 5, 5)
        x = self.layers(x)
        if k and n:
            x = x.view(k, n, x.size(-3), x.size(-2), x.size(-1))
        x = self.sigmoid(x)  # samples
        output = {'rec': {'dist': None,
                          'samples': x}}
        return output, x

    @staticmethod
    def build_upsample_layers(latent_size):
        layers = nn.Sequential(
            nn.Linear(latent_size, 256 * 5 * 5),
            Swish(),
        )
        return layers

    @staticmethod
    def build_layers():
        # Transposed miniature DCGAN
        # as in https://github.com/mhw32/multimodal-vae-public
        layers = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            Swish(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            Swish(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            Swish(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False))
        return layers


class LowerDecoderX2(LowerDecoder):
    """ Maps from z_2 to x_2"""

    def __init__(self, latent_size):
        super().__init__()
        self.layers = self.build_layers(latent_size)

    def forward(self, x):
        x = self.layers(x)  # logits
        mu, scale = x.chunk(2, dim=-1)
        scale = self.softplus(scale)
        dist = Normal(mu, scale)
        samples = dist.rsample()
        output = {'rec': {'dist': dist,
                          'samples': samples  # K x N x D
                          }}
        return output, samples

    @staticmethod
    def build_layers(latent_size):
        # Mirrors encoder
        layers = nn.Sequential(
            nn.Linear(latent_size, 768),
            nn.LeakyReLU(),
            nn.Linear(768, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024 * 2),
        )
        return layers


class UpperDecoder(nn.Module):
    """ Maps from g to z_i """

    hidden_size = 500

    def __init__(self, g_size, z_size):
        super().__init__()
        self.backbone = self.build_backbone(g_size)
        self.output_layer = self.build_output_layer(z_size)
        self.softplus = nn.Softplus()

    def forward(self, x):
        x = self.backbone(x)
        x = self.output_layer(x)  # logits
        mu, scale = x.chunk(2, dim=-1)
        scale = self.softplus(scale)
        dist = Normal(mu, scale)
        samples = dist.rsample()
        output = {'prior': {'dist': dist,
                            'samples': samples  # K x N x D
                            }}
        return output, samples

    def build_backbone(self, g_size):
        # Vasco et al. do not specify this network
        # We use slightly less capacity than the encoder from h to g
        # (because there is only one upper encoder but multiple upper decoders)
        layers = nn.Sequential(
            nn.Linear(g_size, self.hidden_size),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.LeakyReLU(),
        )
        return layers


class UpperDecoderX1(UpperDecoder):
    """ Maps from g to z_1 """

    def build_output_layer(self, z_size):
        layer = nn.Linear(self.hidden_size, z_size * 2)
        return layer


class UpperDecoderX2(UpperDecoder):
    """ Maps from g to z_2 """

    def build_output_layer(self, z_size):
        layer = nn.Linear(self.hidden_size, z_size * 2)
        return layer
