import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import OrderedDict


__all__ = ["vae"]


class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim=None):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        if hidden_dim is None:
            self.hidden_dims = [self.input_dim, self.latent_dim]
        else:
            self.hidden_dims = hidden_dim
        
        # encoder
        modules = []
        for hidden_dim in self.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.BatchNorm1d(hidden_dim),
                    nn.ReLU()
                )
            )
            input_dim = hidden_dim
        modules.append(nn.Linear(self.hidden_dims[-1], self.latent_dim * 2))
        self.encoder = nn.Sequential(*modules)

        # decoder
        modules = []
        input_dim = self.latent_dim
        self.hidden_dims = [self.input_dim] + self.hidden_dims
        for i in range(len(self.hidden_dims) - 1, 0, -1):
            modules.append(
                nn.Sequential(
                    nn.Linear(input_dim, self.hidden_dims[i]),
                    nn.BatchNorm1d(self.hidden_dims[i]),
                    nn.ReLU()
                )
            )
            input_dim = self.hidden_dims[i]
        modules.append(
            nn.Sequential(
                nn.Linear(input_dim, self.hidden_dims[0]),
                nn.Softplus()
            )
        )
        self.decoder = nn.Sequential(*modules)

        self.init_weight()
    
    def encode(self, x):
        out = self.encoder(x)
        z_mu = out[:, :self.latent_dim]
        z_log_var = out[:, self.latent_dim:]

        return z_mu, z_log_var

    def reparameterize(self, z_mu, z_log_var):
        epsilon = torch.randn_like(z_mu)
        z = z_mu + torch.exp(0.5 * z_log_var) * epsilon

        return z

    def forward(self, x):
        z_mu, z_log_var = self.encode(x)
        z = self.reparameterize(z_mu, z_log_var)
        x_recons = self.decoder(z)
        
        output = {
            "x_recons": x_recons,
            "z_mu": z_mu,
            "z_log_var": z_log_var
        }
        return output

    def compute_reconstruction_loss(self, x, x_recons):
        x_recons = torch.clamp(x_recons, min=1e-7, max=1e7)
        return F.poisson_nll_loss(x_recons, x, log_input=False, reduction="sum") / x.size(0)

    def compute_loss(self, **kwargs):
        x = kwargs["x"]
        x_recons = kwargs["x_recons"]
        z_mu = kwargs["z_mu"]
        z_log_var = kwargs["z_log_var"]
        kld_weight = kwargs["kld_weight"]

        reconstruction_loss = self.compute_reconstruction_loss(x, x_recons)
        kld_loss = 0.5 * torch.mean(torch.sum(-1 + z_mu ** 2 + torch.exp(z_log_var) - z_log_var, dim=-1))
        loss = reconstruction_loss + kld_weight * kld_loss
        return OrderedDict([("loss", loss), ("reconstruction_loss", reconstruction_loss.detach()), ("kld_loss", kld_loss.detach())])
    
    def init_weight(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


def vae(**kwargs):
    return VAE(**kwargs)
