import os, importlib
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision

import pytorch_lightning as pl
from pytorch_lightning import LightningModule

from . import datasets, sampler
from .util import save_sample_images

from hydra import compose, initialize
from hydra.utils import instantiate
from scipy.special import gamma
from copy import deepcopy

class AutoEncoder_abstract(LightningModule):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__()
        self.hydra_log = log
        if verbose == 1:
            self.hydra_log.info('------------------------------------------------------------')
            for key in cfg['train_info']:
                self.hydra_log.info('%s : %s' % (key, cfg['train_info'][key]))

            for key in cfg['path_info']:
                self.hydra_log.info('%s : %s' % (key, cfg['path_info'][key]))
        
        self.encoder = nn.Identity()
        self.decoder = nn.Identity()

        self.z_dim = int(cfg['train_info']['z_dim'])
        self.z_sampler = getattr(sampler, cfg['train_info']['z_sampler']) # generate prior

        self.lr = float(cfg['train_info']['lr'])
        self.beta1 = float(cfg['train_info']['beta1'])
        self.num_epoch = int(cfg['train_info']['epoch'])

        # self.tb_logs = cfg['path_info']['tb_logs']
        self.save_img_path = cfg['path_info']['save_img_path']

        self.get_recon_flag = True

        self.encoder_trainable = [self.encoder]
        self.decoder_trainable = [self.decoder]

    def log_architecture(self):
        self.hydra_log.info('------------------------------------------------------------')
        for net in self.encoder_trainable:
            self.hydra_log.info(net)
        for net in self.decoder_trainable:
            self.hydra_log.info(net)

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        return self.decoder(self.encoder(x))

    def configure_optimizers(self):
        optimizer = optim.Adam(sum([list(net.parameters()) for net in self.encoder_trainable], []) + sum([list(net.parameters()) for net in self.decoder_trainable], []), lr = self.lr, betas = (self.beta1, 0.999))
        return {"optimizer": optimizer}

    def _get_reconstruction_loss(self, batch):
        """Given a batch of images, this function returns the reconstruction loss (MSE in our case)"""
        x, _ = batch  # When batch returns both image and label
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="none").sum(dim=[1, 2, 3]).mean(dim=[0])
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)

        # Progress Bar
        self.log("recon", loss, prog_bar=True, logger = False)
        # TensorBoard
        self.log("train/recon", loss, on_step = False, on_epoch = True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)

        # Progress Bar
        self.log("recon", loss, prog_bar=True, logger = False)
        # TensorBoard
        self.log("train/recon", loss, on_step = False, on_epoch = True)
        
        return loss

