import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.distributions import Beta, RelaxedBernoulli, Bernoulli
from torch.distributions.kl import kl_divergence

''' adopted from https://github.com/kckishan/Depth_and_Dropout'''

class SampleNetworkArchitecture(nn.Module):
    """
    Samples an architecture from Beta-Bernoulli prior
    """

    def __init__(self, num_neurons=64, a_prior=2., b_prior=1., num_samples=5, truncation=25, device=None):
        super(SampleNetworkArchitecture, self).__init__()
        self.device = device
        self.num_neurons = num_neurons
        # Hyper-parameters for Prior probabilities
        self.a_prior = torch.tensor(a_prior).float().to(self.device)
        self.b_prior = torch.tensor(b_prior).float().to(self.device)

        # Define a prior beta distribution
        self.beta_prior = Beta(self.a_prior, self.b_prior)

        # Temperature for Bernoulli sampling
        self.temperature = torch.tensor(3).to(self.device)
        self.truncation = truncation
        # Number of samples from IBP prior to estimate expectations
        self.num_samples = num_samples

        a_val = np.log(np.exp(np.random.uniform(1.1, 1.1)) - 1)  # inverse softplus
        b_val = np.log(np.exp(np.random.uniform(1.0, 1.0)) - 1)

        # Define variational parameters for posterior distribution
        self.var_a = nn.Parameter(torch.Tensor(self.truncation).zero_() + a_val)
        self.var_b = nn.Parameter(torch.Tensor(self.truncation).zero_() + b_val)

    def get_var_params(self):
        beta_a = F.softplus(self.var_a) + 0.01
        beta_b = F.softplus(self.var_b) + 0.01
        return beta_a, beta_b

    def get_kl_beta(self):
        """
        Computes the KL divergence between posterior and prior
        Parameters
        ----------
        threshold : Number of layers on sampled network

        Returns
        -------

        """
        beta_a, beta_b = self.get_var_params()
        beta_posterior = Beta(beta_a, beta_b)
        kl_beta = kl_divergence(beta_posterior, self.beta_prior)
        return kl_beta.sum()

    def get_kl(self):
        return self.get_kl_beta() #+ self.get_kl_bernoulli()

    def forward(self, num_samples=5, get_pi=False, get_intermediate_pi=False):
        # Define variational beta distribution
        beta_a, beta_b = self.get_var_params()
        beta_posterior = Beta(beta_a, beta_b)

        # sample from variational beta distribution
        v = beta_posterior.rsample((num_samples, )).view(num_samples, self.truncation)

        # Convert v -> pi i.e. activation level of layer
        pi = torch.cumsum(v.log(), dim=1).exp()
        pi = pi.unsqueeze(1).expand(-1, self.num_neurons, -1)


        if self.training:
            # sample active neurons given the activation probability of that layer
            bernoulli_dist = RelaxedBernoulli(temperature=self.temperature, probs=pi)
            Z = bernoulli_dist.rsample()
        else:
            # sample active neurons given the activation probability of that layer
            bernoulli_dist = Bernoulli(probs=pi)
            Z = bernoulli_dist.sample()

        # compute threshold
        threshold_Z = (Z > 0.01).sum(1)
        threshold_array = (threshold_Z > 0).sum(dim=1).cpu().numpy()
        threshold = max(threshold_array)

        # In case of no active layers
        if threshold == 0:
            threshold = torch.tensor(1)

        self.threshold = threshold
        if get_pi:
            return Z, np.median(threshold_array), np.percentile(threshold_array, 25), np.percentile(threshold_array, 75), pi.mean(0).mean(0)

        if get_intermediate_pi:
            return pi

        return Z, pi, threshold, threshold_array

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.num_neurons) + ' x ' + str(self.truncation) + ')'