import torch
from torch import nn
from torch.distributions.normal import Normal
from torch.distributions import Bernoulli
from .vae_base import VAE
from .AdaptiveCNN import AdaptiveEncoder, AdaptiveDecoder


# flatten the encoder output to one dimension
class Flatten(nn.Module):
    def forward(self, input):
        S, B, C, W, H = input.size()
        return input.view(S, B, C * W * H).contiguous()

#convert the flattened array to C*H*W format for conv layers
class UnFlatten(nn.Module):
    def __init__(self, C):
        super(UnFlatten, self).__init__()
        self.C = C

    def forward(self, input):
        S, B, M = input.size()
        W = int((M / self.C)**0.5)
        return input.view(S, B, self.C, W, W).contiguous()


class ConvVAE(VAE):
    def __init__(self, args, device, img_shape, h_dim, z_dim, truncation=25):
        VAE.__init__(self, device, z_dim, args)

        #preprocess data dimensions for conv network
        self.proc_data = lambda x: x.unsqueeze(1).to(device)
        self.in_channels = img_shape[0]
        self.img_shape = img_shape

        #initialize encoding network
        self.encoder = AdaptiveEncoder(self.in_channels, h_dim, truncation=truncation, device=device)
        #flatten the encoder output to one dimension
        self.flatten = Flatten()

        # variational distribution parameters computation layers
        self.enc_mu = nn.Linear(h_dim * 13 * 13, z_dim)
        self.enc_sig = nn.Linear(h_dim * 13 * 13, z_dim)

        #conver to the 1 dimensional latent representation to C*W*H format to feed into decoding convolutional layers
        self.latent_to_hidden = nn.Sequential(
            nn.Linear(z_dim, h_dim * 13 * 13), nn.ReLU(),
            UnFlatten(h_dim))

        #initialize a decoding network
        self.decoder = AdaptiveDecoder(self.in_channels, num_channels=h_dim, truncation=truncation, device=device)

    def encode(self, x, num_samples=5):
        # preprocess data dimensions to feed into a conv layer
        x = self.proc_data(x)

        #get output from encoding network, flatten the output and obtain variational parameters
        h = self.encoder(x, num_samples)
        h = self.flatten(h)
        mu, _std = self.enc_mu(h), self.enc_sig(h)

        return Normal(mu, nn.functional.softplus(_std), validate_args=False)

    def decode(self, z, num_samples=5):

        #convert the data to proper dimensions to feed into decoding convolutional layers
        I, S, B, D = z.shape
        z = z.permute(1, 0, 2, 3)
        z = z.contiguous().view(S, I*B, D)

        # upsample and convert 1-D latent representations to multidimensional format for feeding to conv layer
        z = self.latent_to_hidden(z)

        #feed into decoding network
        x_dec = self.decoder(z, num_samples)

        #reorganize the dimensions
        x_dec = x_dec.view([S, I, B, self.in_channels, self.img_shape[1], self.img_shape[2]]).contiguous()
        x_dec = x_dec.permute(1, 0, 2, 3, 4, 5)

        return Bernoulli(logits=x_dec, validate_args=False)

    def lpxz(self, true_x, x_dist):
        return x_dist.log_prob(true_x).sum([-1, -2, -3])

    def get_arch_kl(self):
        return self.encoder.get_kl(), self.decoder.get_kl()