"""
FactorVAE.py
- https://github.com/1Konny/FactorVAE
- Contrastive Learning & beta-VAE
- No conditional information
"""
import torch
import lightning.pytorch as pl

import torch.nn.functional as F
import numpy as np

"utils file (SAME)"
from ..metrics.correlation import compute_r2
"FactorVAE list"
from .net import FactorVAEMLP, FactorVAECNN, FactorVAEKP, Discriminator, NLayerLeakyMLP
from .ops import recon_loss, kl_divergence, permute_dims
import ipdb as pdb

class FactorVAE(pl.LightningModule):
    def __init__(self, 
                 input_dim, 
                 z_dim, 
                 hidden_dim, 
                 gamma, 
                 lr_VAE, 
                 beta1_VAE, 
                 beta2_VAE, 
                 lr_D, 
                 beta1_D, 
                 beta2_D, 
                 correlation):
        # Networks & Optimizers
        super(FactorVAE, self).__init__()
        self.z_dim = z_dim
        self.z1_dim = self.z2_dim = self.z3_dim = self.z4_dim = int(z_dim/4)
        self.z_dim_true_list = [2, 2, 2, 2]
        self.gamma = gamma
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.correlation = correlation
        self.lr_VAE = lr_VAE
        self.beta1_VAE = beta1_VAE
        self.beta2_VAE = beta2_VAE

        self.lr_D = lr_D
        self.beta1_D = beta1_D
        self.beta2_D = beta2_D

        self.VAE = FactorVAEMLP(self.input_dim, self.z_dim, self.hidden_dim)
        self.D = Discriminator(self.z_dim, self.hidden_dim)
        self.rew_dec = NLayerLeakyMLP(in_features=int(self.z_dim/2), out_features=1, num_layers=2)
        self.automatic_optimization = False


    def training_step(self, batch, batch_idx):
        x_true1 = batch['s1']['xt'].reshape(-1, self.input_dim)
        x_true2 = batch['s2']['xt'].reshape(-1, self.input_dim)
        r_true1 = batch['s1']['rt'].reshape(-1, 1)
        opt_v, opt_d = self.optimizers()
        batch_size = x_true1.shape[0]
        ones = torch.ones(batch_size, dtype=torch.long, device=x_true1.device)
        zeros = torch.zeros(batch_size, dtype=torch.long, device=x_true1.device)

        x_recon, mu, logvar, z = self.VAE(x_true1)
        # z_rew, _ = torch.split(z, 2, dim=-1)
        z_rew, _ = torch.split(z, [int(self.z_dim/2), int(self.z_dim/2)], dim=-1)
        
        # Discriminator training 
        D_z = self.D(z.detach())
        z_prime = self.VAE(x_true2, no_dec=True)
        z_pperm = permute_dims(z_prime).detach()
        D_z_pperm = self.D(z_pperm)
        D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))
        self.log("train_d_tc_loss", D_tc_loss)
        opt_d.zero_grad()
        self.manual_backward(D_tc_loss)
        opt_d.step()

        # VAE training
        D_z = self.D(z)
        r_recon = self.rew_dec(z_rew)
        vae_recon_r = recon_loss(r_true1, r_recon)
        vae_recon_loss = recon_loss(x_true1, x_recon)
        vae_kld = kl_divergence(mu, logvar)
        vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()
        vae_loss = vae_recon_loss + vae_recon_r+ vae_kld + self.gamma*vae_tc_loss
        opt_v.zero_grad()
        self.manual_backward(vae_loss)
        opt_v.step()

        self.log("train_vae_loss", vae_loss)
        self.log("train_vae_recon_loss", vae_recon_loss)
        self.log("train_vae_r_loss", vae_recon_r)
        self.log("train_vae_kld", vae_kld)
        self.log("train_vae_tc_loss", vae_tc_loss)
        return D_tc_loss

    def validation_step(self, batch, batch_idx):
        x_true1 = batch['s1']['xt'].reshape(-1, self.input_dim)
        r_true1 = batch['s1']['rt'].reshape(-1, 1)

        x_recon, mu, logvar, z = self.VAE(x_true1)
        z_rew, _ = torch.split(z, [int(self.z_dim/2), int(self.z_dim/2)], dim=-1)
        # z_rew, _ = torch.split(z, 2, dim=-1)
        r_recon = self.rew_dec(z_rew)
        vae_recon_r = recon_loss(r_true1, r_recon)
        vae_recon_loss = recon_loss(x_true1, x_recon)
        
        vae_kld = kl_divergence(mu, logvar)

        D_z = self.D(z)
        vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

        vae_loss = vae_recon_loss + vae_recon_r+ vae_kld + self.gamma*vae_tc_loss

        # Compute R2
        zt_recon = mu.view(-1, self.z_dim).detach().cpu().numpy()
        z1_recon, z2_recon, z3_recon, z4_recon =  np.split(zt_recon, [self.z1_dim, self.z1_dim+self.z2_dim, self.z1_dim+self.z2_dim+self.z3_dim], axis=-1)
        train_z1_recon, test_z1_recon = np.split(z1_recon, 2, axis=0)
        train_z2_recon, test_z2_recon = np.split(z2_recon, 2, axis=0)
        train_z3_recon, test_z3_recon = np.split(z3_recon, 2, axis=0)
        train_z4_recon, test_z4_recon = np.split(z4_recon, 2, axis=0)

        zt_true = batch["s1"]["yt"].view(-1, sum(self.z_dim_true_list)).detach().cpu().numpy()
        z1_true, z2_true, z3_true, z4_true =  np.split(zt_true, [self.z_dim_true_list[0], self.z_dim_true_list[0]+self.z_dim_true_list[1], self.z_dim_true_list[0]+self.z_dim_true_list[1]+self.z_dim_true_list[2]], axis=-1)
        train_z1_true, test_z1_true = np.split(z1_true, 2, axis=0)
        train_z2_true, test_z2_true = np.split(z2_true, 2, axis=0)
        train_z3_true, test_z3_true = np.split(z3_true, 2, axis=0)
        train_z4_true, test_z4_true = np.split(z4_true, 2, axis=0)
        r21 = compute_r2(train_z1_recon, train_z1_true, test_z1_recon, test_z1_true)
        r22 = compute_r2(train_z2_recon, train_z2_true, test_z2_recon, test_z2_true)
        r23 = compute_r2(train_z3_recon, train_z3_true, test_z3_recon, test_z3_true)
        r24 = compute_r2(train_z4_recon, train_z4_true, test_z4_recon, test_z4_true)
        ave_r2 = (r21 + r22 + r23 + r24) / 4.0
        r21h = compute_r2(train_z1_true, train_z1_recon, test_z1_true, test_z1_recon)
        r22h = compute_r2(train_z2_true, train_z2_recon, test_z2_true, test_z2_recon)
        r23h = compute_r2(train_z3_true, train_z3_recon, test_z3_true, test_z3_recon)
        r24h = compute_r2(train_z4_true, train_z4_recon, test_z4_true, test_z4_recon)
        ave_r2h = (r21h + r22h + r23h + r24h) / 4.0
        
        self.log("val_vae_loss", vae_loss)
        self.log("val_vae_recon_loss", vae_recon_loss)
        self.log("train_vae_r_loss", vae_recon_r)
        self.log("val_vae_kld", vae_kld)
        self.log("val_vae_tc_loss", vae_tc_loss)
        self.log("r21", r21)
        self.log("r22", r22)
        self.log("r23", r23)
        self.log("r24", r24)
        self.log("ave_r2", ave_r2)
        self.log("r21h", r21h)
        self.log("r22h", r22h)
        self.log("r23h", r23h)
        self.log("r24h", r24h)
        self.log("ave_r2h", ave_r2h)
        return vae_loss

    def configure_optimizers(self):
        opt_v = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.VAE.parameters()), 
                                 lr=self.lr_VAE, betas=(self.beta1_VAE, self.beta2_VAE), weight_decay=0.0001)
        opt_d = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), 
                                 lr=self.lr_D, betas=(self.beta1_D, self.beta2_D))
        return opt_v, opt_d