import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import repeat
import collections

from torch.nn import Parameter
from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver
from torch.quantization.qconfig import QConfig


def get_kernel_size(x, n):
    if isinstance(x, collections.abc.Iterable):
        return tuple(x)
    return tuple(repeat(x, n))


class _BaseVariationalLayer(nn.Module):
    """
    The base variational layer is implemented as a :class:`torch.nn.Module` that, when called on two distributions
    :math:`Q` and :math:`P` returns a :obj:`torch.Tensor` that represents the KL divergence between two gaussians.
    """

    def __init__(self):
        super().__init__()
        self._dnn_to_bnn_flag = False

    @property
    def dnn_to_bnn_flag(self):
        return self._dnn_to_bnn_flag

    @dnn_to_bnn_flag.setter
    def dnn_to_bnn_flag(self, value):
        self._dnn_to_bnn_flag = value

    def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
        """
        Calculates kl divergence between two gaussians (Q || P).

        :param mu_q: mean of distribution Q
        :type mu_q: torch.Tensor
        :sigma_q: deviation of distribution Q
        :type sigma_q: torch.Tensor
        :mu_p: mean of distribution P
        :type mu_p: torch.Tensor
        :sigma_p: deviation of distribution P
        :type sigma_p: torch.Tensor

        :return: the KL divergence between Q and P.
        """
        kl = torch.log(sigma_p) - torch.log(
            sigma_q) + (sigma_q ** 2 + (mu_q - mu_p) ** 2) / (2 * (sigma_p ** 2)) - 0.5
        return kl.mean()


