# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

align_corners = True
N_latent = 128


"""
https://github.com/davidstutz/disentangling-robustness-generalization/blob/master/training/train_vae_gan2.py
"""
def latent_loss(output_mu, output_logvar):
    """
    Latent KLD loss.
    :param output_mu: target images
    :type output_mu: torch.autograd.Variable
    :param output_logvar: predicted images
    :type output_logvar: torch.autograd.Variable
    :return: error
    :rtype: torch.autograd.Variable
    """

    return -0.5 * torch.sum(1 + output_logvar - output_mu.pow(2) - output_logvar.exp())


def reconstruction_loss(batch_images, output_images, absolute_error=True):
    """
    Reconstruction loss.
    :param batch_images: original images
    :type batch_images: torch.autograd.Variable
    :param output_images: output images
    :type output_images: torch.autograd.Variable
    :return: error
    :rtype: torch.autograd.Variable
    """

    if absolute_error:
        return torch.sum(torch.abs(batch_images - output_images))
    else:
        return torch.sum(torch.mul(batch_images - output_images, batch_images - output_images))

    
def reconstruction_error(batch_images, output_images):
    """
    Reconstruction loss.
    :param batch_images: target images
    :type batch_images: torch.autograd.Variable
    :param output_images: predicted images
    :type output_images: torch.autograd.Variable
    :return: error
    :rtype: torch.autograd.Variable
    """

    return torch.mean(torch.mul(batch_images - output_images, batch_images - output_images))


def decoder_loss(output_reconstructed_classes):
    """
    Adversarial loss for decoder.
    :param output_reconstructed_classes: reconstructed predictions
    :type output_reconstructed_classes: torch.autograd.Variable
    :param output_fake_classes: reconstructed predictions
    :type output_fake_classes: torch.autograd.Variable
    :return: error
    :rtype: torch.autograd.Variable
    """

    return - torch.sum(torch.log(torch.sigmoid(output_reconstructed_classes) + 1e-12))


def reparameterize(mu, logvar):
    """
    Reparameterization trick.
    :param mu: mean vector
    :type mu: torch.autograd.Variable
    :param logvar: logvar vector
    :type logvar: torch.autograd.Variable
    """

    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return eps.mul(std).add_(mu)


def discriminator_loss(output_real_classes, output_reconstructed_classes):
        """
        Adversarial loss.
        :param output_real_classes: real predictions
        :type output_real_classes: torch.autograd.Variable
        :param output_reconstructed_classes: reconstructed predictions
        :type output_reconstructed_classes: torch.autograd.Variable
        :return: error
        :rtype: torch.autograd.Variable
        """

        return - torch.sum(torch.log(torch.sigmoid(output_real_classes) + 1e-12) + 
                           torch.log(1 - torch.sigmoid(output_reconstructed_classes) + 1e-12))
    
    
class VariationalEncoder(nn.Module):

    def __init__(self, resolution, compress_mode=1, resize_dim=None):
        super().__init__()
        
        self.resize_dim = resize_dim
        layers = []
        self.prior_channels = resolution[0]
        prior_channels = self.prior_channels
        multiply = 4 if resolution[0] == 3 else 2

        layers.append(nn.Sequential(
            nn.Conv2d(in_channels=prior_channels, out_channels=prior_channels*multiply, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(prior_channels*multiply),
        ))
        prior_channels *= multiply
        
        if compress_mode >= 2:
            layers.append(nn.Sequential(
                nn.Conv2d(in_channels=prior_channels, out_channels=prior_channels*2, kernel_size=4, stride=2, padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(prior_channels*2),
            ))
            prior_channels *= 2
            
        if compress_mode >=3:
            layers.append(nn.Sequential(
                nn.Conv2d(in_channels=prior_channels, out_channels=prior_channels*2, kernel_size=4, stride=2, padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(prior_channels*2),
            ))
            prior_channels *= 2

        layers.append(nn.Sequential(
            nn.BatchNorm2d(prior_channels)
        ))
        
        self.encoder = nn.Sequential(*layers)
        
        img_dim = resize_dim if resize_dim else resolution[1]
        dim_compress = img_dim / (2 ** compress_mode)
        representation = np.prod([prior_channels, dim_compress, dim_compress]).astype(int)
        self.representation = representation
        self.fc_mu = nn.Linear(representation, N_latent)
        self.fc_var = nn.Linear(representation, N_latent)
        
    def forward(self, x):
        if self.resize_dim:
            x = F.interpolate(x, mode='bilinear', size=(self.resize_dim, self.resize_dim), align_corners=align_corners)
        x = self.encoder(x)
        x = x.view(-1, self.representation)
        mu = self.fc_mu(x)
        var = self.fc_var(x)
        return mu, var

    
class VariationalDecoder(nn.Module):
    
    def __init__(self, resolution, compress_mode=1, original_dim=224):
        super().__init__()

        out_channels = resolution[0]
        if out_channels == 3:
            prior_channels = out_channels * 4 * (2 ** (compress_mode - 1))
        else:
            prior_channels = out_channels * 2 * (2 ** (compress_mode - 1))
        prior_channels = int(prior_channels)

        self.resize_dim = original_dim
        self.prior_channels = prior_channels

        img_dim = 128 if original_dim else resolution[1]
        dim_compress = int(img_dim / (2 ** compress_mode))
        representation = np.prod([prior_channels, dim_compress, dim_compress]).astype(int)
        self.dim_compress = dim_compress

        self.fc = nn.Linear(N_latent, representation)

        layers = []

        if compress_mode >=3:
            layers.append(nn.Sequential(
                nn.ConvTranspose2d(prior_channels, prior_channels // 2, kernel_size=4, stride=2, padding=1), #, output_padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(prior_channels // 2),
            ))
            prior_channels = prior_channels // 2

        if compress_mode >=2:
            layers.append(nn.Sequential(
                nn.ConvTranspose2d(prior_channels, prior_channels // 2, kernel_size=4, stride=2, padding=1), # , output_padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(prior_channels // 2),
            ))
            prior_channels = prior_channels // 2

        layers.append(nn.Sequential(
            nn.ConvTranspose2d(prior_channels, out_channels, kernel_size=4, stride=2, padding=1), #, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
        ))

        self.decoder = nn.Sequential(*layers)


    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, self.prior_channels, self.dim_compress, self.dim_compress)
        x = self.decoder(x)
        x = torch.sigmoid(x)

        if self.resize_dim:
            x = F.interpolate(x, mode='bilinear', size=(self.resize_dim, self.resize_dim), align_corners=align_corners)

        return torch.sigmoid(x)