import math
from typing import List, Optional, Tuple

import torch
from torch import nn
from torch.distributions import Normal
from torch.nn import functional as F

T = torch.Tensor


class ScaleMixtureGaussian(nn.Module):
    def __init__(
        self,
        pi: float = 0.5,
        log_sigma_one: float = -0.,
        log_sigma_two: float = -6.
    ):
        super().__init__()
        self.pi = pi
        self.sigma_one = math.exp(log_sigma_one)
        self.sigma_two = math.exp(log_sigma_two)
        self.gaussian1 = torch.distributions.Normal(0, self.sigma_one)
        self.gaussian2 = torch.distributions.Normal(0, self.sigma_two)
        self.device_set = False

    def log_prob(self, x: T):
        prob1 = torch.exp(self.gaussian1.log_prob(x))
        prob2 = torch.exp(self.gaussian2.log_prob(x))
        return (torch.log(self.pi * prob1 + (1 - self.pi) * prob2)).sum()


class BayesianLayer(nn.Module):
    """
    This should work with any layer that has a weight and bias parameter.
    Bias is optional, as if it is None, this will ignore the parameter.

    WARNING: If you have size mismatches, this may cause some weird errors coming
    form low level c++ backends. I believe this is because we are manually setting
    the weight and bias as a tensor with `setattr.` and we therefore skip any nicely
    formatted Pytorch errors.
    """
    def __init__(
        self,
        base_layer: nn.Module,
        weight_mu_rng: Tuple[float, float] = (-0.2, 0.2),
        weight_log_rho_rng: Tuple[float, float] = (-5., -4.),
        bias_mu_rng: Tuple[float, float] = (-0.2, 0.2),
        bias_log_rho_rng: Tuple[float, float] = (-5., -4.),
        weight_prior: Optional[ScaleMixtureGaussian] = ScaleMixtureGaussian(),
        bias_prior: Optional[ScaleMixtureGaussian] = ScaleMixtureGaussian()
    ) -> None:
        super().__init__()

        w_size = base_layer.weight.size()  # type: ignore
        self.weight_mu = nn.Parameter(torch.zeros(*w_size).uniform_(*weight_mu_rng))
        self.weight_rho = nn.Parameter(torch.zeros(*w_size).uniform_(*weight_log_rho_rng))
        delattr(base_layer, 'weight')

        self.bias_mu = None
        self.bias_rho = None

        if base_layer.bias is not None:
            b_size = base_layer.bias.size()  # type: ignore
            self.bias_mu = nn.Parameter(torch.zeros(*b_size).uniform_(*bias_mu_rng))
            self.bias_rho = nn.Parameter(torch.zeros(*b_size).uniform_(*bias_log_rho_rng))
            delattr(base_layer, 'bias')

        self.weight_prior = weight_prior
        self.bias_prior = bias_prior
        if self.weight_prior is None and self.bias_prior is None:
            raise ValueError("weight and bias prior cannot be None at the same time")

        self.log_posterior: List[T] = []
        self.log_prior: List[T] = []
        self.layer = base_layer

    def kl(self) -> Tuple[T, T]:
        post, prior = torch.stack(self.log_posterior).mean(dim=0), torch.stack(self.log_prior).mean(dim=0)
        self.log_posterior, self.log_prior = [], []
        return post, prior

    def forward(self, x: T, samples: int = 1) -> T:
        w_dist = Normal(self.weight_mu, F.softplus(self.weight_rho))
        w = w_dist.rsample()
        setattr(self.layer, 'weight', w)

        if self.bias_mu is not None:
            b_dist = Normal(self.bias_mu, F.softplus(self.bias_rho))
            b = b_dist.rsample()
            setattr(self.layer, 'bias', b)

        if self.training:
            self.log_prior.append(
                self.weight_prior.log_prob(w).sum() if self.weight_prior is not None else 0
                + (self.bias_prior.log_prob(b).sum() if self.bias_prior is not None else 0)
            )

            self.log_posterior.append(
                w_dist.log_prob(w).sum()
                + (b_dist.log_prob(b).sum() if self.bias_mu is not None else 0)
            )

        return self.layer(x)


class BayesianNet(nn.Module):
    def kl(self) -> T:
        log_post, log_prior = 0.0, 0.0
        for m in self.modules():
            if isinstance(m, BayesianLayer):
                post, prior = m.kl()
                log_post += post  # type: ignore
                log_prior += prior  # type: ignore

        return log_post - log_prior  # type: ignore

    def mc(self, x: T, samples: int = 1) -> T:
        outs = []
        for _ in range(samples):
            outs.append(self(x))

        return torch.stack(outs).mean(dim=0)
