import math
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair


class LinearMFMC(nn.Module):
    def __init__(self, in_features, out_features, prior_precision=1e0, var_dampening=1.0, bias=True, eps=1e-12):
        super(LinearMFMC, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_precision = prior_precision
        self.eps = eps
        self.has_bias = bias
        log_var_init =  math.log(var_dampening / prior_precision)

        self.weight = Parameter(torch.Tensor(out_features, in_features))
        self.weight_variance = Parameter(torch.ones_like(self.weight) * log_var_init)
        if self.has_bias:
            self.bias = Parameter(torch.Tensor(out_features))
            self.bias_variance = Parameter(torch.ones(out_features) * log_var_init)
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        outputs_mean = F.linear(x, self.weight, self.bias)
        outputs_variance = F.linear(x**2, torch.exp(self.weight_variance)).clamp(min=self.eps)
        if self.has_bias:
            outputs_variance += torch.exp(self.bias_variance)
        normal_dist = torch.distributions.normal.Normal(torch.zeros_like(outputs_mean), torch.ones_like(outputs_mean))
        return outputs_mean + torch.sqrt(outputs_variance + self.eps) * normal_dist.sample()

    def kl_div(self):
        kld = 0.5 * (-self.weight_variance - math.log(self.prior_precision) + self.prior_precision * (torch.exp(self.weight_variance) + self.weight**2) - 1).sum()
        if self.has_bias:
            kld += 0.5 * (-self.bias_variance - math.log(self.prior_precision) + self.prior_precision * (torch.exp(self.bias_variance) + self.bias**2) - 1).sum()
        return kld


class Conv2dMFMC(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, prior_precision=1e0, var_dampening=1.0, eps=1e-12):
        self.out_channels = out_channels
        self.prior_precision = prior_precision
        self.eps = eps
        self.has_bias = bias
        log_var_init =  math.log(var_dampening / prior_precision)

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(Conv2dMFMC, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias, 'zeros')

        self.weight_variance = Parameter(torch.ones_like(self.weight) * log_var_init)
        if self.has_bias:
            self.bias_variance = Parameter(torch.ones_like(self.bias) * log_var_init)

    def forward(self, x):
        outputs_mean = F.conv2d(
            x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        outputs_variance = F.conv2d(x**2,
            torch.exp(self.weight_variance), None, self.stride, self.padding, self.dilation, self.groups).clamp(min=self.eps)
        if self.has_bias:
            outputs_variance += torch.exp(self.bias_variance).view(1, -1, 1, 1)
        normal_dist = torch.distributions.normal.Normal(torch.zeros_like(outputs_mean), torch.ones_like(outputs_mean))
        return outputs_mean + torch.sqrt(outputs_variance + self.eps) * normal_dist.sample()

    def kl_div(self):
        kld = 0.5 * (-self.weight_variance - math.log(self.prior_precision) + self.prior_precision * (torch.exp(self.weight_variance) + self.weight**2) - 1).sum()
        if self.has_bias:
            kld +=  0.5 * (-self.bias_variance - math.log(self.prior_precision) + self.prior_precision * (torch.exp(self.bias_variance) + self.bias**2) - 1).sum()
        return kld