## Reproduce from the original implementation https://github.com/zhd96/pi-vae


import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import OrderedDict


__all__ = ["pivae"]


class LabelPriorDiscrete(nn.Module):
    def __init__(self, label_num, latent_dim):
        super().__init__()

        self.embed_mu = nn.Embedding(label_num, latent_dim)
        self.embed_log_var = nn.Embedding(label_num, latent_dim)
    
    def forward(self, u):
        z_label_prior_mu = self.embed_mu(u)
        z_label_prior_log_var = self.embed_log_var(u)

        return z_label_prior_mu, z_label_prior_log_var


class LabelPriorContinuous(nn.Module):
    def __init__(self, label_dim, latent_dim, hidden_dim=20):
        super().__init__()

        self.hidden_layers = nn.Sequential(
            nn.Linear(label_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh()
        )

        self.linear_mu = nn.Linear(hidden_dim, latent_dim)
        self.linear_log_var = nn.Linear(hidden_dim, latent_dim)
    
    def forward(self, u):
        h = self.hidden_layers(u)
        z_label_prior_mu = self.linear_mu(h)
        z_label_prior_log_var = self.linear_log_var(h)

        return z_label_prior_mu, z_label_prior_log_var


class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim=60):
        super().__init__()

        self.enc_mu = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.enc_log_var = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, latent_dim)
        )
    
    def forward(self, x):
        z_mu = self.enc_mu(x)
        z_log_var = self.enc_log_var(x)

        return z_mu, z_log_var


class Permutation(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.register_buffer("p_index", torch.randperm(input_dim))
        self.register_buffer("invp_index", torch.argsort(self.p_index))
    
    def forward(self, x):
        return x[..., self.p_index]

    def backward(self, x):
        return x[..., self.invp_index]


class AffineCouplingFunction(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.input_dim = input_dim
        self.split_dim = input_dim // 2
        self.hidden_dim = input_dim // 4

        self.affine_coupling_function = nn.Sequential(
            nn.Linear(self.split_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 2 * (input_dim - self.split_dim) - 1)
        )

    def forward(self, x):
        x1 = x[..., :self.split_dim]
        x2 = x[..., self.split_dim:]

        output = self.affine_coupling_function(x1)
        s = output[..., :self.input_dim - self.split_dim - 1]
        s = 0.1 * torch.tanh(s)
        s = torch.cat((s, -torch.sum(s, dim=-1, keepdim=True)), dim=-1)
        t = output[..., self.input_dim - self.split_dim - 1:]

        output = x2 * torch.exp(s) + t
        output = torch.cat((output, x1), dim=-1)
        
        return output


class GINBlock(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.layers = nn.Sequential(
            AffineCouplingFunction(input_dim),
            AffineCouplingFunction(input_dim)
        )
    
    def forward(self, x):
        return self.layers(x)


class Decoder(nn.Module):
    def __init__(self, input_dim, latent_dim, observation_model="poisson"):
        super().__init__()

        hidden_dim = input_dim // 4
        self.first_layers = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim - latent_dim)
        )
        self.gin_blocks = nn.Sequential(
            Permutation(input_dim),
            GINBlock(input_dim),
            Permutation(input_dim),
            GINBlock(input_dim)
        )
        self.observation_model=observation_model
    
    def forward(self, z):
        output = torch.cat((z, self.first_layers(z)), dim=-1)
        output = self.gin_blocks(output)
        if self.observation_model == "poisson":
            output = F.softplus(output)
            output = torch.clamp(output, min=1e-7, max=1e7)
        return output


class PIVAE(nn.Module):
    def __init__(self, input_dim, label_dim, latent_dim, discrete_prior=True, observation_model="poisson"):
        super().__init__()

        if discrete_prior:
            self.prior = LabelPriorDiscrete(label_dim, latent_dim)
        else:
            self.prior = LabelPriorContinuous(label_dim, latent_dim, hidden_dim=latent_dim // 2)
        
        self.encoder = Encoder(input_dim, latent_dim, hidden_dim=input_dim // 2)
        self.decoder = Decoder(input_dim, latent_dim, observation_model=observation_model)

        self.observation_model = observation_model
        
        self.init_weight()
    
    def reparameterize(self, z_mu, z_log_var):
        epsilon = torch.randn_like(z_mu)
        z = z_mu + torch.exp(0.5 * z_log_var) * epsilon

        return z

    def compute_posterior(self, z_mu, z_log_var, z_label_prior_mu, z_label_prior_log_var):
        # diff_log_var = z_log_var - z_label_prior_log_var
        # post_mean = (z_mu / (1 + torch.exp(diff_log_var))) + (z_label_prior_mu / (1 + torch.exp(-diff_log_var)))

        z_post_mu = (z_mu * torch.exp(z_label_prior_log_var) + z_label_prior_mu * torch.exp(z_log_var)) / (torch.exp(z_log_var) + torch.exp(z_label_prior_log_var))
        z_post_log_var = z_log_var + z_label_prior_log_var - torch.log(torch.exp(z_log_var) + torch.exp(z_label_prior_log_var))

        return z_post_mu, z_post_log_var

    def forward(self, x, u):
        z_mu, z_log_var = self.encoder(x)
        z_label_prior_mu, z_label_prior_log_var = self.prior(u)

        z_post_mu, z_post_log_var = self.compute_posterior(z_mu, z_log_var, z_label_prior_mu, z_label_prior_log_var)
        z = self.reparameterize(z_post_mu, z_post_log_var)

        r = self.decoder(z)

        output = {
            "rate": r,
            "z_mu": z_mu,
            "z_log_var": z_log_var,
            "z_label_prior_mu": z_label_prior_mu,
            "z_label_prior_log_var": z_label_prior_log_var,
            "z_post_mu": z_post_mu,
            "z_post_log_var": z_post_log_var
        }
        return output

    def compute_loss(self, **kwargs):
        x = kwargs["x"]
        r = kwargs["rate"]
        z_post_mu = kwargs["z_post_mu"]
        z_post_log_var = kwargs["z_post_log_var"]
        z_label_prior_mu = kwargs["z_label_prior_mu"]
        z_label_prior_log_var = kwargs["z_label_prior_log_var"]
        kld_weight = kwargs["kld_weight"]

        reconstruction_loss = F.poisson_nll_loss(r, x, log_input=False, reduction="sum")
        reconstruction_loss /= x.size(0)

        kld_loss = 0.5 * torch.mean(torch.sum(-1 + ((z_post_mu - z_label_prior_mu) ** 2 + torch.exp(z_post_log_var)) / torch.exp(z_label_prior_log_var) - z_post_log_var + z_label_prior_log_var, dim=-1))

        loss = reconstruction_loss + kld_weight * kld_loss
        return OrderedDict([("loss", loss), ("reconstruction_loss", reconstruction_loss.detach()), ("kld_loss", kld_loss.detach())])
    
    def init_weight(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Embedding):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


def pivae(**kwargs):
    return PIVAE(**kwargs)