class LinearFlipout(_BaseVariationalLayer):
    """
    Alternative implementation of Bayesian Linear layer.
    """

    def __init__(self,
                 in_features,
                 out_features,
                 prior_mean=0,
                 prior_variance=1,
                 posterior_mu_init=0,
                 posterior_rho_init=-3.0,
                 bias=True):
        super(LinearFlipout, self).__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.prior_mean = prior_mean
        self.prior_variance = prior_variance
        self.posterior_mu_init = posterior_mu_init
        self.posterior_rho_init = posterior_rho_init

        self.mu_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.rho_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.register_buffer('eps_weight',
                             torch.Tensor(out_features, in_features),
                             persistent=False)
        self.register_buffer('prior_weight_mu',
                             torch.Tensor(out_features, in_features),
                             persistent=False)
        self.register_buffer('prior_weight_sigma',
                             torch.Tensor(out_features, in_features),
                             persistent=False)

        if bias:
            self.mu_bias = nn.Parameter(torch.Tensor(out_features))
            self.rho_bias = nn.Parameter(torch.Tensor(out_features))
            self.register_buffer('prior_bias_mu', torch.Tensor(out_features), persistent=False)
            self.register_buffer('prior_bias_sigma',
                                 torch.Tensor(out_features),
                                 persistent=False)
            self.register_buffer('eps_bias', torch.Tensor(out_features), persistent=False)

        else:
            self.register_buffer('prior_bias_mu', None, persistent=False)
            self.register_buffer('prior_bias_sigma', None, persistent=False)
            self.register_parameter('mu_bias', None)
            self.register_parameter('rho_bias', None)
            self.register_buffer('eps_bias', None, persistent=False)

        self.init_parameters()
        self.quant_prepare = False

    def prepare(self):
        self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
            QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric),
                    activation=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))) for _
            in range(4)])
        self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
            QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8),
                    activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
        self.dequant = torch.quantization.DeQuantStub()
        self.quant_prepare = True

    def init_parameters(self):
        # init prior mu
        self.prior_weight_mu.fill_(self.prior_mean)
        self.prior_weight_sigma.fill_(self.prior_variance)

        # init weight and base perturbation weights
        self.mu_weight.data.normal_(mean=self.posterior_mu_init, std=0.1)
        self.rho_weight.data.normal_(mean=self.posterior_rho_init, std=0.1)

        if self.mu_bias is not None:
            self.prior_bias_mu.fill_(self.prior_mean)
            self.prior_bias_sigma.fill_(self.prior_variance)
            self.mu_bias.data.normal_(mean=self.posterior_mu_init, std=0.1)
            self.rho_bias.data.normal_(mean=self.posterior_rho_init, std=0.1)

    def kl_loss(self):
        sigma_weight = torch.log1p(torch.exp(self.rho_weight))
        kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
        if self.mu_bias is not None:
            sigma_bias = torch.log1p(torch.exp(self.rho_bias))
            kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
        return kl

    def forward(self, x, return_kl=True):
        if self.dnn_to_bnn_flag:
            return_kl = False
        # sampling delta_W
        sigma_weight = torch.log1p(torch.exp(self.rho_weight))
        eps_weight = self.eps_weight.data.normal_()
        delta_weight = sigma_weight * eps_weight
        # delta_weight = (sigma_weight * self.eps_weight.data.normal_())

        # get kl divergence
        if return_kl:
            kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu,
                             self.prior_weight_sigma)

        bias = None
        if self.mu_bias is not None:
            sigma_bias = torch.log1p(torch.exp(self.rho_bias))
            bias = (sigma_bias * self.eps_bias.data.normal_())
            if return_kl:
                kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
                                      self.prior_bias_sigma)

        # linear outputs
        outputs = F.linear(x, self.mu_weight, self.mu_bias)
        sign_input = x.clone().uniform_(-1, 1).sign()
        sign_output = outputs.clone().uniform_(-1, 1).sign()
        x_tmp = x * sign_input
        perturbed_outputs_tmp = F.linear(x_tmp, delta_weight, bias)
        perturbed_outputs = perturbed_outputs_tmp * sign_output
        out = outputs + perturbed_outputs

        if self.quant_prepare:
            # quint8 quantstub
            x = self.quint_quant[0](x)  # input
            outputs = self.quint_quant[1](outputs)  # output
            sign_input = self.quint_quant[2](sign_input)
            sign_output = self.quint_quant[3](sign_output)
            x_tmp = self.quint_quant[4](x_tmp)
            perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp)  # output
            perturbed_outputs = self.quint_quant[6](perturbed_outputs)  # output
            out = self.quint_quant[7](out)  # output

            # qint8 quantstub
            sigma_weight = self.qint_quant[0](sigma_weight)  # weight
            mu_weight = self.qint_quant[1](self.mu_weight)  # weight
            eps_weight = self.qint_quant[2](eps_weight)  # random variable
            delta_weight = self.qint_quant[3](delta_weight)  # multiply activation

        # returning outputs + perturbations
        if return_kl:
            return out, kl
        return out


