import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable


class MLPAug(nn.Module):
    def __init__(self, input_dim, h_dim=128):
        super().__init__()
        self.h_dim = h_dim
        # encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2hid = nn.Linear(h_dim, h_dim)

        # one for mu and one for stds, note how we only output
        # diagonal values of covariance matrix. Here we assume
        # the pixels are conditionally independent 
        # self.hid_2mu = nn.Linear(h_dim, z_dim)
        # self.hid_2sigma = nn.Linear(h_dim, z_dim)

        self.mu = nn.Parameter(torch.rand(h_dim))
        self.sigma = nn.Parameter(torch.rand(h_dim))

        # decoder
        self.z_2hid = nn.Linear(h_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)

    def encode(self, x):
        latent_x = F.relu(self.img_2hid(x))
        return latent_x

    def decode(self, x, z):
        x = torch.relu(self.img_2hid(x))
        x = torch.relu(self.hid_2hid(x))
        x = x + z
        x = torch.relu(self.z_2hid(x))
        x = torch.sigmoid(self.hid_2img(x))
        return x

    def forward(self, x, return_eps=False):
        original_shape = x.shape
        x = x.reshape(x.shape[0], -1)

        # add noise to the image
        # epsilon = Variable(torch.randn_like(x))
        epsilon = Variable(torch.randn(x.shape[0], self.h_dim).to(x.device))
        z = self.mu + self.sigma * epsilon

        x = self.decode(x, z)
        x = x.reshape(original_shape)
        if return_eps:
            return x, epsilon
        else:
            return x
        
    def forward_mean(self, x):
        original_shape = x.shape
        x = x.reshape(x.shape[0], -1)

        # add noise to the image
        # epsilon = Variable(torch.randn_like(x))
        z = self.mu 

        x = self.decode(x, z)
        x = x.reshape(original_shape)
    
        return x


