from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D

from models.resnet_vibn_base import ResNetVIBNBase


class ResNet18VIBNMP(nn.Module):
    def __init__(self, num_classes=100, l2_norm=True, training_samples=1, test_samples=16,
                    gamma=1.0, maxpool=False, num_mixture_components=32, **kwargs):
        super(ResNet18VIBNMP, self).__init__()
        self._gamma = gamma
        self._training_samples = training_samples
        self._test_samples = test_samples
        self._num_mixture_components = num_mixture_components
        self.resnet = ResNetVIBNBase('BasicBlock',  [2, 2, 2, 2], num_classes=num_classes, l2_norm=l2_norm, maxpool=maxpool, **kwargs)
        
        self._num_mixture_components = num_mixture_components
        self._C, self._H, self._W = 256, 4, 4
        self.mixture_log_weights = nn.Parameter(torch.zeros([num_mixture_components]), requires_grad=True)
        self.mixture_means = nn.Parameter(0.1 * torch.randn([num_mixture_components, self._C, self._H, self._W]), requires_grad=True)
        self.mixture_log_vars = nn.Parameter(torch.zeros([num_mixture_components, self._C, self._H, self._W]), requires_grad=True)

    def forward(self, example_dict):
        if self.training:
            samples = self._training_samples
        else:
            samples = self._test_samples
            
        inputs = example_dict['input1']
        z, z_var = self.resnet(inputs)
        z_var = torch.exp(z_var)

        z = z.repeat_interleave(samples, dim=0)
        z_var = z_var.repeat_interleave(samples, dim=0)
        sample_batch_size = z.shape[0]

        normal_dist = torch.distributions.normal.Normal(torch.zeros_like(z), torch.ones_like(z))
        normals =  normal_dist.sample()
        z_s = z +  self._gamma * torch.sqrt(z_var + 1e-24) * normals
        prediction = self.decoder(z_s)
    
        mixture_weights =  F.softmax(self.mixture_log_weights, dim=0)
        
        vibn_loss = 0.0
        for i in range(self.mixture_means.shape[0]):
            mu = self.mixture_means[i:i+1,:,:,:]
            log_var = self.mixture_log_vars[i:i+1,:,:,:]
            latent_kld = 0.5 * (- torch.log(z_var + 1e-24)  +  log_var + (z_s - mu)**2 / torch.exp(log_var))
            vibn_loss += latent_kld.mean(dim=0).sum() * mixture_weights[i]

        mixture_weights_extended = D.Categorical(probs=mixture_weights.detach().clone().unsqueeze(dim=0).expand(sample_batch_size, -1))
        mixture_means_extended = self.mixture_means.detach().clone().unsqueeze(dim=0).view(1, self._num_mixture_components, self._C*self._H*self._W).expand(sample_batch_size, -1, -1)
        mixture_vars_extended = torch.exp(self.mixture_log_vars.detach().clone().unsqueeze(dim=0).view(1, self._num_mixture_components, self._C*self._H*self._W).expand(sample_batch_size, -1, -1))
        comp_extended =  D.Independent(D.Normal(mixture_means_extended, mixture_vars_extended), 1)
        gmm = D.MixtureSameFamily(mixture_weights_extended, comp_extended)

        prior_noise = gmm.sample()
        prior_noise  = prior_noise.view(sample_batch_size, self._C, self._H, self._W)
        l_noise = self.decoder(prior_noise)


        return {'vibn_loss': vibn_loss, 'prediction': prediction, 'l_noise': l_noise, 'latent_kld': latent_kld}

    def decoder(self, z):
        return self.resnet.decoder(z)