import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import itertools
import numpy as np
from tqdm import tqdm
import logging
import glob
import numpy as np
from sklearn.manifold import TSNE
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image.inception import InceptionScore
from scipy.stats import entropy
from scipy.spatial.distance import pdist, squareform
# from torchvision.utils import save_image

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_

import torch.optim as optim
from utils import *
from Scheduler import GradualWarmupScheduler
from gmvae import *

# from zuko.utils import odeint
from torchdiffeq import odeint_adjoint as odeint
from dataloader.dataloader_mnist import *

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy


torch.set_default_dtype(torch.float32)
torch.set_printoptions(precision=3)

def map_to_position(input_tensor, num_classes):
    """
    Maps each element in the input tensor to a one-hot encoded row
    with a 1 in the position specified by the value in the input tensor.
    
    Parameters:
    - input_tensor: A tensor of values between 0 and 9.
    
    Returns:
    - A new tensor where each row is a one-hot encoding with 1 in the position
      specified by the corresponding element in the input_tensor.
    """
    num_rows = input_tensor.size(0)
    output_tensor = torch.zeros((num_rows, num_classes))  # Create a tensor of zeros with 10 columns
    output_tensor[torch.arange(num_rows), input_tensor] = 1  # Set the index specified by input_tensor to 1
    
    return output_tensor

def first_three_eigen_proj(x):
    # Step 1: Compute the covariance matrix
    x_centered = (x - x.mean(0, keepdims=True))/np.sqrt(x.var(0, keepdims=True))
    cov_matrix = np.cov(x_centered, rowvar=False)

    # Step 2: Perform Eigen decomposition
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # Step 3: Sort eigenvalues and eigenvectors in descending order
    sorted_indices = np.argsort(eigenvalues)[::-1]
    first_two_eigenvectors = eigenvectors[:, sorted_indices[:3]]

    # Step 4: Project the array onto the first two Eigen directions
    projected_array = np.dot(x, first_two_eigenvectors)

    return projected_array

def logit_trans(data, eps=1e-6):
    data = eps + (1 - 2 * eps) * data
    return torch.log(data) - torch.log1p(-data)

