"""
SlowVAE.py
- https://github.com/bethgelab/slow_disentanglement
- Beta-VAE --> use Contrastive Learning Data
- No Contrastive Learning and Condition
"""
import torch
import numpy as np
import lightning.pytorch as pl
from ..metrics.correlation import compute_r2
import torch.nn.functional as F

"SlowVAE list"
from .net import SlowVAEMLP, NLayerLeakyMLP

def reconstruction_loss(x, x_recon, distribution):
    batch_size = x.size(0)
    assert batch_size != 0

    if distribution == 'bernoulli':
        recon_loss = F.binary_cross_entropy_with_logits(
            x_recon, x, size_average=False).div(batch_size)
    elif distribution == 'gaussian':
        recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
    else:
        recon_loss = None

    return recon_loss

def compute_cross_ent_normal(mu, logvar):
    return 0.5 * (mu**2 + torch.exp(logvar)) + np.log(np.sqrt(2 * np.pi))

def compute_ent_normal(logvar):
    return 0.5 * (logvar + np.log(2 * np.pi * np.e))

def compute_sparsity(mu, normed=True):
    # assume couples, compute normalized sparsity
    diff = mu[::2] - mu[1::2]
    if normed:
        norm = torch.norm(diff, dim=1, keepdim=True)
        norm[norm == 0] = 1  # keep those that are same, dont divide by 0
        diff = diff / norm
    return torch.mean(torch.abs(diff))


