import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.distributions as dist


def kl_divergence(z, mu_theta, p_theta, mu_prior, p_prior, m):
    log_prior = dist.Normal(mu_prior, torch.log(1 + torch.exp(p_prior))).log_prob(z)
    log_p_q = dist.Normal(mu_theta, torch.log(1 + torch.exp(p_theta))).log_prob(z)
    return (log_p_q - log_prior).sum() / m


def exact_kl_divergence(mu_theta, p_theta, mu_prior, p_prior, m):
    sigma_theta = torch.log(1 + torch.exp(p_theta))
    sigma_prior = torch.log(1 + torch.exp(p_prior))
    numerator = (mu_theta - mu_prior).pow(2) + 2*torch.pow(sigma_theta, 2)
    denominator = 2*torch.pow(sigma_prior, 2)
    div_elem = torch.sum(torch.log(sigma_prior)-torch.log(sigma_theta) + numerator / denominator - 0.5)
    # assert div_elem >= 0
    return div_elem / m


def reparameterize(mu, p):
    sigma = torch.log(1 + torch.exp(p))
    eps = torch.randn_like(sigma)
    return mu + (eps * sigma)


class LinearVariational(nn.Module):
    """
    Mean field approximation of nn.Linear
    """

    def __init__(self, in_features, out_features, bias=True, prior_std=1e-3):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.include_bias = bias

        # Initialize the variational parameters.
        # 𝑄(𝑤)=N(𝜇_𝜃,𝜎2_𝜃)
        # Do some random initialization with 𝜎=0.001
        self.w_mu = nn.Parameter(
            torch.FloatTensor(in_features, out_features).normal_(mean=0, std=prior_std)
        )
        # proxy for variance
        # log(1 + exp(ρ))◦ eps
        self.w_p = nn.Parameter(
            torch.FloatTensor(in_features, out_features).normal_(mean=0, std=prior_std)
        )
        if self.include_bias:
            self.b_mu = nn.Parameter(
                torch.zeros(out_features)
            )
            # proxy for variance
            self.b_p = nn.Parameter(
                torch.zeros(out_features)
            )

    def forward(self, x, get_wb=False):
        w = reparameterize(self.w_mu, self.w_p)

        if self.include_bias:
            b = reparameterize(self.b_mu, self.b_p)
        else:
            b = 0

        z = x @ w + b
        if get_wb:
            return z, w, b
        return z


class Conv2dVariational(nn.Module):
    """
    Mean field approximation of nn.Conv2d
    """

    def __init__(self, in_channels, out_channels, kernel_size, bias=True, prior_std=1e-3, padding=0):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_filters = kernel_size
        self.include_bias = bias
        self.padding = padding

        weights_shape = (out_channels, in_channels, kernel_size, kernel_size)

        # Initialize the variational parameters.
        # 𝑄(𝑤)=N(𝜇_𝜃,𝜎2_𝜃)
        # Do some random initialization with 𝜎=0.001
        self.w_mu = nn.Parameter(
            torch.FloatTensor(*weights_shape).normal_(mean=0, std=prior_std)
        )
        # proxy for variance
        # log(1 + exp(ρ))◦ eps
        self.w_p = nn.Parameter(
            torch.FloatTensor(*weights_shape).normal_(mean=0, std=prior_std)
        )
        if self.include_bias:
            self.b_mu = nn.Parameter(
                torch.zeros(out_channels)
            )
            # proxy for variance
            self.b_p = nn.Parameter(
                torch.zeros(out_channels)
            )

    def forward(self, x, get_wb=False):
        w = reparameterize(self.w_mu, self.w_p)

        if self.include_bias:
            b = reparameterize(self.b_mu, self.b_p)
        else:
            b = 0

        z = F.conv2d(x, w, b, stride=1, padding=self.padding, dilation=1)
        if get_wb:
            return z, w, b
        return z