class Classifier_abstract(AutoEncoder_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)

    def _get_losses(self, batch, acc = False):
        x, y = batch  # When batch returns both image and label
        p = self.decode(self.encode(x))
        if len(y.shape) == 2:
            if y.shape[1] == 1:
                y = y.squeeze(1)
        if acc:
            return F.cross_entropy(p, y), (p.max(dim = 1).indices == y).sum().item()/len(x)
        else:
            return F.cross_entropy(p, y), None

    def training_step(self, batch, batch_idx):
        loss, acc= self._get_losses(batch, True)
        self.log("CEloss", loss, prog_bar=True, logger = False)
        self.log("train/CEloss", loss, on_step = False, on_epoch = True)
        if acc is not None:
            self.log("acc", acc, prog_bar=True, logger = False)
            self.log("train/acc", acc, on_step = False, on_epoch = True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, acc= self._get_losses(batch, True)
        self.log("CEloss", loss, prog_bar=True, logger = False)
        self.log("test/CEloss", loss, on_step = False, on_epoch = True)
        if acc is not None:
            self.log("acc", acc, prog_bar=True, logger = False)
            self.log("test/acc", acc, on_step = False, on_epoch = True)
        return loss

# class WAE_MMD_abstract(AutoEncoder_abstract):
#     def __init__(self, cfg, log, verbose = 1):
#         super().__init__(cfg, log, verbose)
#         self.z_sampler = getattr(sampler, cfg['train_info']['z_sampler'])
#         self.lamb = float(cfg['train_info']['lambda'])

#     def k(self, x, y, diag = True):
#         stat = 0.
#         for scale in [.1, .2, .5, 1., 2., 5., 10.]:
#             C = scale*2*self.z_dim*2
#             kernel = (C/(C + (x.unsqueeze(0) - y.unsqueeze(1)).pow(2).sum(dim = 2)))
#             if diag:
#                 stat += kernel.sum()
#             else:
#                 stat += kernel.sum() - kernel.diag().sum()
#         return stat
    
#     def penalty_loss(self, x, y, n):
#         return (self.k(x,x, False) + self.k(y,y, False))/(n*(n-1)) - 2*self.k(x,y, True)/(n*n)

#     def _get_losses(self, batch):
#         """Given a batch of images, this function returns the reconstruction loss (MSE in our case)"""
#         x, _ = batch  # When batch returns both image and label
#         n = len(x)

#         fake_latent = self.encode(x)
#         prior_z = self.z_sampler(n, self.z_dim).type_as(fake_latent)
#         x_hat = self.decode(fake_latent)

#         loss = F.mse_loss(x, x_hat, reduction="none").sum(dim=[1, 2, 3]).mean(dim=[0])
#         penalty = self.penalty_loss(fake_latent, prior_z, n)
#         return loss, penalty

#     def training_step(self, batch, batch_idx):
#         loss, penalty = self._get_losses(batch)
#         # tqdm_dict = {"recon":loss.detach(), "penalty":penalty.detach()}

#         # Progress Bar
#         self.log("recon", loss, prog_bar=True, logger = False)
#         self.log("penalty", penalty, prog_bar=True, logger = False)

#         # TensorBoard
#         self.log("train/recon", loss, on_step = False, on_epoch = True)
#         self.log("train/penalty", penalty, on_step = False, on_epoch = True)

#         # return {"loss": loss + self.lamb * penalty, "progress_bar": tqdm_dict}
#         return loss + self.lamb * penalty
    
#     def validation_step(self, batch, batch_idx):
#         loss, penalty = self._get_losses(batch)
#         # tqdm_dict = {"recon":loss.detach(), "penalty":penalty.detach()}

#         origin = None
#         recon = None
#         if self.get_recon_flag:
#             self.hydra_log.debug(f'Epoch {self.current_epoch} - test loss: {loss:.4f} D: {penalty:.4f}')
#             self.get_recon_flag = False
#             x, _ = batch
#             with torch.no_grad():
#                 recon = self.decode(self.encode(x)).detach()
#             origin = x.detach()

#         # Progress Bar
#         self.log("recon", loss, prog_bar=True, logger = False)
#         self.log("penalty", penalty, prog_bar=True, logger = False)

#         # TensorBoard
#         self.log("test/recon", loss, on_step = False, on_epoch = True)
#         self.log("test/penalty", penalty, on_step = False, on_epoch = True)

#         # return {"loss": loss + self.lamb * penalty, "progress_bar": tqdm_dict}
#         return {"loss": loss + self.lamb * penalty, "x":origin, "recon":recon}

#     def validation_epoch_end(self, outputs) -> None:
#         # sample reconstruction
#         x = outputs[0]["x"]
#         recon = outputs[0]["recon"]
#         x_recon = torch.cat((x[0:32],recon[0:32]), dim = 0)
#         self.get_recon_flag = True

#         # sample generate
#         z = self.z_sampler(64, self.z_dim).type_as(self.decoder[0].weight)
#         gen_img = self.decode(z)
#         if self.save_img_path is not None:
#             save_sample_images(self.save_img_path, "recon", self.current_epoch, (x_recon.to('cpu').detach().numpy()[0:64]))
#             save_sample_images(self.save_img_path, "gen", self.current_epoch, (gen_img.to('cpu').detach().numpy()[0:64]))

#         grid = torchvision.utils.make_grid(x_recon)
#         self.logger.experiment.add_image("reconstructed_images", grid, self.current_epoch)
#         grid = torchvision.utils.make_grid(gen_img)
#         self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

        
class WAE_GAN_abstract(AutoEncoder_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        try:
            self.gen_dataset = instantiate(cfg['train_info']['gen_data'])
            self.gen_dataloader = DataLoader(self.gen_dataset, 32, num_workers = 5, shuffle = True)
        except:
            pass

        self.z_sampler = getattr(sampler, cfg['train_info']['z_sampler'])
        self.lamb = float(cfg['train_info']['lambda'])

        self.lr_adv = float(cfg['train_info']['lr_adv'])
        self.beta1_adv = float(cfg['train_info']['beta1_adv'])

        self.disc = nn.Identity()
        self.disc_trainable = [self.disc]

    def log_architecture(self):
        self.hydra_log.info('------------------------------------------------------------')
        for net in self.encoder_trainable:
            self.hydra_log.info(net)
        for net in self.decoder_trainable:
            self.hydra_log.info(net)
        for net in self.disc_trainable:
            self.hydra_log.info(net)

    def discriminate(self, z):
        return self.disc(z)
    
    def _adv_loss(self, batch):
        x, _ = batch
        q = self.encode(x)
        p = self.z_sampler(len(q), self.z_dim).type_as(q)
        pz = self.discriminate(p)
        qz = self.discriminate(q)
        return F.binary_cross_entropy_with_logits(pz, torch.ones_like(pz)) + F.binary_cross_entropy_with_logits(qz, torch.zeros_like(qz))

    def penalty_loss(self, q):
        qz = self.discriminate(q)
        return F.binary_cross_entropy_with_logits(qz, torch.ones_like(qz))

    def _get_losses(self, batch):
        """Given a batch of images, this function returns the reconstruction loss (MSE in our case)"""
        x, _ = batch  # When batch returns both image and label

        fake_latent = self.encode(x)
        x_hat = self.decode(fake_latent)

        loss = F.mse_loss(x, x_hat, reduction="none").sum(dim=[1, 2, 3]).mean(dim=[0])
        penalty = self.penalty_loss(fake_latent)
        return loss, penalty

    def configure_optimizers(self):
        opt1 = optim.Adam(sum([list(net.parameters()) for net in self.disc_trainable], []), lr = self.lr_adv, betas = (self.beta1_adv, 0.999))
        opt2 = optim.Adam(sum([list(net.parameters()) for net in self.encoder_trainable], []) + sum([list(net.parameters()) for net in self.decoder_trainable], []), lr = self.lr, betas = (self.beta1, 0.999))
        return ({"optimizer": opt1}, {"optimizer":opt2})

    def training_step(self, batch, batch_idx, optimizer_idx):
        if optimizer_idx == 0:
            return self.lamb * self._adv_loss(batch)

        if optimizer_idx == 1:
            loss, penalty = self._get_losses(batch)

            if self.trainer.is_global_zero:
                # Progress Bar
                self.log("recon", loss, prog_bar=True, logger = False, rank_zero_only=True)
                self.log("penalty", penalty, prog_bar=True, logger = False, rank_zero_only=True)

                # TensorBoard
                self.log("train/recon", loss, on_step = False, on_epoch = True, rank_zero_only=True)
                self.log("train/penalty", penalty, on_step = False, on_epoch = True, rank_zero_only=True)

            return loss + self.lamb * penalty
    
    def validation_step(self, batch, batch_idx):
        loss, penalty= self._get_losses(batch)

        # Progress Bar
        self.log("recon", loss, prog_bar=True, logger = False, sync_dist=True)
        if self.lamb > 0.0:
            self.log("penalty", penalty, prog_bar=True, logger = False, sync_dist=True)

        # TensorBoard
        self.log("test/recon", loss, on_step = False, on_epoch = True, sync_dist=True)
        if self.lamb > 0.0:
            self.log("test/penalty", penalty, on_step = False, on_epoch = True, sync_dist=True)

        return loss + self.lamb * penalty

    def validation_epoch_end(self, outputs) -> None:
        #sample generate
        with torch.no_grad():
            # sample reconstruction
            x, _ = next(iter(self.gen_dataloader))
            x = x.type_as(self.decoder[0].weight)
            fake_latent = self.encode(x)
            recon = self.decode(fake_latent)
        
            x_recon = self.all_gather(torch.cat((x,recon), dim = 0))
            if len(x_recon.shape) > 4:
                x_recon = x_recon[0]

            grid = torchvision.utils.make_grid(x_recon)
            self.logger.experiment.add_image("reconstructed_images", grid, self.current_epoch)

            if self.lamb > 0.0:
                z = self.z_sampler(32, self.z_dim).type_as(self.decoder[0].weight)
                gen_img = self.all_gather(self.decode(z), dim = 1)
                if len(gen_img.shape) > 4:
                    gen_img = gen_img[0]

                grid = torchvision.utils.make_grid(gen_img)
                self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

        if self.save_img_path is not None:
            save_sample_images(self.save_img_path, "recon", self.current_epoch, (x_recon.to('cpu').detach().numpy()))
            if self.lamb > 0.0:
                save_sample_images(self.save_img_path, "gen", self.current_epoch, (gen_img.to('cpu').detach().numpy()))

class WFAE_abstract(WAE_GAN_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        self.y_dim = cfg["train_info"]["y_dim"]
        self.lamb_hsic = cfg["train_info"]["lambda_hsic"]
        self.embed_network = None

    def encode(self, x):
        # xx = self.embed_data(x)
        self.embed_network.eval()
        with torch.no_grad():
            y = self.embed_network.encode(x)
        return torch.cat((self.encoder(x), y), dim = 1)

    """
    Refers to original Tensorflow implementation: https://github.com/romain-lopez/HCV
    Refers to original implementations
        - https://github.com/kacperChwialkowski/HSIC
        - https://cran.r-project.org/web/packages/dHSIC/index.html
    """
    def bandwidth(self, d):
        gz = 2 * gamma(0.5 * (d+1)) / gamma(0.5 * d)
        return 1. / (2. * gz**2)

    def knl(self, x, y, gam=1.):
        dist_table = (x.unsqueeze(0) - y.unsqueeze(1)).pow(2).sum(dim = 2)
        return (-gam * dist_table).exp().transpose(0,1)

    def hsic(self, x, y):
        dx = x.shape[1]
        dy = y.shape[1]

        xx = self.knl(x, x, gam=self.bandwidth(dx))
        yy = self.knl(y, y, gam=self.bandwidth(dy))

        res = ((xx*yy).mean()) + (xx.mean()) * (yy.mean())
        res -= 2*((xx.mean(dim=1))*(yy.mean(dim=1))).mean()
        return res.clamp(min = 1e-16).sqrt()

    def _adv_loss(self, batch):
        x, _ = batch
        q = self.encode(x)[:,0:self.z_dim]
        p = self.z_sampler(len(q), self.z_dim).type_as(q)
        pz = self.discriminate(p)
        qz = self.discriminate(q)
        return F.binary_cross_entropy_with_logits(pz, torch.ones_like(pz)) + F.binary_cross_entropy_with_logits(qz, torch.zeros_like(qz))

    def _get_losses(self, batch):
        """Given a batch of images, this function returns the reconstruction loss (MSE in our case)"""
        x, y = batch  # When batch returns both image and label

        fake_latent = self.encode(x)
        assert self.embed_network.training == False
        x_hat = self.decode(fake_latent)

        loss = F.mse_loss(x, x_hat, reduction="none").sum(dim=[1, 2, 3]).mean(dim=[0])
        penalty = self.penalty_loss(fake_latent[:,0:self.z_dim])
        penalty2 = self.hsic(fake_latent[:, 0:self.z_dim], fake_latent[:, self.z_dim:])

        return loss, penalty, penalty2

    def training_step(self, batch, batch_idx, optimizer_idx):
        if optimizer_idx == 0:
            if self.lamb > 0.0:
                return self.lamb * self._adv_loss(batch)
            else:
                return None

        if optimizer_idx == 1:
            loss, penalty, hsic = self._get_losses(batch)

            # if self.trainer.is_global_zero:
            # Progress Bar
            self.log("recon", loss, prog_bar=True, logger = False, sync_dist=True)
            if self.lamb > 0.0:
                self.log("penalty", penalty, prog_bar=True, logger = False, sync_dist=True)
                self.log("hsic", hsic, prog_bar=True, logger = False, sync_dist=True)

            # TensorBoard
            self.log("train/recon", loss, on_step = False, on_epoch = True, sync_dist=True)
            if self.lamb > 0.0:
                self.log("train/penalty", penalty, on_step = False, on_epoch = True, sync_dist=True)
                self.log("train/hsic", hsic, on_step = False, on_epoch = True, sync_dist=True)

            return loss + self.lamb * penalty + self.lamb_hsic * hsic
    
    def validation_step(self, batch, batch_idx):
        loss, penalty, hsic = self._get_losses(batch)

        # Progress Bar
        self.log("recon", loss, prog_bar=True, logger = False, sync_dist=True)
        if self.lamb > 0.0:
            self.log("penalty", penalty, prog_bar=True, logger = False, sync_dist=True)
            self.log("hsic", hsic, prog_bar=True, logger = False, sync_dist=True)

        # TensorBoard
        self.log("test/recon", loss, on_step = False, on_epoch = True, sync_dist=True)
        if self.lamb > 0.0:
            self.log("test/penalty", penalty, on_step = False, on_epoch = True, sync_dist=True)
            self.log("test/hsic", hsic, on_step = False, on_epoch = True, sync_dist=True)

        # return {"loss": loss + self.lamb * penalty + self.lamb_hsic * hsic, "x":origin, "fake_latent":fake_latent, "recon":recon}
        return loss + self.lamb * penalty + self.lamb_hsic * hsic
    
    def validation_epoch_end(self, outputs) -> None:
        #sample generate
        with torch.no_grad():
            # sample reconstruction
            x, _ = next(iter(self.gen_dataloader))
            x = x.type_as(self.decoder[0].weight)
            fake_latent = self.encode(x)
            recon = self.decode(fake_latent)
        
            x_recon = self.all_gather(torch.cat((x,recon), dim = 0))
            if len(x_recon.shape) > 4:
                x_recon = x_recon[0]

            grid = torchvision.utils.make_grid(x_recon)
            self.logger.experiment.add_image("reconstructed_images", grid, self.current_epoch)

            if self.lamb > 0.0:
                z = self.z_sampler(32, self.z_dim).type_as(self.decoder[0].weight)
                gen_img = self.all_gather(self.decode(torch.cat((z, fake_latent[:, self.z_dim:]), dim = 1)))
                if len(gen_img.shape) > 4:
                    gen_img = gen_img[0]

                grid = torchvision.utils.make_grid(gen_img)
                self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

        if self.save_img_path is not None:
            save_sample_images(self.save_img_path, "recon", self.current_epoch, (x_recon.to('cpu').detach().numpy()))
            if self.lamb > 0.0:
                save_sample_images(self.save_img_path, "gen", self.current_epoch, (gen_img.to('cpu').detach().numpy()))


class WFAE_attr(WFAE_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        self.s_dim = cfg["train_info"]["s_dim"]

    def encode(self, x):
        self.embed_network.eval()
        with torch.no_grad():
            y = self.embed_network.encode(x)
        return self.encoder(x), y

    def decode(self, z, y, s):
        z1 = self.decoder_z1(torch.cat((z,y), dim = 1))
        return self.decoder(torch.cat((z1, s), dim = 1)), z1

    def _adv_loss(self, batch):
        x, y, s = batch
        q, _ = self.encode(x)
        p = self.z_sampler(len(q), self.z_dim).type_as(q)
        pz = self.discriminate(p)
        qz = self.discriminate(q)
        return F.binary_cross_entropy_with_logits(pz, torch.ones_like(pz)) + F.binary_cross_entropy_with_logits(qz, torch.zeros_like(qz))

    def _get_losses(self, batch):
        x, y, s = batch
        z, y_hat = self.encode(x)
        x_hat, z1 = self.decode(z, y_hat, s)

        loss = F.mse_loss(x, x_hat, reduction="none").sum(dim=[1, 2, 3]).mean(dim=[0])
        penalty = self.penalty_loss(z)
        penalty2 = self.hsic(z1, s)

        return loss, penalty, penalty2

    def validation_epoch_end(self, outputs) -> None:
        #sample generate
        with torch.no_grad():
            # sample reconstruction
            x, y, s = next(iter(self.gen_dataloader))
            x = x.type_as(self.decoder[0].weight)
            s = s.type_as(self.decoder[0].weight)
            z, y_hat = self.encode(x)
            recon, _ = self.decode(z, y_hat, s)
        
            x_recon = self.all_gather(torch.cat((x,recon), dim = 0))
            if len(x_recon.shape) > 4:
                x_recon = x_recon[0]

            grid = torchvision.utils.make_grid(x_recon)
            self.logger.experiment.add_image("reconstructed_images", grid, self.current_epoch)

            if self.lamb > 0.0:
                z = self.z_sampler(32, self.z_dim).type_as(self.decoder[0].weight)
                gen_img, _ = self.all_gather(self.decode(z[0:32], y_hat, s[0:32]))
                if len(gen_img.shape) > 4:
                    gen_img = gen_img[0]

                grid = torchvision.utils.make_grid(gen_img)
                self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

        if self.save_img_path is not None:
            save_sample_images(self.save_img_path, "recon", self.current_epoch, (x_recon.to('cpu').detach().numpy()))
            if self.lamb > 0.0:
                save_sample_images(self.save_img_path, "gen", self.current_epoch, (gen_img.to('cpu').detach().numpy()))

class VFAE_abstract(WFAE_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        self.s_dim = cfg["train_info"]["s_dim"]
        self.iter_per_epoch = cfg['train_info']['iter_per_epoch']
        self.lamb_hsic2 = cfg['train_info']['lambda_hsic2']

    def _adv_loss(self, batch):
        x, y, s = batch
        if len(s.shape) == 1:
            s = s.unsqueeze(1)
        if len(y.shape) == 1:
            y = y.unsqueeze(1)
        xx = self.embed_data(x)
        z2 = self.encoder(torch.cat((xx, s, y), dim = 1))
        qz = self.discriminate(z2)

        p = self.z_sampler(len(z2), self.z_dim).type_as(z2)
        pz = self.discriminate(p)

        return F.binary_cross_entropy_with_logits(pz, torch.ones_like(pz)) + F.binary_cross_entropy_with_logits(qz, torch.zeros_like(qz))

    def _get_losses(self, batch):
        x, y, s = batch
        if len(s.shape) == 1:
            s = s.unsqueeze(1)
        if len(y.shape) == 1:
            y = y.unsqueeze(1)

        xx = self.embed_data(x)
        z2 = self.encoder(torch.cat((xx, s, y), dim = 1))
        z1 = self.decoder_z1(torch.cat((z2, y), dim = 1))
        x_hat = self.decoder(torch.cat((z1, s), dim = 1))

        loss = F.mse_loss(x, x_hat)
        penalty = self.penalty_loss(z2)
        penalty2 = self.hsic(z1, s)
        penalty3 = self.hsic(z2, torch.cat((y, s), dim=1))

        return loss, penalty, penalty2, penalty3

    def training_step(self, batch, batch_idx, optimizer_idx):
        if optimizer_idx == 0:
            if self.lamb > 0.0:
                return self.lamb * self._adv_loss(batch)
            else:
                return None

        if optimizer_idx == 1:
            loss, penalty, hsic, hsic2 = self._get_losses(batch)

            # Progress Bar
            self.log("recon", loss, prog_bar=True, logger = False, sync_dist=True)
            if self.lamb > 0.0:
                self.log("penalty", penalty, prog_bar=True, logger = False, sync_dist=True)
                self.log("hsic", hsic, prog_bar=True, logger = False, sync_dist=True)
                

            # TensorBoard
            self.log("train/recon", loss, on_step = False, on_epoch = True, sync_dist=True)
            if self.lamb > 0.0:
                self.log("train/penalty", penalty, on_step = False, on_epoch = True, sync_dist=True)
                self.log("train/hsic", hsic, on_step = False, on_epoch = True, sync_dist=True)
                self.log("train/hsic2", hsic2, on_step = False, on_epoch = True, sync_dist=True)

            return loss + self.lamb * penalty + self.lamb_hsic * hsic + self.lamb_hsic2 * hsic2

    def validation_step(self, batch, batch_idx):
        loss, penalty, hsic, hsic2 = self._get_losses(batch)

        # Progress Bar
        self.log("recon", loss, prog_bar=True, logger = False)
        self.log("penalty", penalty, prog_bar=True, logger = False)
        self.log("hsic", hsic, prog_bar=True, logger = False, sync_dist=True)
        self.log("hsic2", hsic2, prog_bar=True, logger = False, sync_dist=True)

        # TensorBoard
        self.log("test/recon", loss, on_step = False, on_epoch = True)
        self.log("test/penalty", penalty, on_step = False, on_epoch = True)
        self.log("test/hsic", hsic, on_step = False, on_epoch = True, sync_dist=True)
        self.log("test/hsic2", hsic2, on_step = False, on_epoch = True, sync_dist=True)

        return loss + self.lamb * penalty + self.lamb_hsic * hsic + self.lamb_hsic2 * hsic2
    
    def validation_epoch_end(self, outputs) -> None:
        pass