class SlowVAE(pl.LightningModule):
    def __init__(self, 
                 input_dim, 
                 z_dim, 
                 hidden_dim, 
                 beta,
                 gamma, 
                 lr, 
                 beta1, 
                 beta2, 
                 rate_prior,
                 correlation):
        # Networks & Optimizers
        super(SlowVAE, self).__init__()
        self.beta = beta
        self.z_dim = z_dim
        self.z1_dim = self.z2_dim = self.z3_dim = self.z4_dim = int(z_dim/4)
        self.z_dim_true_list = [2, 2, 2, 2]
        self.gamma = gamma
        self.rate_prior = rate_prior
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.correlation = correlation
        self.decoder_dist = 'gaussian'
        self.rate_prior = rate_prior * torch.ones(1, requires_grad=False)

        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.net = SlowVAEMLP(self.input_dim, self.z_dim, self.hidden_dim)
        self.rew_dec = NLayerLeakyMLP(in_features=int(self.z_dim/2), out_features=1, num_layers=2)

    def compute_cross_ent_laplace(self, mean, logvar, rate_prior):
        var = torch.exp(logvar)
        sigma = torch.sqrt(var)
        device = sigma.device
        rate_prior = rate_prior.to(device)
        normal_dist = torch.distributions.normal.Normal(
            torch.zeros(self.z_dim).to(device),
            torch.ones(self.z_dim).to(device))
        ce = - torch.log(rate_prior / 2) + rate_prior * sigma *\
             np.sqrt(2 / np.pi) * torch.exp(- mean**2 / (2 * var)) -\
             rate_prior * mean * (
                     1 - 2 * normal_dist.cdf(mean / sigma))
        return ce

    def compute_cross_ent_combined(self, mu, logvar):
        normal_entropy = compute_ent_normal(logvar)
        cross_ent_normal = compute_cross_ent_normal(mu, logvar)
        # assuming couples, do Laplace both ways
        mu0 = mu[::2]
        mu1 = mu[1::2]
        logvar0 = logvar[::2]
        logvar1 = logvar[1::2]
        rate_prior0 = self.rate_prior
        rate_prior1 = self.rate_prior
        cross_ent_laplace = (
            self.compute_cross_ent_laplace(mu0 - mu1, logvar0, rate_prior0) +
            self.compute_cross_ent_laplace(mu1 - mu0, logvar1, rate_prior1))
        return [x.sum(1).mean(0, True) for x in [normal_entropy,
                                                 cross_ent_normal,
                                                 cross_ent_laplace]]
    
    def training_step(self, batch, batch_idx):
        x = batch['s1']['xt']
        r = batch['s1']['rt']
        r_pst = r[:, :-1]
        r_cur = r[:, 1:]
        x_pst = x[:,:-1,:]
        x_cur = x[:,1:,:]
        cat = torch.cat((x_pst.reshape(-1, self.input_dim), 
                         x_cur.reshape(-1, self.input_dim)), 
                         dim=0)
        cat_r = torch.cat((r_pst.reshape(-1, 1), r_cur.reshape(-1, 1)), dim=0)
        
        x_recon, mu, logvar, z = self.net(cat, return_z=True)
        
        recon_loss = reconstruction_loss(cat, x_recon, self.decoder_dist)
        z_rew, _ = torch.split(z, [int(self.z_dim/2), int(self.z_dim/2)], dim=-1)
        r_recon = self.rew_dec(z_rew)
        recon_r_loss = reconstruction_loss(cat_r, r_recon, self.decoder_dist)
        # VAE training
        [normal_entropy, cross_ent_normal, cross_ent_laplace] = self.compute_cross_ent_combined(mu, logvar)
        vae_loss = 2 * recon_loss + 2 * recon_r_loss
        kl_normal = cross_ent_normal - normal_entropy
        kl_laplace = cross_ent_laplace - normal_entropy
        vae_loss = vae_loss + self.beta * kl_normal
        vae_loss = vae_loss + self.gamma * kl_laplace
                
        self.log("train_vae_loss", vae_loss)
        return vae_loss


    def validation_step(self, batch, batch_idx):
        x = batch['s1']['xt']
        r = batch['s1']['rt']
        r_pst = r[:, :-1]
        r_cur = r[:, 1:]
        x_pst = x[:,:-1,:]
        x_cur = x[:,1:,:]
        cat = torch.cat((x_pst.reshape(-1, self.input_dim), 
                         x_cur.reshape(-1, self.input_dim)), 
                         dim=0)
        cat_r = torch.cat((r_pst.reshape(-1, 1), r_cur.reshape(-1, 1)), dim=0)
        
        x_recon, mu, logvar, z = self.net(cat, return_z=True)
        
        recon_loss = reconstruction_loss(cat, x_recon, self.decoder_dist)
        z_rew, _ = torch.split(z, [int(self.z_dim/2), int(self.z_dim/2)], dim=-1)
        r_recon = self.rew_dec(z_rew)
        recon_r_loss = reconstruction_loss(cat_r, r_recon, self.decoder_dist)

        # VAE training
        [normal_entropy, cross_ent_normal, cross_ent_laplace] = self.compute_cross_ent_combined(mu, logvar)
        vae_loss = 2 * recon_loss + 2 * recon_r_loss
        kl_normal = cross_ent_normal - normal_entropy
        kl_laplace = cross_ent_laplace - normal_entropy
        vae_loss = vae_loss + self.beta * kl_normal
        vae_loss = vae_loss + self.gamma * kl_laplace
        
        # Compute Mean Correlation Coefficient (MCC)

        _, mu, _ = self.net(x.view(-1, self.input_dim))
        # Compute R2
        zt_recon = mu.view(-1, self.z_dim).detach().cpu().numpy()
        z1_recon, z2_recon, z3_recon, z4_recon =  np.split(zt_recon, [self.z1_dim, self.z1_dim+self.z2_dim, self.z1_dim+self.z2_dim+self.z3_dim], axis=-1)
        train_z1_recon, test_z1_recon = np.split(z1_recon, 2, axis=0)
        train_z2_recon, test_z2_recon = np.split(z2_recon, 2, axis=0)
        train_z3_recon, test_z3_recon = np.split(z3_recon, 2, axis=0)
        train_z4_recon, test_z4_recon = np.split(z4_recon, 2, axis=0)

        zt_true = batch["s1"]["yt"].view(-1, sum(self.z_dim_true_list)).detach().cpu().numpy()
        z1_true, z2_true, z3_true, z4_true =  np.split(zt_true, [self.z_dim_true_list[0], self.z_dim_true_list[0]+self.z_dim_true_list[1], self.z_dim_true_list[0]+self.z_dim_true_list[1]+self.z_dim_true_list[2]], axis=-1)
        train_z1_true, test_z1_true = np.split(z1_true, 2, axis=0)
        train_z2_true, test_z2_true = np.split(z2_true, 2, axis=0)
        train_z3_true, test_z3_true = np.split(z3_true, 2, axis=0)
        train_z4_true, test_z4_true = np.split(z4_true, 2, axis=0)
        r21 = compute_r2(train_z1_recon, train_z1_true, test_z1_recon, test_z1_true)
        r22 = compute_r2(train_z2_recon, train_z2_true, test_z2_recon, test_z2_true)
        r23 = compute_r2(train_z3_recon, train_z3_true, test_z3_recon, test_z3_true)
        r24 = compute_r2(train_z4_recon, train_z4_true, test_z4_recon, test_z4_true)
        ave_r2 = (r21 + r22 + r23 + r24) / 4.0
        r21h = compute_r2(train_z1_true, train_z1_recon, test_z1_true, test_z1_recon)
        r22h = compute_r2(train_z2_true, train_z2_recon, test_z2_true, test_z2_recon)
        r23h = compute_r2(train_z3_true, train_z3_recon, test_z3_true, test_z3_recon)
        r24h = compute_r2(train_z4_true, train_z4_recon, test_z4_true, test_z4_recon)
        ave_r2h = (r21h + r22h + r23h + r24h) / 4.0
            
        self.log("val_vae_loss", vae_loss)
        self.log("r21", r21)
        self.log("r22", r22)
        self.log("r23", r23)
        self.log("r24", r24)
        self.log("ave_r2", ave_r2)
        self.log("r21h", r21h)
        self.log("r22h", r22h)
        self.log("r23h", r23h)
        self.log("r24h", r24h)
        self.log("ave_r2h", ave_r2h)
        return vae_loss

    def configure_optimizers(self):
        opt_v = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), 
                                 lr=self.lr, betas=(self.beta1, self.beta2), weight_decay=0.0001)
        return opt_v