import pyro
from pyro import distributions as pdist
from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
from torch import nn
from torch.distributions import constraints
from pyro.infer.predictive import _predictive
from pyro.nn import PyroModule, PyroSample
from pyro.optim import Adam, SGD, ClippedAdam
import math

import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn import init as init
import logging
from sbsep.util import linear_aux
from sbsep.dist import bnn_helper

logger = logging.getLogger(__name__)


class Linear(nn.Module):
    __constants__ = ["in_features", "out_features"]
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        return linear_aux(input, self.weight, self.bias)

    def extra_repr(self) -> str:
        return "in_features={}, out_features={}, bias={}".format(
            self.in_features, self.out_features, self.bias is not None
        )


class BNN(PyroModule):
    def __init__(self, dims=(1, 32, 32, 1), name="a"):
        super().__init__()
        self.name = name
        self.dims = dims
        self.nlayers = len(dims) - 1
        self.activation = nn.ReLU()
        self.nsigma = torch.tensor(1e-4)
        for k, (da, db) in enumerate(zip(dims, dims[1:])):
            layer = PyroModule[Linear](da, db)
            layer.weight = PyroSample(
                pdist.Normal(0.0, 1.0).expand([db, da]).to_event(2)
            )
            layer.bias = PyroSample(pdist.Normal(0.0, 1.0).expand([db]).to_event(1))
            setattr(self, f"{self.name}#layer_{k}", layer)

    def forward(self, data):
        x, y = data
        sigma = pyro.sample(
            f"{self.name}#sigma",
            pdist.InverseGamma(torch.tensor(30.0), torch.tensor(1.45)),
        )

        with pyro.plate("data", x.shape[0]):

            for k in range(self.nlayers):
                layer = getattr(self, f"{self.name}#layer_{k}")
                if k == 0:
                    r = self.activation(layer(x))
                elif k < self.nlayers - 1:
                    r = self.activation(layer(r))
                else:
                    mu = layer(r)

            obs = pyro.sample("obs", pdist.Normal(mu, sigma).to_event(1), obs=y)

    def guide(self, data):

        x, y = data
        dims = self.dims

        sigma_alpha = pyro.param(
            f"{self.name}#sigma_alpha",
            torch.tensor(100.0),
            constraint=constraints.positive,
        )
        sigma_beta = pyro.param(
            f"{self.name}#sigma_beta",
            torch.tensor(1.0),
            constraint=constraints.positive,
        )

        sigma = pyro.sample(
            f"{self.name}#sigma", pdist.InverseGamma(sigma_alpha, sigma_beta)
        )

        with pyro.plate("data", x.shape[0]):
            for k, (da, db) in enumerate(zip(dims, dims[1:])):
                guess_weight_loc = pyro.param(
                    f"{self.name}#weight_loc_{k}", torch.zeros(db, da)
                )
                guess_weight_scale = pyro.param(
                    f"{self.name}#weight_scale_{k}",
                    0.01 * torch.ones(db, da),
                    constraint=constraints.positive,
                )
                _ = pyro.sample(
                    f"{self.name}#layer_{k}.weight",
                    pdist.Normal(guess_weight_loc, guess_weight_scale).to_event(2),
                )

                guess_bias_loc = pyro.param(
                    f"{self.name}#bias_loc_{k}", torch.zeros(db)
                )
                guess_bias_scale = pyro.param(
                    f"{self.name}#bias_scale_{k}",
                    0.01 * torch.ones(db),
                    constraint=constraints.positive,
                )
                _ = pyro.sample(
                    f"{self.name}#layer_{k}.bias",
                    pdist.Normal(guess_bias_loc, guess_bias_scale).to_event(1),
                )

    def infer_parameters(
        self,
        xdata,
        ydata,
        lr,
        momentum=0.9,
        num_epochs=1000,
        elbo="mf",
        opt="sgd",
    ):
        if opt == "sgd":
            optim = SGD({"lr": lr, "momentum": momentum, "nesterov": True})
        elif opt == "ca":
            optim = ClippedAdam({"lr": lr})
        else:
            optim = Adam({"lr": lr})
        if elbo == "mf":
            elbo = TraceMeanField_ELBO()
        else:
            elbo = Trace_ELBO()

        svi = SVI(self, self.guide, optim, elbo)
        losses = []
        for step in range(num_epochs):
            total_loss = 0.0
            loss = svi.step(**{"x": xdata, "y": ydata})
            total_loss += loss
            losses += [total_loss]
            if step % 100 == 0:
                print(".", end="")
        return losses

    def get_samples(self, n_samples, param_names, data):
        samples = get_posteriors(
            guide=self.guide,
            model_kwargs={"data": data},
            num_samples=n_samples,
            rv_extract=param_names,
        )
        return samples