class AmorNet(nn.Module):
    def __init__(self, in_feat = 8, out_feat = 128, device=None, dtype=None): # 64 for mean, 64 for variance
        super(AmorNet, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.e1 = nn.Linear(in_feat, out_feat, device=device, dtype=dtype)

    def forward(self, x):
        return self.e1(x)


class ContextualLinearFlipout(_BaseVariationalLayer):
    def __init__(self,
                 in_features,
                 out_features,
                 prior_mean=0,
                 prior_variance=1,
                 posterior_mu_init=0,
                 posterior_rho_init=-3.0,
                 bias=True):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.prior_mean = prior_mean
        self.prior_variance = prior_variance
        self.posterior_mu_init = posterior_mu_init
        self.posterior_rho_init = posterior_rho_init

        self.mu_weight = AmorNet(in_feat=in_features, out_feat=out_features * in_features)
        self.rho_weight = AmorNet(in_feat=in_features, out_feat=out_features * in_features)
        self.register_buffer('eps_weight',
                             torch.Tensor(out_features, in_features),
                             persistent=False)
        self.register_buffer('prior_weight_mu',
                             torch.Tensor(out_features, in_features),
                             persistent=False)
        self.register_buffer('prior_weight_sigma',
                             torch.Tensor(out_features, in_features),
                             persistent=False)

        if bias:
            self.mu_bias = AmorNet(in_feat=in_features, out_feat=out_features)
            self.rho_bias = AmorNet(in_feat=in_features, out_feat=out_features)
            self.register_buffer('prior_bias_mu', torch.Tensor(out_features), persistent=False)
            self.register_buffer('prior_bias_sigma',
                                 torch.Tensor(out_features),
                                 persistent=False)
            self.register_buffer('eps_bias', torch.Tensor(out_features), persistent=False)
        else:
            self.register_buffer('prior_bias_mu', None, persistent=False)
            self.register_buffer('prior_bias_sigma', None, persistent=False)
            self.register_parameter('mu_bias', None)
            self.register_parameter('rho_bias', None)
            self.register_buffer('eps_bias', None, persistent=False)

        self.apply(self._init_weights)

    @torch.no_grad()
    def _init_weights(self, m: nn.Module):
        if isinstance(m, nn.Linear):
            # He/Kaiming for ReLU-family
            nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if m.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(m.bias, -bound, bound)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
            if m.weight is not None: nn.init.ones_(m.weight)
            if m.bias  is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def kl_div_stable(self, mu_q, sigma_q, mu_p, sigma_p):
        eps = 1e-6
        kl = (math.log(sigma_p+eps) - torch.log(sigma_q.to(torch.float64)+eps) +
              (sigma_q.to(torch.float64)**2 + (mu_q.to(torch.float64) - mu_p)**2) / (2 * (sigma_p**2)+eps) - 0.5)
        return kl.sum()

    def kl_loss(self, x):
        mu_weight = self.mu_weight(x)
        sigma_weight = torch.log1p(torch.exp(self.rho_weight(x)))
        kl = self.kl_div(mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
        if self.mu_bias is not None:
            mu_bias = self.mu_bias(x)
            sigma_bias = torch.log1p(torch.exp(self.rho_bias(x)))
            kl += self.kl_div(mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
        return kl

    def forward(self, x, return_kl=True):
        # sampling delta_W
        mu_weight = self.mu_weight(x).reshape(-1, self.out_features, self.in_features)
        sigma_weight = torch.exp(self.rho_weight(x))
        sigma_weight = torch.log1p(torch.exp(self.rho_weight(x).reshape(-1, self.out_features, self.in_features)))
        eps_weight = self.eps_weight.data.normal_()
        delta_weight = sigma_weight * eps_weight

        # get kl divergence
        if return_kl:
            kl = self.kl_div(mu_weight, sigma_weight, self.prior_weight_mu,
                             self.prior_weight_sigma)

        bias = None
        if self.mu_bias is not None:
            mu_bias = self.mu_bias(x)
            sigma_bias = torch.log1p(torch.exp(self.rho_bias(x)))
            bias = (sigma_bias * self.eps_bias.data.normal_())
            if return_kl:
                kl = kl + self.kl_div(mu_bias, sigma_bias, self.prior_bias_mu,
                                      self.prior_bias_sigma)

        # linear outputs
        if self.mu_bias is not None:
            outputs = torch.matmul(mu_weight, x.unsqueeze(-1)).squeeze(-1) + mu_bias
        else:
            outputs = torch.matmul(mu_weight, x.unsqueeze(-1)).squeeze(-1)
        sign_input = x.clone().uniform_(-1, 1).sign()
        sign_output = outputs.clone().uniform_(-1, 1).sign()
        x_tmp = x * sign_input
        if bias is not None:
            perturbed_outputs_tmp = torch.matmul(delta_weight, x_tmp.unsqueeze(-1)).squeeze(-1) + bias
        else:
            perturbed_outputs_tmp = torch.matmul(delta_weight, x_tmp.unsqueeze(-1)).squeeze(-1)
        perturbed_outputs = perturbed_outputs_tmp * sign_output
        out = outputs + perturbed_outputs

        # returning outputs + perturbations
        if return_kl:
            return out, kl
        return out