from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn

from models.resnet_vibn_base import ResNetVIBNBase


class ResNet18IDO(nn.Module):
    def __init__(self, num_classes=100, alpha_max=1.0, l2_norm=True,
                    training_samples=1, test_samples=16, gamma=1.0, q=0.1, maxpool=False, **kwargs):
        super(ResNet18IDO, self).__init__()
        self._alpha_max = torch.nn.Parameter(torch.tensor(alpha_max), requires_grad=False)
        self._gamma = gamma
        self._training_samples = training_samples
        self._test_samples = test_samples
        self._q =  q
        self.resnet = ResNetVIBNBase('BasicBlock',  [2, 2, 2, 2], num_classes=num_classes, l2_norm=l2_norm, maxpool=maxpool, **kwargs)

    def forward(self, example_dict):
        if self.training:
            samples = self._training_samples
        else:
            samples = self._test_samples
        inputs = example_dict['input1']
        z, z_var = self.resnet(inputs)
        z_var = torch.sigmoid(z_var)
        z_alpha = self._alpha_max * z_var
        indicator = (z > 0.0).float() - torch.sigmoid(z).detach() + torch.sigmoid(z)
        latent_kld = - indicator * (torch.log(z_alpha + 1e-24) - torch.log(self._alpha_max))
        vibn_loss = latent_kld.mean(dim=0).sum()

        z = z.repeat_interleave(samples, dim=0)
        z_alpha = z_alpha.repeat_interleave(samples, dim=0)

        normal_dist = torch.distributions.normal.Normal(torch.zeros_like(z), torch.ones_like(z))
        normals =  normal_dist.sample()
        z_s = z * torch.exp(z_alpha * normals - z_alpha**2 / 2)
        prediction = self.decoder(z_s)

        U = torch.distributions.uniform.Uniform(-torch.ones_like(z), torch.ones_like(z))
        u = U.sample()
        B = torch.distributions.bernoulli.Bernoulli(self._q * torch.ones_like(z))
        b = B.sample()
        s_noise = b *  torch.exp(self._gamma * u)
        l_noise = self.decoder(s_noise)

        return {'vibn_loss': vibn_loss, 'prediction': prediction, 'l_noise': l_noise, 'latent_kld': latent_kld}

    def decoder(self, z):
        return self.resnet.decoder(z)