"""
BetaVAE.py
- https://github.com/1Konny/Beta-VAE
- No Transition, Contrastive Learning and Condition
"""
import torch
import numpy as np
import lightning.pytorch as pl
import torch.nn.functional as F

"utils file (SAME)"
from ..metrics.correlation import compute_r2
"BetaVAE list"
from .net import BetaVAEMLP, NLayerLeakyMLP

def reconstruction_loss(x, x_recon, distribution):
    batch_size = x.size(0)
    assert batch_size != 0

    if distribution == 'bernoulli':
        recon_loss = F.binary_cross_entropy_with_logits(
            x_recon, x, size_average=False).div(batch_size)
    elif distribution == 'gaussian':
        recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
    else:
        recon_loss = None
    return recon_loss

def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0
    if mu.data.ndimension() == 4:
        mu = mu.view(mu.size(0), mu.size(1))
    if logvar.data.ndimension() == 4:
        logvar = logvar.view(logvar.size(0), logvar.size(1))

    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
    total_kld = klds.sum(1).mean(0, True)
    dimension_wise_kld = klds.mean(0)
    mean_kld = klds.mean(1).mean(0, True)

    return total_kld, dimension_wise_kld, mean_kld

class BetaVAE(pl.LightningModule):
    def __init__(self, 
                 input_dim, 
                 z_dim, 
                 hidden_dim, 
                 beta,
                 beta1,
                 beta2,
                 lr,
                 correlation):
        # Networks & Optimizers
        super(BetaVAE, self).__init__()
        self.beta = beta
        self.beta1 = beta1
        self.beta2 = beta2
        self.z_dim = z_dim
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.correlation = correlation
        self.decoder_dist = 'gaussian'
        self.z1_dim = self.z2_dim = self.z3_dim = self.z4_dim = int(z_dim/4)
        self.z_dim_true_list = [2, 2, 2, 2]
        self.lr = lr
        self.net = BetaVAEMLP(self.input_dim, self.z_dim, self.hidden_dim)
        self.rew_dec = NLayerLeakyMLP(in_features=int(self.z_dim/2), out_features=1, num_layers=2)
    
    def training_step(self, batch, batch_idx):
        x = batch['s1']['xt'].reshape(-1, self.input_dim)
        r = batch['s1']['rt'].reshape(-1, 1)
        x_recon, mu, logvar, z = self.net(x, return_z=True)
        z_rew, _ = torch.split(z, [int(self.z_dim/2), int(self.z_dim/2)], dim=-1)
        r_recon = self.rew_dec(z_rew)
        recon_r_loss = reconstruction_loss(r, r_recon, self.decoder_dist)
        recon_loss = reconstruction_loss(x, x_recon, self.decoder_dist)
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)
        vae_loss = recon_loss + recon_r_loss + self.beta * total_kld
                
        self.log("train_vae_loss", vae_loss)
        return vae_loss


    def validation_step(self, batch, batch_idx):
        x = batch['s1']['xt'].reshape(-1, self.input_dim)
        r = batch['s1']['rt'].reshape(-1, 1)
        x_recon, mu, logvar, z = self.net(x, return_z=True)
        z_rew, _ = torch.split(z, [int(self.z_dim/2), int(self.z_dim/2)], dim=-1)
        r_recon = self.rew_dec(z_rew)
        recon_r_loss = reconstruction_loss(r, r_recon, self.decoder_dist)
        recon_loss = reconstruction_loss(x, x_recon, self.decoder_dist)
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)
        vae_loss = recon_loss + recon_r_loss + self.beta * total_kld      
        # Compute R2
        zt_recon = mu.view(-1, self.z_dim).detach().cpu().numpy()
        z1_recon, z2_recon, z3_recon, z4_recon =  np.split(zt_recon, [self.z1_dim, self.z1_dim+self.z2_dim, self.z1_dim+self.z2_dim+self.z3_dim], axis=-1)
        train_z1_recon, test_z1_recon = np.split(z1_recon, 2, axis=0)
        train_z2_recon, test_z2_recon = np.split(z2_recon, 2, axis=0)
        train_z3_recon, test_z3_recon = np.split(z3_recon, 2, axis=0)
        train_z4_recon, test_z4_recon = np.split(z4_recon, 2, axis=0)

        zt_true = batch["s1"]["yt"].view(-1, sum(self.z_dim_true_list)).detach().cpu().numpy()
        z1_true, z2_true, z3_true, z4_true =  np.split(zt_true, [self.z_dim_true_list[0], self.z_dim_true_list[0]+self.z_dim_true_list[1], self.z_dim_true_list[0]+self.z_dim_true_list[1]+self.z_dim_true_list[2]], axis=-1)
        train_z1_true, test_z1_true = np.split(z1_true, 2, axis=0)
        train_z2_true, test_z2_true = np.split(z2_true, 2, axis=0)
        train_z3_true, test_z3_true = np.split(z3_true, 2, axis=0)
        train_z4_true, test_z4_true = np.split(z4_true, 2, axis=0)
        r21 = compute_r2(train_z1_recon, train_z1_true, test_z1_recon, test_z1_true)
        r22 = compute_r2(train_z2_recon, train_z2_true, test_z2_recon, test_z2_true)
        r23 = compute_r2(train_z3_recon, train_z3_true, test_z3_recon, test_z3_true)
        r24 = compute_r2(train_z4_recon, train_z4_true, test_z4_recon, test_z4_true)
        ave_r2 = (r21 + r22 + r23 + r24) / 4.0
        r21h = compute_r2(train_z1_true, train_z1_recon, test_z1_true, test_z1_recon)
        r22h = compute_r2(train_z2_true, train_z2_recon, test_z2_true, test_z2_recon)
        r23h = compute_r2(train_z3_true, train_z3_recon, test_z3_true, test_z3_recon)
        r24h = compute_r2(train_z4_true, train_z4_recon, test_z4_true, test_z4_recon)
        ave_r2h = (r21h + r22h + r23h + r24h) / 4.0
            
        self.log("val_vae_loss", vae_loss)
        self.log("r21", r21)
        self.log("r22", r22)
        self.log("r23", r23)
        self.log("r24", r24)
        self.log("ave_r2", ave_r2)
        self.log("r21h", r21h)
        self.log("r22h", r22h)
        self.log("r23h", r23h)
        self.log("r24h", r24h)
        self.log("ave_r2h", ave_r2h)
        return vae_loss

    def configure_optimizers(self):
        opt_v = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), 
                                 lr=self.lr, betas=(self.beta1, self.beta2), weight_decay=0.0001)
        return opt_v