import logging

import torch
from torch import nn
from torch.distributions.normal import Normal

from vae.layers.misc import Swish
from vae.misc import get_trainable_params

logger = logging.getLogger('custom')


class DisentanglementVae(nn.Module):
    """
    Two observed modalities, one shared latent variable,
    two modality-specific latent variables
    """

    def __init__(self, args):
        super().__init__()
        logger.info(f'\n====> Building multimodal disentanglement VAE.')
        self.vaes = nn.ModuleDict({'x1': ImageVAE(args),
                                   'x2': CaptionVAE(args)})
        self.modalities = list(self.vaes.keys())
        self.stoc_dim = {
            'g': self.vaes['x1'].stoc_dim['g'],
            'z1': self.vaes['x1'].stoc_dim['z'],
            'z2': self.vaes['x2'].stoc_dim['z'],
        }

        # Prepare priors
        self.pg, self.pz1, self.pz2 = Normal, Normal, Normal
        self._pg_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, self.stoc_dim['g']),
                         requires_grad=False),
            nn.Parameter(torch.ones(1, self.stoc_dim['g']),
                         requires_grad=False)
        ])
        self._pz1_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, self.stoc_dim['z1']),
                         requires_grad=False),
            nn.Parameter(torch.ones(1, self.stoc_dim['z1']),
                         requires_grad=False)
        ])
        self._pz2_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, self.stoc_dim['z2']),
                         requires_grad=False),
            nn.Parameter(torch.ones(1, self.stoc_dim['z2']),
                         requires_grad=False)
        ])

        self.print_number_of_parameters()

    def print_number_of_parameters(self):
        params = {'VAE (x1)': get_trainable_params(self.vaes['x1']),
                  'VAE (x2)': get_trainable_params(self.vaes['x2']),
                  'Total': get_trainable_params(self)}

        logger.info('\nParameter overview:')
        for k, v in params.items():
            logger.info(f'- {k}: {v / 10 ** 6:.1f}M')

    @property
    def pg_params(self):
        return self._pg_params

    @property
    def pz1_params(self):
        return self._pz1_params

    @property
    def pz2_params(self):
        return self._pz2_params

    def forward(self, *args, **kwargs):
        raise NotImplementedError('Define in subclass.')

    def _crossmodal_generation_training(self, output):
        """ Sample z_i from posterior
        Input: z_i, g
        Output: x_j
        """
        for m1 in self.modalities:  # condition
            for m2 in self.modalities:  # target
                if m1 != m2:
                    z = output['posterior'][m2][m2]['z']['samples']
                    g = output['posterior'][m1][m1]['g']['samples']
                    z = torch.cat((z, g), dim=-1)
                    xr = self.vaes[m2].decode(z)
                    output['reconstruction'][m2][m1] = xr
        return output

    def _crossmodal_generation_eval(self, output):
        """ Sample z_i from prior
        Input: z_i, g
        Output: x_j
        """
        for m1 in self.modalities:  # condition
            for m2 in self.modalities:  # target
                if m1 != m2:
                    g = output['posterior'][m1][m1]['g']['samples']
                    pz = self.pz1(*self.pz1_params) if m2 == 'x1' \
                        else self.pz2(*self.pz2_params)
                    k, n, _ = g.size()
                    z = pz.sample((k, n,)).squeeze(dim=-2)
                    z = torch.cat((z, g), dim=-1)
                    xr = self.vaes[m2].decode(z)
                    output['reconstruction'][m2][m1] = xr
        return output


class VAE(nn.Module):
    """One observed variable, two latent variables. """

    def __init__(self, stoc_dim):
        super().__init__()
        self.stoc_dim = stoc_dim
        self.encoder = self._build_encoder()
        self.decoder = self._build_decoder()
        self.softplus = nn.Softplus()

    def forward(self, x, k):
        posterior = self.encode(x, k)
        x = torch.cat((posterior['z']['samples'], posterior['g']['samples']),
                      dim=-1)  # z
        reconstruction = self.decode(x)
        output = {'posterior': posterior,
                  'reconstruction': reconstruction}
        return output

    def encode(self, x, k=1):
        logits = self.encoder(x)
        logits_z = logits[:, :self.stoc_dim['z'] * 2]
        logits_g = logits[:, self.stoc_dim['z'] * 2:]
        posterior = {'z': self._encode_single_variable(logits_z, k),
                     'g': self._encode_single_variable(logits_g, k)}
        return posterior

    def _encode_single_variable(self, logits, k):
        mu, scale = logits.chunk(2, dim=-1)
        scale = self.softplus(scale)
        dist = Normal(mu, scale)
        posterior = {'dist': dist,
                     'samples': dist.rsample((k,)).squeeze()  # K x N x D
                     }
        return posterior

    def decode(self, z):
        logits = self.decoder(z)
        mu, scale = logits.chunk(2, dim=-1)
        scale = self.softplus(scale)
        dist = Normal(mu, scale)
        reconstruction = {'dist': dist,
                          'samples': dist.rsample()}
        return reconstruction

    def _build_encoder(self, *args, **kwargs):
        raise NotImplementedError('Define in subclass.')

    def _build_decoder(self, *args, **kwargs):
        raise NotImplementedError('Define in subclass.')


