## Reproduce from the original implementation https://github.com/nerdslab/SwapVAE


import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import OrderedDict


__all__ = ["swap_vae"]


class SwapVAE(nn.Module):
    def __init__(self, input_dim, content_dim, style_dim, hidden_dim=None):
        super().__init__()
        self.input_dim = input_dim
        self.content_dim = content_dim
        self.style_dim = style_dim
        self.latent_dim = self.content_dim + self.style_dim
        if hidden_dim is None:
            self.hidden_dims = [self.input_dim, self.latent_dim]
        else:
            self.hidden_dims = hidden_dim
        
        # encoder
        modules = []
        for hidden_dim in self.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.BatchNorm1d(hidden_dim),
                    nn.ReLU()
                )
            )
            input_dim = hidden_dim
        modules.append(nn.Linear(self.hidden_dims[-1], self.content_dim + self.style_dim * 2))
        self.encoder = nn.Sequential(*modules)

        # decoder
        modules = []
        input_dim = self.latent_dim
        self.hidden_dims = [self.input_dim] + self.hidden_dims
        for i in range(len(self.hidden_dims) - 1, 0, -1):
            modules.append(
                nn.Sequential(
                    nn.Linear(input_dim, self.hidden_dims[i]),
                    nn.BatchNorm1d(self.hidden_dims[i]),
                    nn.ReLU()
                )
            )
            input_dim = self.hidden_dims[i]
        modules.append(
            nn.Sequential(
                nn.Linear(input_dim, self.hidden_dims[0]),
                nn.Softplus()
            )
        )
        self.decoder = nn.Sequential(*modules)

        self.init_weight()
    
    def encode(self, x):
        out = self.encoder(x)
        z_content = out[:, :self.content_dim]
        z_style_mu = out[:, self.content_dim: self.latent_dim]
        z_style_log_var = out[:, self.latent_dim:]

        return z_content, z_style_mu, z_style_log_var

    def reparameterize(self, z_mu, z_log_var):
        epsilon = torch.randn_like(z_mu)
        z = z_mu + torch.exp(0.5 * z_log_var) * epsilon

        return z

    def forward(self, x1, x2):
        z1_content, z1_style_mu, z1_style_log_var = self.encode(x1)
        z1_style = self.reparameterize(z1_style_mu, z1_style_log_var)
        z2_content, z2_style_mu, z2_style_log_var = self.encode(x2)
        z2_style = self.reparameterize(z2_style_mu, z2_style_log_var)

        z1_original = torch.cat((z1_content, z1_style), dim=-1)
        z2_original = torch.cat((z2_content, z2_style), dim=-1)

        z1_swap = torch.cat((z2_content, z1_style), dim=-1)
        z2_swap = torch.cat((z1_content, z2_style), dim=-1)

        x1_recons_ori = self.decoder(z1_original)
        x2_recons_ori = self.decoder(z2_original)

        x1_recons_swap = self.decoder(z1_swap)
        x2_recons_swap = self.decoder(z2_swap)
        
        output = {
            "x1_recons_ori": x1_recons_ori,
            "x2_recons_ori": x2_recons_ori,
            "x1_recons_swap": x1_recons_swap,
            "x2_recons_swap": x2_recons_swap,
            "z1_content": z1_content,
            "z1_style_mu": z1_style_mu,
            "z1_style_log_var": z1_style_log_var,
            "z2_content": z2_content,
            "z2_style_mu": z2_style_mu,
            "z2_style_log_var": z2_style_log_var
        }
        return output

    def compute_reconstruction_loss(self, x, x_recons):
        x_recons = torch.clamp(x_recons, min=1e-7, max=1e7)
        return F.poisson_nll_loss(x_recons, x, log_input=False, reduction="sum") / x.size(0)

    def compute_loss(self, **kwargs):
        x1 = kwargs["x1"]
        x2 = kwargs["x2"]
        x1_recons_ori = kwargs["x1_recons_ori"]
        x2_recons_ori = kwargs["x2_recons_ori"]
        x1_recons_swap = kwargs["x1_recons_swap"]
        x2_recons_swap = kwargs["x2_recons_swap"]
        z1_content = kwargs["z1_content"]
        z1_style_mu = kwargs["z1_style_mu"]
        z1_style_log_var = kwargs["z1_style_log_var"]
        z2_content = kwargs["z2_content"]
        z2_style_mu = kwargs["z2_style_mu"]
        z2_style_log_var = kwargs["z2_style_log_var"]
        kld_weight = kwargs["kld_weight"]
        align_weight = kwargs["align_weight"]

        reconstruction_ori_loss = (self.compute_reconstruction_loss(x1, x1_recons_ori) + self.compute_reconstruction_loss(x2, x2_recons_ori)) / 2
        reconstruction_swap_loss = (self.compute_reconstruction_loss(x1, x1_recons_swap) + self.compute_reconstruction_loss(x2, x2_recons_swap)) / 2

        kld_loss = (0.5 * torch.mean(torch.sum(-1 + z1_style_mu ** 2 + torch.exp(z1_style_log_var) - z1_style_log_var, dim=-1)) + 0.5 * torch.mean(torch.sum(-1 + z2_style_mu ** 2 + torch.exp(z2_style_log_var) - z2_style_log_var, dim=-1))) / 2
        
        align_loss = torch.mean(1 - torch.sum(F.normalize(z1_content, p=2, dim=-1) * F.normalize(z2_content, p=2, dim=-1), dim=-1))
        
        loss = reconstruction_ori_loss + reconstruction_swap_loss + kld_weight * kld_loss + align_weight * align_loss

        return OrderedDict([("loss", loss), ("reconstruction_ori_loss", reconstruction_ori_loss.detach()), ("reconstruction_swap_loss", reconstruction_swap_loss.detach()), ("kld_loss", kld_loss.detach()), ("align_loss", align_loss.detach())])
    
    def init_weight(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


def swap_vae(**kwargs):
    return SwapVAE(**kwargs)
