from torch import nn
import torch
import numpy as np
from v2.swag_repo.swag.models.preresnet import add_tag

class StochasticMLP(nn.Module):
    def __init__(self, arch):
        super().__init__()
        self.arch = arch
        self.z_sz = arch[-2]

        layers = [nn.Linear(arch[0], arch[1])]
        for l in range(1, len(arch) - (2)): # if there are 5 layers, will loop l = 1
            layers += [nn.ReLU(), nn.Linear(arch[l], arch[l+1])]

        print(layers)

        self.pred_mu = nn.Sequential(*layers)

        sigma_layers = [nn.Linear(arch[0], arch[1])]
        for l in range(1, len(arch) - (2)): # if there are 5 layers, will loop l = 1
            sigma_layers += [nn.ReLU(), nn.Linear(arch[l], arch[l+1])]

        self.pred_sigma = nn.Sequential(*(sigma_layers + [nn.Softplus()])) # nn.Sigmoid(), Limit()

        self.cls = nn.Sequential(
            nn.Linear(arch[-2], arch[-1]),
        )

        self.pred_mu.apply(lambda module: add_tag(module, 0))
        self.pred_sigma.apply(lambda module: add_tag(module, 0))

        self.cls.apply(lambda module: add_tag(module, 1))


    def forward(self, x, repr=False):
        means = self.pred_mu(x) # n, enc_sz
        stds = self.pred_sigma(x)

        eps = torch.randn_like(means)
        z = means + stds * eps

        if repr:
            distr = torch.distributions.normal.Normal(means, stds)  # batch_sz, L for each
            logprob = distr.log_prob(z).sum(dim=1) # batch_sz
            return z, logprob

        return self.cls(z), stds.mean().item()


    def log_marg_prob(self, z, d_x, jensen):
        batch_sz, L = z.shape
        batch_sz2 = d_x.shape[0]

        means = self.pred_mu(d_x)  # n, enc_sz
        stds = self.pred_sigma(d_x)

        # for each target, pass through each mean
        means = means.unsqueeze(0).expand(batch_sz, batch_sz2, L)
        stds = stds.unsqueeze(0).expand(batch_sz, batch_sz2, L)

        z = z.unsqueeze(1).expand(batch_sz, batch_sz2, L)

        distr = torch.distributions.normal.Normal(means, stds)
        logprob = distr.log_prob(z)
        assert logprob.shape == (batch_sz, batch_sz2, L)

        logprob = logprob.sum(dim=2) # batch_sz, batch_sz2, logprob of each code, was missing before!
        if jensen:
            log_margprob = logprob.mean(dim=1) # est
        else:
            log_margprob = - np.log(batch_sz2) + torch.logsumexp(logprob, dim=1)

        assert log_margprob.shape == (batch_sz,)

        return log_margprob # batch_sz