class PlotLogLikelihoodCallback(Callback):
    def __init__(self, save_path="loss_plot.png", log_keys=("log_likelihood_x", "log_likelihood_z")):
        """
        Callback to plot the log likelihoods logged during training for two variables (e.g., x and z).

        Args:
            log_keys (tuple): A tuple containing keys for the two log likelihoods.
        """
        super().__init__()
        self.save_path = os.path.join(save_path, f'llk.png')
        self.log_keys = log_keys
        self.log_likelihoods_x = []
        self.log_likelihoods_z = []

    def on_train_epoch_end(self, trainer, pl_module):
        """
        Called at the end of each training epoch.

        Args:
            trainer (Trainer): The PyTorch Lightning trainer instance.
            pl_module (LightningModule): The LightningModule being trained.
        """
        # Retrieve the logged values from the trainer's logger
        if self.log_keys[0] in trainer.callback_metrics:
            log_likelihood_x = trainer.callback_metrics[self.log_keys[0]].item()
            self.log_likelihoods_x.append(log_likelihood_x)
        
        if self.log_keys[1] in trainer.callback_metrics:
            log_likelihood_z = trainer.callback_metrics[self.log_keys[1]].item()
            self.log_likelihoods_z.append(log_likelihood_z)

        # Plot the log likelihoods
        plt.figure(figsize=(12, 6))

        # Subplot for log likelihood of x
        plt.subplot(1, 2, 1)
        plt.plot(self.log_likelihoods_x, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(x|z)")
        plt.title("Log Likelihood of x During Training")
        # plt.legend()
        plt.grid()

        # Subplot for log likelihood of z
        plt.subplot(1, 2, 2)
        plt.plot(self.log_likelihoods_z, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(z|x)")
        plt.title("Log Likelihood of z During Training")
        # plt.legend()
        plt.grid()

        plt.tight_layout()
        plt.savefig(self.save_path)
        plt.close()



class PlotLossCallback(Callback):
    def __init__(self, save_path="loss_plot.png", update_interval=5):
        super().__init__()
        self.save_path = save_path
        self.update_interval = update_interval
        self.train_losses = []
        self.val_losses = []
        self.epochs = []

    def on_validation_end(self, trainer, pl_module):
        # Get current epoch
        current_epoch = trainer.current_epoch

        # Log training and validation loss
        train_loss = trainer.callback_metrics.get("train_loss")
        val_loss = trainer.callback_metrics.get("val_loss")

        if train_loss is not None and val_loss is not None:
            self.epochs.append(current_epoch)
            self.train_losses.append(train_loss.cpu().item())
            self.val_losses.append(val_loss.cpu().item())

        # Update the plot every `update_interval` epochs
        if current_epoch % self.update_interval == 0:
            self.plot_and_save()

    def plot_and_save(self):
        plt.figure(figsize=(10, 6))
        plt.plot(self.epochs, self.train_losses, label="Training Loss", marker="o")
        plt.plot(self.epochs, self.val_losses, label="Validation Loss", marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.save_path)
        plt.close()


class GMVAELightningModule(pl.LightningModule):
    def __init__(self, decoder: nn.Module, encoderpi: nn.Module, encoderz: nn.Module, priorz, config, args):
        super().__init__()
        self.config = config
        self.args = args
        
        self.decoder = decoder
        self.encoderz = encoderz
        self.encoderpi = encoderpi
        self.priorz = priorz
        self.k, self.d = self.config.flow.pi_dim, self.config.flow.z_dim
        self.c, self.p = self.config.data.channel, self.config.data.size
        # Register buffers for priorz parameters

        self.automatic_optimization = False
        self.last_validation_batch = None


    def setup(self, stage=None):
    	self.elbo_loss = ELBO(self.decoder, self.encoderz, self.encoderpi, self.priorz)


    def training_step(self, batch, batch_idx):
        # print("train")
        X, y = batch  # Assuming the batch is the input data `x`
        # print("sample", X[0])
        # print("X", X[0])
        X = X.view(-1, self.config.data.size**2)
        loss = self.elbo_loss(X)

        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        g_opt, d_opt = self.optimizers()
        d_opt.zero_grad()
        g_opt.zero_grad()

        self.manual_backward(loss)
        d_opt.step()
        g_opt.step()

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        X = X.view(-1, self.config.data.size**2)
        val_loss = self.elbo_loss(X)

        # sample from posterior
        zc, zcidx, pic = self.sample(X)

        # evaluate log likelihood of z|x and x|z
        log_post_z = self.encoderz.log_prob(zcidx, X, zc).mean()
        log_lik = self.decoder.log_prob(zc, X).mean()


        # Store the last batch for plotting
        if batch_idx == self.trainer.num_val_batches[0] - 1:
            self.last_validation_batch = {"X": X, "y": y}

        self.log('val_loss', val_loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
        self.log('val_log_post_z', log_post_z, on_step=False, on_epoch=True, sync_dist=True, logger=True)
        self.log('val_log_lik', log_lik, on_step=False, on_epoch=True, sync_dist=True, logger=True)

        return val_loss


    def on_train_epoch_end(self):
        # manual scheduler step
        
        sch1, sch2 = self.lr_schedulers()
        sch1.step()
        sch2.step()

    def on_validation_epoch_end(self):
        if self.last_validation_batch is not None:
            X = self.last_validation_batch["X"]
            y = self.last_validation_batch["y"]

            X = X.view(-1, self.config.data.size**2)

            if self.current_epoch % self.config.training.snapshot_freq == 0:
                print("sampling")
                self.generate(X, y)
                """ Snapshot sampling at the end of every epoch """
                # if self.config.training.snapshot_sampling:
                self.plot_latent(X, y)

        # Clear the stored batch for next epoch
        self.last_validation_batch = None

    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(
            itertools.chain(
                self.decoder.parameters(),
                ), 
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)
        cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = self.config.training.n_epochs // 10 * 9,
                            eta_min = 0,
                            last_epoch = -1
                        )
        warmUpScheduler = GradualWarmupScheduler(
                                optimizer = optimizer,
                                multiplier = self.config.optim.multiplier,
                                warm_epoch = self.config.training.n_epochs // 10,
                                after_scheduler = cosineScheduler,
                                # after_scheduler = Scheduler,
                                last_epoch = self.current_epoch
                            )
        optimizer_latent = torch.optim.AdamW(
            itertools.chain(
                self.encoderz.parameters(),
                self.encoderpi.parameters(),
                self.priorz.parameters()
                ), 
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)
        cosineScheduler_latent = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer_latent,
                            T_max = self.config.training.n_epochs // 10 * 9,
                            eta_min = 0,
                            last_epoch = -1
                        )
        warmUpScheduler_latent = GradualWarmupScheduler(
                                optimizer = optimizer_latent,
                                multiplier = self.config.optim.multiplier,
                                warm_epoch = self.config.training.n_epochs // 10,
                                after_scheduler = cosineScheduler_latent,
                                # after_scheduler = Scheduler,
                                last_epoch = self.current_epoch
                            )
        return [optimizer, optimizer_latent], [
                {'scheduler': warmUpScheduler, 
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}, 
                {'scheduler': warmUpScheduler_latent, 
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}]

    def generate(self, x, y):
        # generate posterior predictive
        self.decoder.eval()
        self.encoderz.eval()
        self.encoderpi.eval()
        self.priorz.eval()

        with torch.no_grad():
            fig, axes = plt.subplots(10, 8, figsize=(10, 10))
            for row_idx, row_axes in enumerate(axes):
                m = len(row_axes)
                z1idx = map_to_position(torch.ones_like(y[0], device=self.device).unsqueeze(0)*row_idx, self.k).repeat(m,1).to(self.device)
                # print(z1idx)
                z1 = self.priorz.rsample(z1idx, None)
                
                y0_new = self.decoder.sample(z1)

                x0_np = inv_transform(y0_new).cpu().detach().numpy()
                
                x0_np = x0_np.reshape((-1,self.c,self.p,self.p))
                
                for col_idx, ax in enumerate(row_axes):   
                    ax.imshow(x0_np[col_idx][0], cmap='gray')
                    if col_idx == 0:
                        ax.set_ylabel(r"$\xi$={}".format(row_idx))
                    else:
                        ax.set_ylabel("")
                    ax.set_xlabel("")
                    ax.get_xaxis().set_ticks([])
                    ax.get_yaxis().set_ticks([])

            plt.tight_layout()
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'priorgen_grid_epoch_{self.current_epoch}.png'))
            plt.close()

            fig, axes = plt.subplots(10, 8, figsize=(10,10))
            for row_idx, row_axes in enumerate(axes):
                m = len(row_axes)
                mask = y==row_idx
                if mask.sum() == 0:
                    pass
                else:
                    x_k = x[mask][0]
                    y_k = map_to_position(y[y==row_idx], self.k)[0]
                    x0 = x_k.repeat(m-1,1)

                    # genrate posterior
                    pic, logits, prob = self.encoderpi.rsample(x0)
                    zc = self.encoderz.rsample(pic, x0)
                    xnew = self.decoder.sample(zc)

                    xnew_np = inv_transform(xnew).cpu().detach().numpy()
                    x_np = inv_transform(x_k).cpu().detach().numpy()

                    xnew_np = xnew_np.reshape((-1,self.c,self.p,self.p))
                    x_np = x_np.reshape((self.c,self.p,self.p))

                    for col_idx, ax in enumerate(row_axes):   
                        if col_idx == 0:
                            ax.imshow(x_np[0], cmap='gray')
                            ax.set_ylabel("y={}".format(row_idx))
                        else:
                            ax.imshow(xnew_np[col_idx-1][0], cmap='gray')
                            ax.set_ylabel("")
                        ax.set_xlabel("")
                        ax.get_xaxis().set_ticks([])
                        ax.get_yaxis().set_ticks([])

            plt.tight_layout()
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'image_grid_epoch_{self.current_epoch}.png'))
            plt.close()

            # posterior sample from pit
            fig, axes = plt.subplots(10, 3, figsize=(10,10), sharex=True, constrained_layout=True)
            for row_idx, row_axes in enumerate(axes):
                # print(f"Row {row_idx}")
                m = len(row_axes)
                mask = y==row_idx
                if mask.sum() == 0:
                    pass
                else:
                    x_k = x[mask]
                    for col_idx, ax in enumerate(row_axes):   
                        if col_idx < len(x_k):
                            x0 = x_k[col_idx].repeat(100,1)

                            pic, logits, prob = self.encoderpi.rsample(x0)
                            
                            kidx = torch.argmax(pic, dim=-1).squeeze()
                            kidx_np = kidx.cpu().detach().numpy()

                            # ax.hist(z0, bins=20, alpha=0.7)
                            ax.hist(kidx_np, bins=20, alpha=0.7)
                        else:
                            ax.axis('off')
                        # ax.set_title(f"y={y[i]}")
                        if col_idx == 0:
                            ax.set_ylabel(f"y={row_idx}")

            # plt.tight_layout(pad=3.0)
            plt.savefig(os.path.join(self.args.log_sample_path, 'postpi_grid_epoch_{}.png'.format(self.current_epoch)))
            plt.close()

        self.decoder.train()
        self.encoderz.train()
        self.encoderpi.train()
        self.priorz.train()


    def sample(self, x):
        zcidx, logits, prob = self.encoderpi.rsample(x)
        zc = self.encoderz.rsample(zcidx, x)

        return zc, zcidx, prob

    def plot_latent(self, x, y):
        self.decoder.eval()
        self.encoderz.eval()
        self.encoderpi.eval()
        self.priorz.eval()

        with torch.no_grad():
            # sample from posterior
            zc, zcidx, pic = self.sample(x)

            zc_np = zc.cpu().detach().numpy()
            # k0_np = k0.cpu().detach().numpy()
            # k0_np = np.argmax(pi0_np, axis=1)
            k_np = np.argmax(zcidx.cpu().detach().numpy(), axis=1)
            # print(pi0_np)
            y_np = y.cpu().detach().numpy()

            cmap = plt.colormaps['tab10']
            if self.config.flow.z_dim == 2:
                zc_proj = zc_np.squeeze()
                plt.figure(figsize=(8, 6))
                plt.scatter(zc_proj[:,0], zc_proj[:,1], c=k_np, cmap=cmap, s=10)
                # Add the color bar
                plt.colorbar(shrink=0.5, orientation='vertical')
                plt.title("Generated latent z given x")
                plt.tight_layout()
                plt.savefig(os.path.join(self.args.log_sample_path, f'postz_grid_epoch_{self.current_epoch}.png'))
                plt.close()
            elif self.config.flow.z_dim == 3:
                zc_proj = zc_np.squeeze()
                fig = plt.figure(figsize=(8, 8))
                ax = fig.add_subplot(111, projection='3d')
                scatter = ax.scatter(zc_proj[:,0], zc_proj[:,1], zc_proj[:,2], c=k_np, cmap=cmap, s=10)
                cbar = plt.colorbar(scatter, ax=ax, pad=0.1, orientation='vertical', shrink=0.5)
                ax.set_title("Generated latent z given x")

                plt.tight_layout()
                plt.savefig(os.path.join(self.args.log_sample_path, f'postz_grid_epoch_{self.current_epoch}.png'))
                plt.close()
            else:
                # z0_proj = first_three_eigen_proj(z0_np)
                tsne = TSNE(n_components=3, perplexity=30, random_state=0)
                zc_proj = tsne.fit_transform(zc_np)

                fig = plt.figure(figsize=(8, 8))
                ax = fig.add_subplot(111, projection='3d')
                scatter = ax.scatter(zc_proj[:,0], zc_proj[:,1], zc_proj[:,2], c=y_np, cmap=cmap, s=10)
                cbar = plt.colorbar(scatter, ax=ax, pad=0.1, orientation='vertical', shrink=0.5)
                ax.set_title("Generated latent z given x")

                plt.tight_layout()
                plt.savefig(os.path.join(self.args.log_sample_path, f'postz_grid_epoch_{self.current_epoch}.png'))
                plt.close()


        self.decoder.train()
        self.encoderz.train()
        self.encoderpi.train()
        self.priorz.train()

    def test_generate(self, x, y):
        # generate posterior predictive
        inception = InceptionScore()
        inception.inception = inception.inception.to(torch.float32)
        with torch.no_grad():
            fig, axes = plt.subplots(10, 8, figsize=(10, 10))
            z1idx = torch.rand((10*10, self.k))
            z1 = self.priorz.rsample(z1idx, None)
            y0_new = self.decoder.sample(z1)
            y0_new = inv_transform(y0_new)
            x0_np = y0_new.cpu().detach().numpy()
            x0_np = x0_np.reshape((-1,self.c,self.p,self.p))
            for i, ax in enumerate(axes.flatten()):
                ax.imshow(x0_np[i][0], cmap='gray')
                ax.set_ylabel("")
                ax.set_xlabel("")
                ax.get_xaxis().set_ticks([])
                ax.get_yaxis().set_ticks([])
            plt.tight_layout()
            plt.savefig(os.path.join(self.args.log_sample_path, f'test_gen.png'))
            plt.close()

            gen_img = y0_new.repeat(1, 3, 1, 1)
            inception.update(gen_img.to(torch.uint8)) # comput this over all classes
            is_score_mean, is_score_std = inception.compute()

            fid = FrechetInceptionDistance()
            lpips = LearnedPerceptualImagePatchSimilarity()
            ssim = StructuralSimilarityIndexMeasure()
            
            loss_lpips = []
            loss_ssim = []

            fig, axes = plt.subplots(10, 8, figsize=(10,10))
            for row_idx, row_axes in enumerate(axes):
                m = len(row_axes)
                mask = y==row_idx
                if mask.sum() == 0:
                    pass
                else:
                    x_k = x[mask][0]
                    y_k = map_to_position(y[y==row_idx], self.k)[0]
                    x0 = x_k.repeat(m-1,1)
                    # genrate posterior
                    pic, logits, prob = self.encoderpi.rsample(x0)
                    zc = self.encoderz.rsample(pic, x0)
                    xnew = self.decoder.sample(zc)
                    xnew = inv_transform(xnew).reshape((-1,self.c,self.p,self.p))
                    # print("xnew", xnew.shape)

                    xnew_np = inv_transform(xnew).cpu().detach().numpy()
                    x_np = inv_transform(x_k).cpu().detach().numpy()
                    x0 = inv_transform(x0).reshape((-1,self.c,self.p,self.p))
                    # print("x0", x0.shape)
                    
                    x_np = x_np.reshape((self.c,self.p,self.p))

                    for col_idx, ax in enumerate(row_axes):   
                        if col_idx == 0:
                            ax.imshow(x_np[0], cmap='gray')
                            ax.set_ylabel("y={}".format(row_idx))
                        else:
                            ax.imshow(xnew_np[col_idx-1][0], cmap='gray')
                            ax.set_ylabel("")
                            # loss_r.append(np.sqrt(((xnew_np[col_idx-1][0] - x_np[0])**2).sum()))

                        ax.set_xlabel("")
                        ax.get_xaxis().set_ticks([])
                        ax.get_yaxis().set_ticks([])

                    # fid.update(x0.repeat(1, 3, 1, 1).to(torch.float32).to(torch.uint8), real=True)
                    # fid.update(xnew.repeat(1, 3, 1, 1).to(torch.float32).to(torch.uint8), real=False)
                    # fid_score = fid.compute()
                    # loss.append(fid_score) 
                    # fid.reset()
                    real_img = x0.repeat(1, 3, 1, 1).to(torch.float32) #.to(torch.uint8)
                    gen_img = xnew.repeat(1, 3, 1, 1).to(torch.float32) #.to(torch.uint8)
                    # lpips_score = lpips(real_img, gen_img)
                    ssim_score = ssim(real_img, gen_img)
                    # loss_lpips.append(lpips_score)
                    loss_ssim.append(ssim_score)

            plt.tight_layout()
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'test_regenerate.png'))
            plt.close()


            # posterior sample from pit
            fig, axes = plt.subplots(10, 3, figsize=(10,10), sharex=True, constrained_layout=True)
            for row_idx, row_axes in enumerate(axes):
                # print(f"Row {row_idx}")
                m = len(row_axes)
                mask = y==row_idx
                if mask.sum() == 0:
                    pass
                else:
                    x_k = x[mask]
                    for col_idx, ax in enumerate(row_axes):   
                        if col_idx < len(x_k):
                            x0 = x_k[col_idx].repeat(100,1)
                            pic, logits, prob = self.encoderpi.rsample(x0)
                            
                            kidx = torch.argmax(pic, dim=-1).squeeze()
                            kidx_np = kidx.cpu().detach().numpy()

                            # ax.hist(z0, bins=20, alpha=0.7)
                            ax.hist(kidx_np, bins=20, alpha=0.7)
                        else:
                            ax.axis('off')
                        # ax.set_title(f"y={y[i]}")
                        if col_idx == 0:
                            ax.set_ylabel(f"y={row_idx}")

            # plt.tight_layout(pad=3.0)
            plt.savefig(os.path.join(self.args.log_sample_path, 'test_postpi_grid.png'.format(self.current_epoch)))
            plt.close()

        return np.array(loss_ssim).mean(), is_score_mean, is_score_std

    def test_generation_figure(self, x, y):
        # assume x cnotains each digit no repeat no missing
        # pick 1 digit each (a) plot original (2) plot one predictive (3) plot one posterior on \xi
        sorted_idx = torch.argsort(y)
        xk = x[sorted_idx]
        print("xk", xk.shape)
        row_ratios = [1, 1, 1.5]
        with torch.no_grad():
            fig, axes = plt.subplots(3, 10, figsize=(10,4), gridspec_kw={'height_ratios': row_ratios}) 

            # genrate posterior
            pic, logits, prob = self.encoderpi.rsample(xk)
            zc = self.encoderz.rsample(pic, xk)
            xnew = self.decoder.sample(zc)
            xnew = inv_transform(xnew).reshape((-1,self.c,self.p,self.p))
            xk = inv_transform(xk).reshape((-1,self.c,self.p,self.p))
            # print("xnew", xnew.shape)
            # find probability for each x

            xnew_np = xnew.cpu().detach().numpy()
            xk_np = xk.cpu().detach().numpy()
            prob_np = prob.cpu().detach().numpy()
            # print("x0", x0.shape)
            
            for col_idx in range(10):
                axes[0, col_idx].imshow(xk_np[col_idx][0], cmap='gray')
                axes[1, col_idx].imshow(xnew_np[col_idx][0], cmap='gray')
                axes[2, col_idx].bar(range(self.k), prob_np[col_idx])
                axes[0, col_idx].get_xaxis().set_ticks([])
                axes[0, col_idx].get_yaxis().set_ticks([])
                axes[1, col_idx].get_xaxis().set_ticks([])
                axes[1, col_idx].get_yaxis().set_ticks([])
                axes[2, col_idx].set_xlim(-1,self.k)
                axes[2, col_idx].set_ylim(0,1)
                if col_idx == 0:
                    axes[0, col_idx].set_ylabel("Real")
                    axes[1, col_idx].set_ylabel("Generated")
                    axes[2, col_idx].set_ylabel(r"$p(\xi|x)$")
                if col_idx != 0:
                    axes[2, col_idx].get_yaxis().set_ticks([])

            plt.tight_layout()
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'eval_display.png'))
            plt.close()


    def test_plot_latent(self, x, y):

        # sample from posterior
        zc, zcidx, pic = self.sample(x)

        zc_np = zc.cpu().detach().numpy()
        # k0_np = k0.cpu().detach().numpy()
        # k0_np = np.argmax(pi0_np, axis=1)
        k_np = np.argmax(zcidx.cpu().detach().numpy(), axis=1)
        pi_np = pic.cpu().detach().numpy()
        # print(pi0_np)
        y_np = y.cpu().detach().numpy()

        cmap = plt.colormaps['tab10']
        
        # z0_proj = first_three_eigen_proj(z0_np)
        tsne = TSNE(n_components=2, perplexity=50, random_state=14, n_iter=1000)
        zc_proj = tsne.fit_transform(zc_np)

        plt.figure(figsize=(8,6))
        plt.scatter(zc_proj[:,0], zc_proj[:,1], c=y_np, cmap=cmap, s=10)
        plt.colorbar(shrink=0.5)
        # ax.set_title("Generated latent z given x")

        plt.tight_layout()
        plt.savefig(os.path.join(self.args.log_sample_path, f'eval_grid_epoch_{self.current_epoch}_50.png'))
        plt.close()

        nmi = normalized_mutual_info_score(y_np, k_np)
        ard = adjusted_rand_score(y_np, k_np)

        nmi_soft = soft_nmi(pi_np, y_np)
        # ari = soft_ari(pi_np, y_np)
        return nmi, ard, nmi_soft

    def test_step(self, batch, batch_idx):
        X, y = batch
        X = X.view(-1, self.config.data.size**2)

        if self.args.figure:
            self.test_generation_figure(X, y)
        else:
  
            ssim_loss, is_score_mean, is_score_std = self.test_generate(X, y)
            # self.log('test_lpips_loss', lpips_loss, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            self.log('test_ssim_loss', ssim_loss, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            self.log('test_is_loss', is_score_mean, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            self.log('test_is_loss_std', is_score_std, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            """ Snapshot sampling at the end of every epoch """
            # if self.config.training.snapshot_sampling:
            nmi, ard, nmi_soft = self.test_plot_latent(X, y)
            self.log('test_post_clus_nmi', nmi, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            self.log('test_post_clus_ard', ard, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            self.log('test_post_clus_nmi_soft', nmi_soft, on_step=True, on_epoch=True, sync_dist=True, logger=True)



class GMVAERunner():
    def __init__(self, args, config):
        self.args = args
        self.config = config
        args.log_sample_path = os.path.join(args.log_path, 'samples')
        os.makedirs(args.log_sample_path, exist_ok=True)


        self.k, self.d = self.config.flow.pi_dim, self.config.flow.z_dim
        self.c, self.p = self.config.data.channel, self.config.data.size

        self.decoder = LLKNet(self.p**2, self.d, hidden_features=[self.config.model.ngf]*5, fct=nn.Softplus()
        	)
        self.encoderz = GaussianNet(self.k, self.d, self.p**2, hidden_features=[self.config.model.ngf]*2,
        	fct=nn.Softplus()
            )
        self.encoderpi = CatNet(self.p**2, self.k, hidden_features=[self.config.flow.ngf]*2,
        	fct=nn.Softplus()
            # fct=nn.Tanh()
            )
        self.priorz = GaussianNet(self.k, self.d, hidden_features=[], fct=nn.Softplus(), batch_norm=True)


        # Define the ModelCheckpoint callback
        self.checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Metric to monitor
            dirpath=self.args.log_path,  # Directory where checkpoints will be saved
            filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
            save_top_k=5,  # Only save the best model based on val_loss
            mode='min'  # Minimize the validation loss
        )
        # Initialize the Trainer

        if torch.cuda.is_available():
            accelerator='gpu'
            strategy="ddp"
            devices="auto"
        else:
            accelerator='cpu'
            devices="auto"
            strategy = "auto"

        # Add the callback
        plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss.png'), update_interval=1)
        plot_llk_callback = PlotLogLikelihoodCallback(save_path=self.args.log_sample_path, log_keys=("val_log_lik", "val_log_post_z"))

        self.trainer = pl.Trainer(
            max_epochs=self.config.training.n_epochs, 
            # accelerator='gpu',
            accelerator = accelerator,
            devices=devices,
            strategy=strategy,
            # strategy='ddp',
            # strategy=DDPStrategy(find_unused_parameters=True),
            # gpus=self.config.,  # Number of GPUs
            logger=pl.loggers.TensorBoardLogger(save_dir=self.args.tb_path, name="mnist"),  # Optional: log to TensorBoard
            # logger=self.config.tb_logger
            callbacks=[plot_loss_callback, plot_llk_callback],
        )

    def train(self):
        # load data
        dataset, val_dataset, sampler, val_sampler = get_mnist(
            self.config.data.n_classes, "data", self.config.data.samplesize, self.config.data.test_samplesize)
        train_dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size,
                                num_workers=self.config.data.num_workers, sampler=sampler)
        val_dataloader = DataLoader(val_dataset, batch_size=self.config.training.batch_size, 
                                 num_workers=self.config.data.num_workers, sampler=val_sampler, drop_last=True)
        # Initialize the Lightning model
        model = GMVAELightningModule(self.decoder, self.encoderpi, self.encoderz, self.priorz, self.config, self.args)
        # Run the training loop
        if not self.args.resume_training:
            ckpt_path = None
        else:
            _ckpt_path = os.path.join(self.args.tb_path, "mnist/version_0/checkpoints")
            ckpt_files = [f for f in 
                    os.listdir(_ckpt_path) if f.endswith('.ckpt')]
            ckpt_path = os.path.join(_ckpt_path, ckpt_files[-1])

        self.trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)

    def sample(self):
        _, test_dataset, _, test_sampler = get_mnist(
                self.config.data.n_classes, "data", 500, 500)
        test_dataloader = DataLoader(test_dataset, batch_size=5000,
                num_workers=self.config.data.num_workers, sampler=test_sampler)

        _ckpt_path = os.path.join(self.args.tb_path, "mnist/version_0/checkpoints")
        ckpt_files = [f for f in 
                os.listdir(_ckpt_path) if f.endswith('.ckpt')]
        ckpt_path = os.path.join(_ckpt_path, ckpt_files[-1])

        model = GMVAELightningModule(self.decoder, self.encoderpi, self.encoderz, self.priorz, self.config, self.args) #.load_from_checkpoint(ckpt_path)

        # set to test mode
        self.trainer.test(model, dataloaders=test_dataloader, ckpt_path=ckpt_path)

        # Access the aggregated value
        log_lik = self.trainer.callback_metrics["test_log_lik"]
        log_post = self.trainer.callback_metrics["test_log_post_z"] 
        test_pred_loss = self.trainer.callback_metrics["test_pred_loss"]
        test_post_clus = self.trainer.callback_metrics["test_post_clus"]
        

    def draw_figure(self):

        _, test_dataset, _, test_sampler = get_mnist(
                self.config.data.n_classes, "data", 1, 1)
        test_dataloader = DataLoader(test_dataset, batch_size=10,
                num_workers=self.config.data.num_workers, sampler=test_sampler)

        _ckpt_path = os.path.join(self.args.tb_path, "mnist/version_0/checkpoints")
        ckpt_files = [f for f in 
                os.listdir(_ckpt_path) if f.endswith('.ckpt')]
        ckpt_path = os.path.join(_ckpt_path, ckpt_files[-1])

        model = GMVAELightningModule(self.decoder, self.encoderpi, self.encoderz, self.priorz, self.config, self.args) #.load_from_checkpoint(ckpt_path)

        # set to test mode
        self.trainer.test(model, dataloaders=test_dataloader, ckpt_path=ckpt_path)




        