class SimpleBNN(PyroModule):
    def __init__(self, dims=(1, 32, 1), name="a"):
        super().__init__()
        self.name = name
        self.dims = dims
        self.nlayers = len(dims) - 1
        self.activation = nn.ReLU()
        self.nsigma = torch.tensor(1e-4)

        self.prior_params = []
        for k, (da, db) in enumerate(zip(dims, dims[1:])):
            weight_loc = torch.zeros(da, db)
            weight_scale = torch.ones(da, db)
            bias_loc = torch.zeros(db)
            bias_scale = torch.ones(db)
            self.prior_params += [(weight_loc, weight_scale, bias_loc, bias_scale)]

    def forward(self, data):
        x, y = data
        sigma = pyro.sample(
            f"{self.name}#sigma",
            pdist.InverseGamma(torch.tensor(30.0), torch.tensor(1.45)),
        )
        mu = bnn_helper(self.prior_params, x, name=self.name)
        obs = pyro.sample("obs", pdist.Normal(mu, sigma).to_event(1), obs=y)
        return obs

    def guide(self, data):
        x, y = data
        dims = self.dims
        sigma_alpha = pyro.param(
            f"{self.name}#sigma_alpha",
            torch.tensor(100.0),
            constraint=constraints.positive,
        )
        sigma_beta = pyro.param(
            f"{self.name}#sigma_beta",
            torch.tensor(1.0),
            constraint=constraints.positive,
        )

        sigma = pyro.sample(
            f"{self.name}#sigma", pdist.InverseGamma(sigma_alpha, sigma_beta)
        )

        gparams = []

        for k, (da, db) in enumerate(zip(dims, dims[1:])):
            guess_weight_loc = pyro.param(
                f"{self.name}#weight_loc_{k}", torch.zeros(da, db)
            )
            guess_weight_scale = pyro.param(
                f"{self.name}#weight_scale_{k}",
                0.01 * torch.ones(da, db),
                constraint=constraints.positive,
            )

            guess_bias_loc = pyro.param(f"{self.name}#bias_loc_{k}", torch.zeros(db))
            guess_bias_scale = pyro.param(
                f"{self.name}#bias_scale_{k}",
                0.01 * torch.ones(db),
                constraint=constraints.positive,
            )
            gparams += [
                (guess_weight_loc, guess_weight_scale, guess_bias_loc, guess_bias_scale)
            ]
        mu = bnn_helper(gparams, x, name=self.name)


def devectorize(vectorize_params: dict) -> dict:
    devectorize_params = {}
    for k, v in vectorize_params.items():
        if len(v.shape) > 1 and v.shape[0] > 1:
            for i, x in enumerate(v.T):
                devectorize_params[f"{k}_{i}"] = x
        else:
            devectorize_params[k] = v
    return devectorize_params


def get_posteriors(guide, num_samples, rv_extract, model_kwargs):
    post_samples = {}
    posterior_samples = _predictive(
        guide, post_samples, num_samples, parallel=False, model_kwargs=model_kwargs
    )
    if rv_extract is None:
        post_numpy = {k: v.detach().cpu().numpy() for k, v in posterior_samples.items()}
    else:
        post_numpy = {
            k: v.detach().cpu().numpy()
            for k, v in posterior_samples.items()
            if k in rv_extract
        }
    post_flat = devectorize(post_numpy)
    return post_flat
