import logging
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.util import torch_item
from pyro.nn import PyroModule
from pyro.optim import ClippedAdam
from pyro.util import torch_isnan

logger = logging.getLogger(__name__)


class FullyConnected(nn.Sequential):
    def __init__(self, sizes, final_activation=None):
        layers = []
        for in_size, out_size in zip(sizes, sizes[1:]):
            layers.append(nn.Linear(in_size, out_size))
            layers.append(nn.ELU())
        layers.pop(-1)
        if final_activation is not None:
            layers.append(final_activation)
        super().__init__(*layers)

    def append(self, layer):
        assert isinstance(layer, nn.Module)
        self.add_module(str(len(self)), layer)


class DistributionNet(nn.Module):
    @staticmethod
    def get_class(dtype):
        for cls in DistributionNet.__subclasses__():
            if cls.__name__.lower() == dtype + "net":
                return cls
        raise ValueError("dtype not supported: {}".format(dtype))


class BernoulliNet(DistributionNet):
    def __init__(self, sizes):
        assert len(sizes) >= 1
        super().__init__()
        self.fc = FullyConnected(sizes + [1])

    def forward(self, x):
        logits = self.fc(x).squeeze(-1).clamp(min=-10, max=10)
        return logits,

    @staticmethod
    def make_dist(logits):
        return dist.Bernoulli(logits=logits)


class NormalNet(DistributionNet):
    def __init__(self, sizes):
        assert len(sizes) >= 1
        super().__init__()
        self.fc = FullyConnected(sizes + [2])

    def forward(self, x):
        loc_scale = self.fc(x)
        loc = loc_scale[..., 0].clamp(min=-1e6, max=1e6)
        scale = nn.functional.softplus(loc_scale[..., 1]).clamp(min=1e-3, max=1e6)
        return loc, scale

    @staticmethod
    def make_dist(loc, scale):
        return dist.Normal(loc, scale)


class DiagNormalNet(nn.Module):
    def __init__(self, sizes):
        assert len(sizes) >= 2
        self.dim = sizes[-1]
        super().__init__()
        self.fc = FullyConnected(sizes[:-1] + [self.dim * 2])

    def forward(self, x):
        loc_scale = self.fc(x)
        loc = loc_scale[..., :self.dim].clamp(min=-1e2, max=1e2)
        scale = nn.functional.softplus(loc_scale[..., self.dim:]).add(1e-3).clamp(max=1e2)
        return loc, scale


class DiagBernoulliNet(nn.Module):
    def __init__(self, sizes):
        assert len(sizes) >= 2
        self.dim = sizes[-1]
        super().__init__()
        self.fc = FullyConnected(sizes[:-1] + [self.dim])

    def forward(self, x):
        logits = self.fc(x).squeeze(-1).clamp(min=0, max=11)
        return logits

    @staticmethod
    def make_dist(logits):
        return dist.Bernoulli(logits=logits)


class Guide(PyroModule):
    def __init__(self, config):
        self.latent_dim_Zf = config["latent_dim_Zf"]

        super().__init__()

        self.zf_nn = FullyConnected(
            [config["feature_dim"] - 1]
            + [config["hidden_dim"]] * (config["num_layers"] - 1),
            final_activation=nn.ELU(),
        )
        self.zf0_nn = DiagNormalNet([config["hidden_dim"], config["latent_dim_Zf"]])

    def forward(self, x, t=None, y=None, size=None):
        if size is None:
            size = x.size(0)
        with pyro.plate("data", size, subsample=x):
            zf = pyro.sample("zf", self.zf_dist(x[:, 0:1], x[:, 1:9]))

    def zf(self, x):
        with pyro.plate("data", x.size(0)):
            zf = pyro.sample("zf", self.zf_dist(x[:, 0:1], x[:, 1:9]))
        return zf

    def zf_dist(self, t, x):
        hidden = self.zf_nn(x.float())
        params0 = self.zf0_nn(hidden)
        return dist.Normal(*params0).to_event(1)


class TraceCausalEffect_ELBO(Trace_ELBO):
    def _differentiable_loss_particle(self, model_trace, guide_trace):

        blocked_names = [name for name, site in guide_trace.nodes.items()
                         if site["type"] == "sample" and site["is_observed"]]
        blocked_guide_trace = guide_trace.copy()
        for name in blocked_names:
            del blocked_guide_trace.nodes[name]
        loss, surrogate_loss = super()._differentiable_loss_particle(
            model_trace, blocked_guide_trace)

        for name in blocked_names:
            log_q = guide_trace.nodes[name]["log_prob_sum"]
            loss = loss - torch_item(log_q)
            surrogate_loss = surrogate_loss - log_q

        return loss, surrogate_loss

    @torch.no_grad()
    def loss(self, model, guide, *args, **kwargs):
        return torch_item(self.differentiable_loss(model, guide, *args, **kwargs))


class Model(PyroModule):
    def __init__(self, config):
        self.latent_dim_Zf = config["latent_dim_Zf"]

        super().__init__()
        self.x_nn = DiagNormalNet([config["latent_dim_Zf"]] + [config["hidden_dim"]] * config["num_layers"] + [4])
        OutcomeNet = DistributionNet.get_class(config["outcome_dist"])

        self.zf_nn = DiagNormalNet(
            [5]
            + [config["hidden_dim"]] * config["num_layers"]
            + [config["latent_dim_Zf"]]
        )

    def forward(self, x, t=None, y=None, size=None):
        if size is None:
            size = x.size(0)
        with pyro.plate("data", size, subsample=x):
            zf = pyro.sample("zf", self.zf_dist(x[:, 0:5]))
            x_continuous = pyro.sample("x_cont", self.x_dist_continuous(zf), obs=x[:, 5:9])

        return x_continuous

    def zf_dist(self, t_w):
        loc, scale = self.zf_nn(t_w.float())
        return dist.Normal(loc, scale).to_event(1)

    def x_dist_continuous(self, zf):
        loc, scale = self.x_nn(zf.float())
        return dist.Normal(loc, scale).to_event(1)


class CFDiVAE_Model(nn.Module):
    def __init__(self, feature_dim, outcome_dist, latent_dim_Zf, hidden_dim, num_layers):
        config = dict(feature_dim=feature_dim,
                      latent_dim_Zf=latent_dim_Zf,
                      hidden_dim=hidden_dim,
                      num_layers=num_layers)

        config["outcome_dist"] = outcome_dist
        self.feature_dim = feature_dim

        super().__init__()
        self.model = Model(config)
        self.guide = Guide(config)
        self.cuda()

    def fit(self, x, t, y,
            num_epochs,
            batch_size,
            learning_rate,
            learning_rate_decay,
            weight_decay):

        assert x.dim() == 2 and x.size(-1) == self.feature_dim
        assert t.shape == x.shape[:1]
        assert y.shape == y.shape[:1]

        dataset = TensorDataset(x, t, y)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        logger.info("Training with {} minibatches per epoch".format(len(dataloader)))
        num_steps = num_epochs * len(dataloader)
        optim = ClippedAdam({"lr": learning_rate,
                             "weight_decay": weight_decay,
                             "lrd": learning_rate_decay ** (1 / num_steps)})
        svi = SVI(self.model, self.guide, optim, TraceCausalEffect_ELBO())
        losses = []
        for epoch in range(num_epochs):
            for x, t, y in dataloader:
                loss = svi.step(x, t, y, size=len(dataset)) / len(dataset)
                logger.debug("step {: >5d} loss = {:0.6g}".format(len(losses), loss))
                assert not torch_isnan(loss)
                losses.append(loss)
            print(int(epoch))
        return losses
