import pyro
import torch
from pyro import distributions as dist
from torch import nn
import torch.nn.functional as F
from pyro.distributions.torch_distribution import TorchDistribution


def linear_(input, weight, bias=None):
    ret = torch.einsum("...k,kj->...j", input, weight)
    if bias is not None:
        ret += bias
    return ret


class BNNLayerDist(TorchDistribution):
    has_rsample = True

    def __init__(self, params, data, final=False):
        weight_loc, weight_scale, bias_loc, bias_scale = params
        da, db = weight_loc.shape
        (dc,) = bias_loc.shape
        assert weight_loc.shape == weight_scale.shape
        assert bias_loc.shape == bias_scale.shape
        assert dc == db

        self.params = params
        self.input = data
        self.activation = nn.ReLU()
        self.final = final
        batch_shape = data.shape[:-1]
        event_shape = (db,)
        super().__init__(batch_shape, event_shape)

    def sample(self, sample_shape=torch.Size()):
        return self.rsample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        weight_loc, weight_scale, bias_loc, bias_scale = self.params

        mu = linear_(self.input, weight_loc, bias_loc)
        sigma_a = torch.matmul(torch.pow(self.input, 2), torch.pow(weight_scale, 2))
        sigma = torch.pow(sigma_a + torch.pow(bias_scale, 2), 0.5)

        # output = pyro.sample(f"{self.name}.out", dist.Normal(mu, sigma).to_event(1))
        output = dist.Normal(mu, sigma).to_event(1).rsample()
        return output

    def log_prob(self, value, *args, **kwargs):
        weight_loc, weight_scale, bias_loc, bias_scale = self.params

        mu = linear_(self.input, weight_loc, bias_loc)
        sigma_a = torch.matmul(torch.pow(self.input, 2), torch.pow(weight_scale, 2))
        sigma = torch.pow(sigma_a + torch.pow(bias_scale, 2), 0.5)

        # sum(-1) should be checked
        lp = dist.Normal(mu, sigma).log_prob(value).sum(-1)
        # lp = dist.Normal(mu, sigma).to_event(1).log_prob(value).sum(-1)
        return lp


def bnn_helper(gparams, data, name="bnn0", activation=nn.ReLU(), transforms=[]):
    assert len(set([len(item) for item in gparams])) == 1
    for params in gparams:
        weight_loc, weight_scale, bias_loc, bias_scale = params
        da, db = weight_loc.shape
        (dc,) = bias_loc.shape
        assert weight_loc.shape == weight_scale.shape
        assert bias_loc.shape == bias_scale.shape
        assert dc == db

    assert data.shape[-1] == gparams[0][0].shape[0]
    with pyro.plate(f"{name}#data", data.size(0)):
        for k, params in enumerate(gparams):
            if k < len(gparams) - 1:
                if k == 0:
                    r = pyro.sample(
                        f"{name}#layer_{k}",
                        BNNLayerDist(params, data),
                    )
                else:
                    r = pyro.sample(
                        f"{name}#layer_{k}",
                        BNNLayerDist(params, r),
                    )
                r = activation(r)
            else:
                cdist = dist.TransformedDistribution(
                    BNNLayerDist(params, r), transforms
                )
                r = pyro.sample(f"{name}#layer_{k}", cdist)
    return r