class ImageVAE(VAE):
    def __init__(self, args):
        logger.info(f'\n====> Building image VAE:')
        stoc_dim = {'z': args.stoc_dim['z1'],
                    'g': args.stoc_dim['g']}
        super().__init__(stoc_dim)
        self.sigmoid = nn.Sigmoid()

    def _build_encoder(self):
        latent_size = self.stoc_dim['z'] + self.stoc_dim['g']
        encoder = nn.ModuleDict()
        encoder['backbone'] = nn.Sequential(
            # Miniature DCGAN as in https://github.com/mhw32/multimodal-vae-public
            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            Swish(),
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            Swish(),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            Swish(),
            nn.Conv2d(128, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            Swish(),
        )
        logger.info(f'Encoder backbone:\n{encoder["backbone"]}')
        encoder['output'] = nn.Sequential(
            nn.Linear(256 * 5 * 5, 512),
            nn.LeakyReLU(),
            nn.Linear(512, latent_size * 2)
        )
        logger.info(f'Encoder output:\n{encoder["output"]}')
        return encoder

    def _build_decoder(self):
        latent_size = self.stoc_dim['z'] + self.stoc_dim['g']
        decoder = nn.ModuleDict()
        decoder['upsample'] = nn.Sequential(
            nn.Linear(latent_size, 256 * 5 * 5),
            Swish()
        )
        logger.info(f'Decoder output:\n{decoder["upsample"]}')
        decoder['output'] = nn.Sequential(
            # Transposed miniature DCGAN
            # as in https://github.com/mhw32/multimodal-vae-public
            nn.ConvTranspose2d(256, 128, 4, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            Swish(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            Swish(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            Swish(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
        )
        logger.info(f'Decoder backbone:\n{decoder["output"]}')
        return decoder

    def encode(self, x, k=1):
        x = self.encoder['backbone'](x)
        x = x.view(-1, 256 * 5 * 5)
        x = self.encoder['output'](x)
        logits_z = x[:, :self.stoc_dim['z'] * 2]
        logits_g = x[:, self.stoc_dim['z'] * 2:]
        posterior = {'z': self._encode_single_variable(logits_z, k),
                     'g': self._encode_single_variable(logits_g, k)}
        return posterior

    def decode(self, x):
        x = self.decoder['upsample'](x)
        if len(x.size()) == 3:
            k, n = x.size(0), x.size(1)
        else:
            k, n = None, None
        x = x.view(-1, 256, 5, 5)
        x = self.decoder['output'](x)
        if k and n:
            x = x.view(k, n, x.size(-3), x.size(-2), x.size(-1))
        x = self.sigmoid(x)  # samples
        reconstruction = {'dist': None,
                          'samples': x}
        return reconstruction


class CaptionVAE(VAE):
    def __init__(self, args):
        logger.info(f'\n====> Building attribute VAE:')
        stoc_dim = {'z': args.stoc_dim['z2'],
                    'g': args.stoc_dim['g']}
        super().__init__(stoc_dim)

    def _build_encoder(self):
        latent_size = self.stoc_dim['z'] + self.stoc_dim['g']
        network = nn.Sequential(*[
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 768),
            nn.LeakyReLU(),
            nn.Linear(768, 768),
            nn.LeakyReLU(),
            nn.Linear(768, 768),
            nn.LeakyReLU(),
            nn.Linear(768, latent_size * 2),
        ])
        logger.info(f'Encoder:\n{network}')
        return network

    def _build_decoder(self):
        latent_size = self.stoc_dim['z'] + self.stoc_dim['g']
        network = nn.Sequential(*[
            nn.Linear(latent_size, 768),
            nn.LeakyReLU(),
            nn.Linear(768, 768),
            nn.LeakyReLU(),
            nn.Linear(768, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024 * 2),
            nn.LeakyReLU(),
        ])
        logger.info(f'Decoder:\n{network}')
        return network