class VAE(nn.Module):
    def __init__(self, label, image_size, channel_num, kernel_num, z_size):
        # configurations
        super().__init__()
        self.label = label
        self.image_size = image_size
        self.channel_num = channel_num
        self.kernel_num = kernel_num
        self.z_size = z_size

        # # encoder
        # self.encoder = nn.Sequential(
        #     self._conv(channel_num, kernel_num // 8),
        #     self._conv(kernel_num // 8, kernel_num // 4),
        #     self._conv(kernel_num // 4, kernel_num // 2),
        #     self._conv(kernel_num // 2, kernel_num),
        # )

        # encoded feature's size and volume
        # self.feature_size = image_size // 16
        # self.feature_size = image_size // 32
        # self.feature_size = image_size // 64
        self.feature_size = 2 
        self.feature_volume = kernel_num * (self.feature_size ** 2)

        # # q
        # self.q_mean = self._linear(self.feature_volume, z_size, relu=False)
        # self.q_logvar = self._linear(self.feature_volume, z_size, relu=False)

        # projection
        self.project = self._linear(z_size, self.feature_volume, relu=False)

        # decoder
        decoder_config = {
            32: [1, 2, 4, 8, 8],
            64: [1, 2, 4, 8, 16, 16],
            128: [1, 2, 4, 8, 16, 32, 32]
        }

        decoder_shape = decoder_config[self.image_size]
        decoder = []
        for i in range(len(decoder_shape)-1):
            decoder.append(self._deconv(kernel_num//decoder_shape[i], kernel_num // decoder_shape[i+1]))

        decoder.append(nn.Conv2d(kernel_num // decoder_shape[-1], out_channels=channel_num, kernel_size=3, padding=1))
        decoder.append(nn.Tanh())
        self.decoder = nn.Sequential(*decoder)
        
        # self.decoder = nn.Sequential(
        #     self._deconv(kernel_num, kernel_num // 2),
        #     self._deconv(kernel_num // 2, kernel_num // 4),
        #     self._deconv(kernel_num // 4, kernel_num // 8),
        #     self._deconv(kernel_num // 8, kernel_num // 16),
        #     self._deconv(kernel_num // 16, kernel_num // 16),
        #     # self._deconv(kernel_num // 32, kernel_num // 32),
        #     nn.Conv2d(kernel_num // 16, out_channels=channel_num, kernel_size=3, padding=1),
        #     nn.Tanh()
        # )


    def forward(self, image_syn):
        # encode x
        # encoded = self.encoder(x)

        # sample latent code z from q given x.
        mean, logvar = image_syn[:, :, 0], image_syn[:, :, 1]
        # mean = mean + 1.*Variable(torch.randn_like(mean))
        # logvar = logvar + 1.*Variable(torch.randn_like(logvar))
        z = self.z(mean, logvar)
        z_projected = self.project(z).view(
            -1, self.kernel_num,
            self.feature_size,
            self.feature_size,
        )
        # reconstruct x from z
        x_reconstructed = self.decoder(z_projected)

        # return the reconstructed image.
        return x_reconstructed
    
    def forward_fixed(self, image_syn, eps=None):
        # encode x
        # encoded = self.encoder(x)

        # sample latent code z from q given x.
        if len(image_syn.shape) == 3:
            mean, logvar = image_syn[:, :, 0], image_syn[:, :, 1]
        elif len(image_syn.shape) == 2:
            mean, logvar = image_syn[:, 0], image_syn[:, 1]
        else:
            raise Exception("Dimension Error")

        # no noise added
        std = logvar.mul(0.5).exp_() 
        if eps is None:
            z = mean
        else:
            z = mean + eps
        z_projected = self.project(z).view(
            -1, self.kernel_num,
            self.feature_size,
            self.feature_size,
        )

        # reconstruct x from z
        x_reconstructed = self.decoder(z_projected)

        # return the parameters of distribution of q given x and the
        # reconstructed image.
        return  x_reconstructed
    
    def forward_given(self, z):
        # encode x
        # encoded = self.encoder(x)
        
        z_projected = self.project(z).view(
            -1, self.kernel_num,
            self.feature_size,
            self.feature_size,
        )

        # reconstruct x from z
        x_reconstructed = self.decoder(z_projected)

        # return the parameters of distribution of q given x and the
        # reconstructed image.
        return  x_reconstructed


    # ==============
    # VAE components
    # ==============

    def q(self, encoded):
        unrolled = encoded.view(-1, self.feature_volume)
        return self.q_mean(unrolled), self.q_logvar(unrolled)

    def z(self, mean, logvar):
        std = logvar.mul(0.5).exp_()
        eps = (
            Variable(torch.randn(std.size())).cuda() if self._is_on_cuda else
            Variable(torch.randn(std.size()))
        )
        return eps.mul(std).add_(mean)

    def reconstruction_loss(self, x_reconstructed, x):
        return nn.BCELoss(size_average=False)(x_reconstructed, x) / x.size(0)

    def kl_divergence_loss(self, mean, logvar):
        return ((mean**2 + logvar.exp() - 1 - logvar) / 2).mean()
    
    def kl_divergence_custom(mean, mean_1, logvar_1, mean_2, logvar_2):
        d = mean_1.shape[0]
        mean = ((mean_1 - mean_2)*1./logvar_2.exp()*(mean_1 - mean_2)).sum()
        var = (logvar_1.exp()/logvar_2.exp()).sum()
        log_var = (logvar_2- logvar_1).sum()
        kl_loss = mean + var + log_var 

        return (kl_loss - d)/2
    
    def wass_custom(mean, mean_1, logvar_1, mean_2, logvar_2):
        mean = ((mean_1 - mean_2)**2).sum()
        var = ((logvar_1.exp().sqrt() - logvar_2.exp().sqrt())**2).sum()
        wass_loss = mean + var  

        return 1./wass_loss

    def wasserstein_loss(self, mean, logvar):
        sigma = logvar.exp()
        return (mean**2 + sigma).mean()

    # =====
    # Utils
    # =====

    @property
    def name(self):
        return (
            'VAE'
            '-{kernel_num}k'
            '-{label}'
            '-{channel_num}x{image_size}x{image_size}'
        ).format(
            label=self.label,
            kernel_num=self.kernel_num,
            image_size=self.image_size,
            channel_num=self.channel_num,
        )

    def sample(self, size):
        z = Variable(
            torch.randn(size, self.z_size).cuda() if self._is_on_cuda() else
            torch.randn(size, self.z_size)
        )
        z_projected = self.project(z).view(
            -1, self.kernel_num,
            self.feature_size,
            self.feature_size,
        )
        return self.decoder(z_projected).data

    def _is_on_cuda(self):
        return next(self.parameters()).is_cuda

    # ======
    # Layers
    # ======

    def _conv(self, channel_size, kernel_num):
        return nn.Sequential(
            nn.Conv2d(
                channel_size, kernel_num,
                kernel_size=4, stride=2, padding=1,
            ),
            nn.BatchNorm2d(kernel_num),
            nn.LeakyReLU(),
        )

    def _deconv(self, channel_num, kernel_num, stride=2):
        return nn.Sequential(
            nn.ConvTranspose2d(
                channel_num, kernel_num,
                kernel_size=4, stride=stride, padding=1,
            ),
            nn.BatchNorm2d(kernel_num),
            nn.LeakyReLU(),
        )

    def _linear(self, in_size, out_size, relu=True):
        return nn.Sequential(
            nn.Linear(in_size, out_size),
            nn.ReLU(),
        ) if relu else nn.Linear(in_size, out_size)
