from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn

from models.resnet_ncvibn_base import ResNetNCVIBNBase


class ResNet18NCVIBN(nn.Module):
    def __init__(self, num_classes=100, prior_precision=1e0, var_dampening=1e-4, l2_norm=True,
                training_samples=1, test_samples=16, gamma=1.0, maxpool=False, **kwargs):
        super(ResNet18NCVIBN, self).__init__()
        self._gamma = gamma
        self._training_samples = training_samples
        self._test_samples = test_samples
        self.resnet = ResNetNCVIBNBase('BasicBlock',  [2, 2, 2, 2], num_classes=num_classes, prior_precision=prior_precision,
            var_dampening=var_dampening, l2_norm=l2_norm, maxpool=maxpool, **kwargs)

    def forward(self, example_dict):
        if self.training:
            samples = self._training_samples
        else:
            samples = self._test_samples
            
        inputs = example_dict['input1']
        z, z_var = self.resnet(inputs)
        z_var = torch.exp(z_var)
        latent_kld = 0.5 * (- torch.log(z_var + 1e-24)  +  (z_var + z**2) - 1)
        vibn_loss = latent_kld.mean(dim=0).sum()

        z = z.repeat_interleave(samples, dim=0)
        z_var = z_var.repeat_interleave(samples, dim=0)

        normal_dist = torch.distributions.normal.Normal(torch.zeros_like(z), torch.ones_like(z))
        normals =  normal_dist.sample()
        z_s = z +  self._gamma * torch.sqrt(z_var + 1e-24) * normals
        prediction = self.decoder(z_s)

        normals = self._gamma  * normal_dist.sample()
        l_noise = self.decoder(normals)

        return {'vibn_loss': vibn_loss, 'prediction': prediction, 'l_noise': l_noise, 'latent_kld': latent_kld, 'decoder_kld': self.kl_div()}

    def decoder(self, z):
        return self.resnet.decoder(z)

    def kl_div(self):
        return self.resnet.kl_div()