import os

import argparse
import itertools
import numpy as np
from tqdm import tqdm
import logging
import glob
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image.inception import InceptionScore

# from torchvision.utils import save_image
import pandas as pd

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_

import torch.optim as optim
from utils import *
from Scheduler import WarmUpScheduler # GradualWarmupScheduler
from sfa_discrete import *
from torchdiffeq import odeint_adjoint as odeint

from dataloader.dataloader_mnist import *

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy


torch.set_default_dtype(torch.float32)
torch.set_printoptions(precision=3)


def siren_init(m):
    if isinstance(m, nn.Linear):
        with torch.no_grad():
            num_input = m.weight.size(-1)
            m.weight.uniform_(-1 / num_input, 1 / num_input)

def map_to_position(input_tensor, num_classes):
    """
    Maps each element in the input tensor to a one-hot encoded row
    with a 1 in the position specified by the value in the input tensor.
    
    Parameters:
    - input_tensor: A tensor of values between 0 and 9.
    
    Returns:
    - A new tensor where each row is a one-hot encoding with 1 in the position
      specified by the corresponding element in the input_tensor.
    """
    num_rows = input_tensor.size(0)
    output_tensor = torch.zeros((num_rows, num_classes))  # Create a tensor of zeros with 10 columns
    output_tensor[torch.arange(num_rows), input_tensor] = 1  # Set the index specified by input_tensor to 1
    
    return output_tensor

def first_three_eigen_proj(x):
    # Step 1: Compute the covariance matrix
    x_centered = (x - x.mean(0, keepdims=True))/np.sqrt(x.var(0, keepdims=True))
    cov_matrix = np.cov(x_centered, rowvar=False)

    # Step 2: Perform Eigen decomposition
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # Step 3: Sort eigenvalues and eigenvectors in descending order
    sorted_indices = np.argsort(eigenvalues)[::-1]
    first_two_eigenvectors = eigenvectors[:, sorted_indices[:3]]

    # Step 4: Project the array onto the first two Eigen directions
    projected_array = np.dot(x, first_two_eigenvectors)

    return projected_array

def logit_trans(data, eps=1e-6):
    data = eps + (1 - 2 * eps) * data
    return torch.log(data) - torch.log1p(-data)

class PlotLogLikelihoodCallback(Callback):
    def __init__(self, save_path="loss_plot.png", log_keys=("log_likelihood_x", "log_likelihood_z")):
        """
        Callback to plot the log likelihoods logged during training for two variables (e.g., x and z).

        Args:
            log_keys (tuple): A tuple containing keys for the two log likelihoods.
        """
        super().__init__()
        self.save_path = os.path.join(save_path, f'llk.png')
        self.log_keys = log_keys
        self.log_likelihoods_x = []
        self.log_likelihoods_z = []

    def on_train_epoch_end(self, trainer, pl_module):
        """
        Called at the end of each training epoch.

        Args:
            trainer (Trainer): The PyTorch Lightning trainer instance.
            pl_module (LightningModule): The LightningModule being trained.
        """
        # Retrieve the logged values from the trainer's logger
        if self.log_keys[0] in trainer.callback_metrics:
            log_likelihood_x = trainer.callback_metrics[self.log_keys[0]].item()
            self.log_likelihoods_x.append(log_likelihood_x)
        
        if self.log_keys[1] in trainer.callback_metrics:
            log_likelihood_z = trainer.callback_metrics[self.log_keys[1]].item()
            self.log_likelihoods_z.append(log_likelihood_z)

        # Plot the log likelihoods
        plt.figure(figsize=(12, 6))

        # Subplot for log likelihood of x
        plt.subplot(1, 2, 1)
        plt.plot(self.log_likelihoods_x, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(x|z)")
        plt.title("Log Likelihood of x During Training")
        # plt.legend()
        plt.grid()

        # Subplot for log likelihood of z
        plt.subplot(1, 2, 2)
        plt.plot(self.log_likelihoods_z, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(z|x)")
        plt.title("Log Likelihood of z During Training")
        # plt.legend()
        plt.grid()

        plt.tight_layout()
        plt.savefig(self.save_path)
        plt.close()



class PlotLossCallback(Callback):
    def __init__(self, save_path="loss_plot.png", update_interval=5):
        super().__init__()
        self.save_path = save_path
        self.update_interval = update_interval
        self.train_losses = []
        self.val_losses = []
        self.epochs = []

    def on_train_epoch_end(self, trainer, pl_module):
        # Get current epoch
        current_epoch = trainer.current_epoch

        # Log training and validation loss
        train_loss = trainer.callback_metrics.get("train_loss")
        val_loss = trainer.callback_metrics.get("val_loss")

        if train_loss is not None and val_loss is not None:
            self.epochs.append(current_epoch)
            self.train_losses.append(train_loss.cpu().item())
            self.val_losses.append(val_loss.cpu().item())

        # Update the plot every `update_interval` epochs
        if current_epoch % self.update_interval == 0:
            self.plot_and_save()

    def plot_and_save(self):
        plt.figure(figsize=(10, 6))
        plt.plot(self.epochs, self.train_losses, label="Training Loss", marker="o")
        plt.plot(self.epochs, self.val_losses, label="Validation Loss", marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.save_path)
        plt.close()

class FlowMatchingLightningModule(pl.LightningModule):
    def __init__(self, vt: nn.Module, Rt: nn.Module, rt, priorz, embx, config, args):
        super().__init__()
        self.config = config
        self.args = args
        
        self.vt = vt
        self.Rt = Rt
        self.rt = rt
        self.priorz = priorz
        self.embx = embx
        self.k, self.d = self.config.flow.pi_dim, self.config.flow.z_dim
        self.c, self.p = self.config.data.channel, self.config.data.size
        self.xemb_dim = self.config.model.ngf
        self.sig_min = 1e-4
        # Register buffers for priorz parameters

        self.automatic_optimization = False
        self.last_validation_batch = None

    def setup(self, stage=None):
        # Reinitialize the distributions using the buffers now on the correct device
        # self.priorpi = Dirichlet(torch.ones(self.k).to(self.device))
        self.priorpi = DiagNormal(torch.zeros(self.k, device=self.device), torch.ones(self.k, device=self.device))
        if not self.config.model.cnn:
            self.priory = Normal(torch.zeros(self.p**2).to(self.device), torch.ones(self.p**2).to(self.device))
        else:
            self.priory = Normal(torch.zeros(self.c, self.p, self.p).to(self.device), torch.ones(self.c, self.p, self.p).to(self.device))
        
        if self.config.model.cnn:
            if self.config.flow.fix_z and self.config.flow.fix_pi:
                self.flow_matching_loss = FlowMatchingLossCNN_fixed(
                    self.vt, self.Rt, self.rt, self.priorpi, self.priorz, self.priory, k=self.k, 
                    beta=self.config.flow.beta, alpha=self.config.training.alpha, 
                    cnn=self.config.flow.cnn
                    )
            elif self.config.flow.fix_pi and not self.config.flow.fix_z:
                self.flow_matching_loss = FlowMatchingLossCNN_fixpi(
                    self.vt, self.Rt, self.rt, self.priorpi, self.priorz, self.priory, k=self.k, 
                    beta=self.config.flow.beta, alpha=self.config.training.alpha, 
                    )
            elif not self.config.flow.fix_pi and not self.config.flow.fix_z:
                self.flow_matching_loss = FlowMatchingLossCNN(
                    self.vt, self.Rt, self.rt, self.priorpi, self.priorz, self.priory, k=self.k, 
                    beta=self.config.flow.beta, alpha=self.config.training.alpha, 
                    )
       

    def forward(self, pi0, n):

        if not self.config.model.cnn:
            # for generating data given class z (batched integer)
            # y1 = torch.rand(n,self.p**2).to(self.device)
            y1 = self.priory.sample((n,)).to(self.device)
            x1 = inv_transform(y1)
            # pi1 = self.priorpi.sample((n,)).to(self.device)
            z1 = self.priorz.sample(pi0, (n,)).to(self.device)

            y0 = self.vt.decode(y1, z1)
            x0 = inv_transform(y0)

            z1_np = z1.cpu().detach().numpy()
            x0_np = x0.cpu().detach().numpy().reshape((-1,self.c,self.p,self.p))
        else:
            # for generating data given class z (batched integer)
            y1 = torch.randn(n,self.c,self.p,self.p).to(self.device)
            z1 = self.priorz.sample(pi0, (n,)).to(self.device)
            # q1 = map_to_position(z, self.config.data.n_classes)
            y0 = self.vt.decode(y1, z1)

            # x0 = torch.sigmoid(y0)
            x0 = inv_transform(y0)
            z1_np = z1.cpu().detach().numpy()
            x0_np = x0.cpu().detach().numpy()
            # print("x1_output", x1_np.shape)
        return x0_np, z1_np

    def training_step(self, batch, batch_idx):
        # print("train")
        X, y = batch  # Assuming the batch is the input data `x`
        # print("sample", X[0])
        # print("X", X[0])
        if not self.config.model.cnn:
            X = X.view(-1, self.c*self.p**2)
        else:
            X = X
        loss = self.flow_matching_loss(X)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)


        g_opt, d_opt = self.optimizers()
        d_opt.zero_grad()
        g_opt.zero_grad()

        self.manual_backward(loss)
        d_opt.step()
        # if self.current_epoch >= 5:
        #     d_opt.step()
        g_opt.step()

        self.clip_gradients(d_opt, gradient_clip_val=self.config.training.clipval, gradient_clip_algorithm="norm")
        self.clip_gradients(g_opt, gradient_clip_val=self.config.training.clipval, gradient_clip_algorithm="norm")

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        if not self.config.model.cnn:
            X = X.view(-1, self.c*self.p**2)
        else:
            X = X
        val_loss = self.flow_matching_loss(X)
        # generate posterior and evaluate log_posterior and log_lik
        # Store the last batch for plotting
        if batch_idx == self.trainer.num_val_batches[0] - 1:
            self.last_validation_batch = {"X": X, "y": y}

        self.log('val_loss', val_loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
        

        return val_loss

    def configure_optimizers(self):
        # Define your optimizer
        # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        # return optimizer
        optimizer = torch.optim.AdamW(
            itertools.chain(
                self.vt.parameters(),
                ), 
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)
        cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = self.config.training.n_epochs,
                            eta_min = 1e-5,
                            last_epoch = -1
                        )

        warmUpScheduler = WarmUpScheduler(
            optimizer = optimizer, 
            lr_scheduler=cosineScheduler, 
            warmup_steps=self.config.training.n_epochs // 10, 
            warmup_start_lr=0.00005,
            len_loader=self.config.data.samplesize//self.config.training.batch_size
            )

        optimizer_latent = torch.optim.AdamW(
            itertools.chain(
                self.Rt.parameters(),
                self.rt.parameters()
                # self.priorz.parameters()
                ), 
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)

        # return [optimizer, optimizer_latent]

        cosineScheduler_latent = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer_latent,
                            T_max = self.config.training.n_epochs,
                            eta_min = 1e-5,
                            last_epoch = -1
                        )

        warmUpScheduler_latent = WarmUpScheduler(
            optimizer = optimizer_latent, 
            lr_scheduler=cosineScheduler_latent, 
            warmup_steps=self.config.training.n_epochs // 10, 
            warmup_start_lr=0.00005,
            len_loader=self.config.data.samplesize//self.config.training.batch_size
            )

        return [optimizer, optimizer_latent], [
                {'scheduler': warmUpScheduler, 
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}, 
                {'scheduler': warmUpScheduler_latent,
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}]
        
        # return {"optimizer": optimizer, "lr_scheduler": warmUpScheduler, "interval":"epoch", "frequency":1}
    
    def on_train_epoch_end(self):
        # manual scheduler step
        
        sch1, sch2 = self.lr_schedulers()
        sch1.step()
        sch2.step()


    def on_validation_epoch_end(self):
        if self.last_validation_batch is not None:
            X = self.last_validation_batch["X"]
            y = self.last_validation_batch["y"]

            if not self.config.model.cnn:
                X = X.view(-1, self.c*self.p**2)

            
            if self.current_epoch % self.config.training.snapshot_freq == 0:
                
                """ Snapshot sampling at the end of every epoch """
                # if self.config.training.snapshot_sampling:
                self.plot_latent(X, y)
                self.generate(X, y)

                 # if self.current_epoch % self.config.training.snapshot_freq == 0:
                z0, z0idx, pi0 = self.sample(X)
                # evaluate likelihood
                if self.config.flow.fix_z:
                    log_post_z = self.rt.log_prob(z0idx, X.flatten(start_dim=1), z0).mean()
                    # log_post_z = self.rt.log_prob(y0.view(-1, self.config.data.size**2), z0, pi0).mean()
                else:    
                    log_post_z = self.rt.log_prob(z0, z0idx, X.flatten(start_dim=1), 0., self.priorz).mean()
                # print("log_post_z", log_post_z)
                log_lik = self.vt.log_prob(X, z0, 0., self.priory).mean()
                self.log('val_log_post_z', log_post_z, on_step=False, on_epoch=True, sync_dist=True, logger=True)
                self.log('val_log_lik', log_lik, on_step=False, on_epoch=True, sync_dist=True, logger=True)

                print("log_post", log_post_z)
                print("log_lik", log_lik)
                
            
        # Clear the stored batch for next epoch
        self.last_validation_batch = None

    def sample(self, x):
        if self.config.flow.cnn:
            _x = x
        else:
            _x = x.flatten(start_dim=1)


        if self.config.flow.fix_pi:
            logits1 = self.priorpi.sample((len(x),)).to(self.device) # prior sampled from DiagNormal
            pi1 = F.softmax(logits1/self.config.flow.beta, dim=-1)
            _, z1idx = self.Rt.rsample(None, pi1)
            logits0, z0idx = self.Rt.rsample(_x)
            pi0 = F.softmax(logits0/self.config.flow.beta, dim=-1)
        else:

            logits1 = self.priorpi.sample((len(x),)).to(self.device)
            pi1 = F.softmax(logits1/self.config.flow.beta, dim=-1)
            z1idx = F.gumbel_softmax(logits1, tau=self.config.flow.beta, hard=False)

            logits0 = self.Rt.decode(logits1, _x)
            pi0 = F.softmax(logits0/self.config.flow.beta, dim=-1)
            z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.beta, hard=False)

        if self.config.flow.fix_z:
            z0 = self.rt.sample(z0idx, _x).to(self.device)
        else:
            # print("z1idx", z1idx.shape)
            z1 = self.priorz.rsample(z1idx, (1,)).to(self.device)
            z0 = self.rt.decode(z1, z0idx, _x) # given soft probability, sample z

        return z0, z0idx, pi0/pi0.sum(-1, keepdims=True)



    def generate(self, x, y, cmap="gray"):
        # Set models to evaluation mode
        self.vt.eval()
        self.Rt.eval()
        self.rt.eval()
        # self.priorz.eval()
        # self.embx.eval()
        with torch.no_grad():


            fig, axes = plt.subplots(10, 8, figsize=(10, 10))
            for row_idx, row_axes in enumerate(axes):
                m = len(row_axes)
                mask = y==row_idx
                if mask.sum() == 0:
                    pass
                else:
                    
                    x_k = x[mask][0]
                    y_k = map_to_position(y[y==row_idx], self.k)[0]
                    # x0 = inv_transform(x_k).repeat(m-1,1)
                    if self.config.model.cnn:
                        y0 = x_k.repeat(m-1,1,1,1)
                    else:
                        y0 = x_k.repeat(m-1,1)

                    if self.config.flow.cnn:
                        _y0 = y0
                    else:
                        _y0 = y0.flatten(start_dim=1)
                        

                    if not self.config.flow.fix_pi:
                        # _pi1 = self.priorpi.sample((m-1,)).to(self.device)
                        
                        # z1idx = F.gumbel_softmax(pi1.log(), tau=self.config.flow.tau, hard=False)
                        logits1 = self.priorpi.sample((m-1,)).to(self.device)
                        pi1 = F.softmax(logits1/self.config.flow.beta, dim=-1)
                        z1idx = F.gumbel_softmax(logits1, tau=self.config.flow.beta, hard=False)
                        z1 = self.priorz.sample(z1idx, (1,)).to(self.device)
                        # print('z1', z1.shape)
                        logits0 = self.Rt.decode(logits1, _y0)
                        pi0 = F.softmax(logits0/self.config.flow.beta, dim=-1)
                        z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.beta, hard=False)
                    else:
                        logits1 = self.priorpi.sample((m-1,)).to(self.device)
                        _, z1idx = self.Rt.rsample(None, logits1)
                        logits0, z0idx = self.Rt.rsample(_y0)
                    # z0idx = y_k.repeat(m-1,1).to(self.device)

                    if self.config.flow.fix_z:
                        z0 = self.rt.sample(z0idx, _y0)
                    else:

                        z1 = self.priorz.sample(z1idx, (1,)).to(self.device)
                        z0 = self.rt.decode(z1, z0idx, _y0)
                        

                    y1_new = self.priory.sample((m-1,)).to(self.device)
                    y0_new = self.vt.decode(y1_new, z0)

                    x0_np = inv_transform(y0_new).cpu().detach().numpy()
                    x_np = inv_transform(x_k).cpu().detach().numpy()

                    if not self.config.model.cnn:
                        x0_np = x0_np.reshape((-1,self.c,self.p,self.p))
                        x_np = x_np.reshape((self.c,self.p,self.p))
                    
                    for col_idx, ax in enumerate(row_axes):   
                        if col_idx == 0:
                            x_np_tr = np.transpose(x_np, (1, 2, 0))
                            ax.imshow(x_np_tr, cmap=cmap)
                            ax.set_ylabel("y={}".format(row_idx))
                        else:
                            x0_np_tr = np.transpose(x0_np[col_idx-1], (1, 2, 0))
                            # print(x0_np_tr)
                            # print()
                            ax.imshow(x0_np_tr, cmap=cmap)
                            ax.set_ylabel("")
                        ax.set_xlabel("")
                        ax.get_xaxis().set_ticks([])
                        ax.get_yaxis().set_ticks([])
            plt.tight_layout()
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'image_grid_epoch_{self.current_epoch}.png'))
            plt.close()
            

            fig, axes = plt.subplots(10, 3, figsize=(10,10), sharex=True, constrained_layout=True)
            for row_idx, row_axes in enumerate(axes):
                # print(f"Row {row_idx}")
                x_k = x[y==row_idx]
                for col_idx, ax in enumerate(row_axes):   
                    if col_idx < len(x_k):
                        if self.config.model.cnn:              
                            y0 = x_k[col_idx].repeat(100,1,1,1)
                        else:
                            y0 = x_k[col_idx].repeat(100,1)

                        if self.config.flow.cnn:
                            _y0 = y0
                        else:
                            _y0 = y0.flatten(start_dim=1)

                        if self.config.flow.fix_pi:
                            logits0, z0idx = self.Rt.rsample(_y0)
                        else:
                            logits1 = self.priorpi.sample((100,))
                            logits0 = self.Rt.decode(logits1, _y0)
                            pi0 = F.softmax(logits0/self.config.flow.beta, dim=-1)
                            z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.beta, hard=False)
 
                        kidx = torch.argmax(z0idx, dim=-1).squeeze()
                        kidx_np = kidx.cpu().detach().numpy()

                        # ax.hist(z0, bins=20, alpha=0.7)
                        ax.hist(kidx_np, bins=20, alpha=0.7)
                    else:
                        ax.axis('off')
                    # ax.set_title(f"y={y[i]}")
                    if col_idx == 0:
                        ax.set_ylabel(f"y={row_idx}")

            # plt.tight_layout(pad=3.0)
            plt.savefig(os.path.join(self.args.log_sample_path, 'postpi_grid_epoch_{}.png'.format(self.current_epoch)))
            plt.close()
        # Optionally, switch back to training mode after sampling
        self.vt.train()
        self.Rt.train()
        self.rt.train()
        # self.priorz.train()
        # self.embx.train()

    def plot_latent(self, x, y):
        # Set models to evaluation mode
        self.vt.eval()
        self.Rt.eval()
        self.rt.eval()
        # self.priorz.eval()

        with torch.no_grad():

            z0, z0idx, pi0 = self.sample(x)

        z0_np = z0.cpu().detach().numpy()

        k0_np = np.argmax(z0idx.cpu().detach().numpy(), axis=1)
        # print(pi0_np)
        y_np = y.cpu().detach().numpy()

        cmap = plt.colormaps['tab10']
        
        if self.config.flow.z_dim == 2:
            z0_proj = z0_np.squeeze()
            plt.figure(figsize=(8, 6))
            plt.scatter(z0_proj[:,0], z0_proj[:,1], c=k0_np, cmap=cmap, s=10)
            # Add the color bar
            plt.colorbar(shrink=0.5, orientation='vertical')
            plt.title("Generated latent z given x")
            plt.tight_layout()
            plt.savefig(os.path.join(self.args.log_sample_path, f'postz_grid_epoch_{self.current_epoch}.png'))
            plt.close()
        elif self.config.flow.z_dim == 3:
            z0_proj = z0_np.squeeze()
            fig = plt.figure(figsize=(8, 8))
            ax = fig.add_subplot(111, projection='3d')
            scatter = ax.scatter(z0_proj[:,0], z0_proj[:,1], z0_proj[:,2], c=k0_np, cmap=cmap, s=10)
            cbar = plt.colorbar(scatter, ax=ax, pad=0.1, orientation='vertical', shrink=0.5)
            ax.set_title("Generated latent z given x")

            plt.tight_layout()
            plt.savefig(os.path.join(self.args.log_sample_path, f'postz_grid_epoch_{self.current_epoch}.png'))
            plt.close()
        else:
            # z0_proj = first_three_eigen_proj(z0_np)
            tsne = TSNE(n_components=3, perplexity=30, random_state=0)
            z0_proj = tsne.fit_transform(z0_np)

            fig = plt.figure(figsize=(8, 6))
            ax = fig.add_subplot(111, projection='3d')
            scatter = ax.scatter(z0_proj[:,0], z0_proj[:,1], z0_proj[:,2], c=y_np, cmap=cmap, s=10)
            cbar = plt.colorbar(scatter, ax=ax, pad=0.1, orientation='vertical', shrink=0.5)
            ax.set_title("Generated latent z given x")

            plt.tight_layout()
            plt.savefig(os.path.join(self.args.log_sample_path, f'postz_grid_epoch_{self.current_epoch}.png'))
            plt.close()

        # Optionally, switch back to training mode after sampling
        self.vt.train()
        self.Rt.train()
        self.rt.train()


    def test_generate(self, x, y, cmap='gray', ncomp=10, prior="gaussian", conditional=True):
        # generate posterior predictive
        # fid = FrechetInceptionDistance()
        inception = InceptionScore()
        with torch.no_grad():
            
            if self.config.flow.fix_z and self.config.flow.fix_pi:
                logits0, z0idx = self.Rt.rsample(x.flatten(start_dim=1))
                pi0 = F.softmax(logits0/self.config.flow.beta, dim=-1)
                k0 = torch.argmax(z0idx, dim=-1) # get the maximum 
                # print(z0idx)
                z0 = self.rt.sample(z0idx, x.flatten(start_dim=1))

            fig, axes = plt.subplots(self.config.data.n_classes, 8, figsize=(10, 10))
            if conditional:
                for k, kax in enumerate(axes):
                    # get PCA first d dimension, then sample from Gaussian 
                    if prior=="gaussian":
                        z0_k = z0[k0==k]
                        z0_mean = z0_k.mean(dim=0, keepdim=True)
                        z0_centered = z0_k - z0_mean
                        U, S, Vh = torch.linalg.svd(z0_centered, full_matrices=False)
                        print("--- Running PCA Sampling and Reconstruction ---")
                        _ , z0_new = sample_and_reconstruct_pca(
                            X_original=z0_k,
                            mean_original=z0_mean.squeeze(), # Pass as 1D or 2D row vector
                            Vh_from_svd=Vh,
                            k_components=ncomp,
                            num_new_samples=10,
                            sampling_strategy="mimic_projection"
                        )

                    elif prior=="uniform":
                        lowerb = torch.min(z0[k0==k], 0)[0]
                        upperb = torch.max(z0[k0==k], 0)[0]

                        z0_new = torch.rand((10, self.d)) * (upperb - lowerb) + lowerb

                    x1_new = self.priory.sample((10,)).to(self.device)
                    x0_new = self.vt.decode(x1_new, z0_new)

                    x0_new = inv_transform(x0_new)
                    x0_np = x0_new.cpu().detach().numpy()

                    if not self.config.model.cnn:
                        x0_np = x0_np.reshape((-1,self.c,self.p,self.p))

                    for i, ax in enumerate(kax):
                        x0_np_tr = np.transpose(x0_np[i], (1, 2, 0))
                        ax.imshow(x0_np_tr, cmap=cmap)
                        ax.axis("off")

                plt.tight_layout()
                # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
                plt.savefig(os.path.join(self.args.log_sample_path, f'test_gen_grid_conditional_{prior}.png'))
                plt.close()
            else:
                if prior=="gaussian":
                    z0_mean = z0.mean(dim=0, keepdim=True)
                    z0_centered = z0 - z0_mean
                    U, S, Vh = torch.linalg.svd(z0_centered, full_matrices=False)
                    print("--- Running PCA Sampling and Reconstruction ---")
                    _ , z0_new = sample_and_reconstruct_pca(
                        X_original=z0,
                        mean_original=z0_mean.squeeze(), # Pass as 1D or 2D row vector
                        Vh_from_svd=Vh,
                        k_components=ncomp,
                        num_new_samples=10*10,
                        sampling_strategy="mimic_projection"
                    )

                elif prior=="uniform":
                    lowerb = torch.min(z0, 0)[0]
                    upperb = torch.max(z0, 0)[0]

                    z0_new = torch.rand((10*10, self.d)) * (upperb - lowerb) + lowerb

                x1_new = self.priory.sample((10*10,)).to(self.device)
                x0_new = self.vt.decode(x1_new, z0_new)

                x0_new = inv_transform(x0_new)
                x0_np = x0_new.cpu().detach().numpy()

                if not self.config.model.cnn:
                    x0_np = x0_np.reshape((-1,self.c,self.p,self.p))

                for i, ax in enumerate(axes.flatten()):
                    x0_np_tr = np.transpose(x0_np[i], (1, 2, 0))
                    ax.imshow(x0_np_tr, cmap=cmap)
                    ax.axis("off")

                plt.tight_layout()
                # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
                plt.savefig(os.path.join(self.args.log_sample_path, f'test_gen_grid_{prior}.png'))
                plt.close()

            gen_img = x0_new.repeat(1, 3, 1, 1).to(torch.float32)
            inception.update(gen_img.to(torch.uint8)) # comput this over all classes
            is_score_mean, is_score_std = inception.compute()


            fig, axes = plt.subplots(10, 8, figsize=(10,10))
            ssim = StructuralSimilarityIndexMeasure()
            
            loss_ssim = []
            for row_idx, row_axes in enumerate(axes):
                m = len(row_axes)
                mask = y==row_idx
                
                if mask.sum() == 0:
                    pass
                else:
                    x_k = x[mask][0]
                    y_k = map_to_position(y[y==row_idx], self.k)[0]
                    # x0 = inv_transform(x_k).repeat(m-1,1)
                    if self.config.model.cnn:
                        y0 = x_k.repeat(m-1,1,1,1)
                    else:
                        y0 = x_k.repeat(m-1,1)
                    if self.config.model.cnn:
                        y0 = y0.flatten(start_dim=1)
                        

                    if not self.config.flow.fix_pi:

                        logits1 = self.priorpi.sample((m-1,)).to(self.device)
                        pi1 = F.softmax(logits1/self.config.flow.beta, dim=-1)
                        z1idx = F.gumbel_softmax(logits1, tau=self.config.flow.beta, hard=False)
                        z1 = self.priorz.sample(z1idx, (1,)).to(self.device)
                        # print('z1', z1.shape)
                        logits0 = self.Rt.decode(logits1, y0.flatten(start_dim=1))
                        pi0 = F.softmax(logits0/self.config.flow.beta, dim=-1)
                        z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.beta, hard=False)
                    else:
                        logits1 = self.priorpi.sample((m-1,)).to(self.device)
                        _, z1idx = self.Rt.rsample(None, logits1)
                        logits0, z0idx = self.Rt.rsample(y0.flatten(start_dim=1))
                    # z0idx = y_k.repeat(m-1,1).to(self.device)

                    if self.config.flow.fix_z:
                        z0 = self.rt.sample(z0idx, y0.flatten(start_dim=1))
                    else:

                        z1 = self.priorz.sample(z1idx, (1,)).to(self.device)
                        z0 = self.rt.decode(z1, z0idx, y0)
                        

                    y1_new = self.priory.sample((m-1,)).to(self.device)
                    y0_new = self.vt.decode(y1_new, z0)
                    y0_new = inv_transform(y0_new)
                    x0_np = y0_new.cpu().detach().numpy()

                    x_np = inv_transform(x_k).cpu().detach().numpy()
                    y0 = inv_transform(y0) # .cpu().detach().numpy()
                    y0 = y0.reshape((-1,self.c,self.p,self.p))
                    if not self.config.model.cnn:
                        x0_np = x0_np.reshape((-1,self.c,self.p,self.p))
                        
                        x_np = x_np.reshape((self.c,self.p,self.p))

                    for col_idx, ax in enumerate(row_axes):   
                        if col_idx == 0:
                            ax.imshow(np.transpose(x_np, (1,2,0)), cmap=cmap)
                            ax.set_ylabel("y={}".format(row_idx))
                        else:
                            ax.imshow(np.transpose(x0_np[col_idx-1],(1,2,0)), cmap=cmap)
                            ax.set_ylabel("")

                            # loss_r.append(np.sqrt(((x_np[0] - x0_np[col_idx-1][0])**2).sum()))

                        ax.set_xlabel("")
                        ax.get_xaxis().set_ticks([])
                        ax.get_yaxis().set_ticks([])

                    real_img = y0.repeat(1, 3, 1, 1).to(torch.float32)
                    gen_img = y0_new.repeat(1, 3, 1, 1).to(torch.float32)
                    ssim_score = ssim(real_img, gen_img)
                    
                    loss_ssim.append(ssim_score)
                    


            plt.tight_layout()
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'eval_image_grid_epoch_{self.current_epoch}.png'))
            plt.close()


            # posterior sample from pit
            fig, axes = plt.subplots(10, 3, figsize=(10,10), sharex=True, constrained_layout=True)
            for row_idx, row_axes in enumerate(axes):
                # print(f"Row {row_idx}")
                x_k = x[y==row_idx]
                for col_idx, ax in enumerate(row_axes):   
                    if col_idx < len(x_k):
                        if self.config.model.cnn:              
                            y0 = x_k[col_idx].repeat(100,1,1,1)
                        else:
                            y0 = x_k[col_idx].repeat(100,1)

                        if self.config.model.cnn:
                            y0 = y0.flatten(start_dim=1)

                        if self.config.flow.fix_pi:
                            logits0, z0idx = self.Rt.rsample(y0)
                        else:
                            logits1 = self.priorpi.sample((100,))

                            logits0 = self.Rt.decode(logits1, y0)
                            pi0 = F.softmax(logits0/self.config.flow.beta, dim=-1)
                            z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.beta, hard=False)

                        kidx = torch.argmax(z0idx, dim=-1).squeeze()
                        kidx_np = kidx.cpu().detach().numpy()

                        # ax.hist(z0, bins=20, alpha=0.7)
                        ax.hist(kidx_np, bins=20, alpha=0.7)
                    else:
                        ax.axis('off')
                    # ax.set_title(f"y={y[i]}")
                    if col_idx == 0:
                        ax.set_ylabel(f"y={row_idx}")
            # plt.tight_layout(pad=3.0)
            plt.savefig(os.path.join(self.args.log_sample_path, 'eval_postpi_grid_epoch_{}.png'.format(self.current_epoch)))
            plt.close()

        return np.array(loss_ssim).mean(), is_score_mean, is_score_std

    def test_plot_path(self, x, y, n_time=100):
        sorted_idx = torch.argsort(y)
        xk = x[sorted_idx]
        yk = y[sorted_idx].numpy()
        if self.config.model.cnn:
            xk = xk.flatten(start_dim=1)

        # sample
        logits1 = self.priorpi.rsample((len(x),))
        # z1idx = F.gumbel_softmax(logits1, tau=self.tau, hard=hard)
        _, z1idx = self.Rt.rsample(None, logits1)
        z1 = self.priorz.rsample(z1idx, (1,))
        logitst, ztidx = self.Rt.rsample(xk.flatten(start_dim=1))
        pi0 = softmax(logitst/self.config.flow.beta, dim=-1)
        z0, trajectories, time_points = self.rt.decode_with_trajectory(z1, pi0, xk, num_points=n_time)
        _z0 = z0.cpu().detach().numpy()
        
        recoded_n_time = len(time_points)
        _trajectories = trajectories.cpu().detach().numpy().reshape(len(x), recoded_n_time, self.config.flow.z_dim)
        # print(_trajectories.shape)
        # print(time_points.shape)
        _zt_path = pca_map(_z0, _trajectories, k=1)

        cmap = plt.colormaps['tab10']

        fig, axes = plt.subplots(10, 1, figsize=(10, 12), sharex=True)
        fig.tight_layout(pad=3.0)  # Add some padding between subplots
        for class_label in range(10):  # self.config.data.n_classes
            ax = axes[class_label]  # Get the appropriate subplot
            
            class_indices = np.where(yk == class_label)[0]
            top_indices = class_indices[:min(2, len(class_indices))]
            # print(time_points.numpy())
            # Plot all trajectories for this class in its subplot
            for idx in top_indices:
                path = _zt_path[idx]
                color = cmap(class_label)
                ax.set_xlim(time_points.numpy()[0], time_points.numpy()[-1])
                ax.plot(time_points.numpy(), path, c=color)
            ax.set_title("Class {}, Mean Across [0,1] = {}".format(class_label, np.round(np.mean(_zt_path[top_indices,-1,0]), 3)))
            if class_label == 4:
                ax.set_ylabel(r'PCA Projected Latent Trajectory $z_t$')
        axes[-1].set_xlabel('Time')
        # plt.show()
        plt.savefig(os.path.join(self.args.log_sample_path, f'eval_latent_path.png'))
        plt.close()
        
        x1_new = self.priory.sample((len(xk),)).to(self.device)
        x0_new, x0_trajectories, time_points = self.vt.decode_with_trajectory(x1_new, z0)
        _x0_new = x0_new.cpu().detach().numpy()

        recoded_n_time = len(time_points)
        _x0_trajectories = x0_trajectories.cpu().detach().numpy().reshape(len(x), recoded_n_time, self.config.data.size**2)
        # print(_trajectories.shape)
        # print(time_points.shape)
        _xt_path = pca_map(_x0_new.reshape(len(x),self.config.data.size**2), _x0_trajectories, k=1)

        fig, axes = plt.subplots(10, 1, figsize=(10, 12), sharex=True)
        fig.tight_layout(pad=3.0)  # Add some padding between subplots
        for class_label in range(10):  # self.config.data.n_classes
            ax = axes[class_label]  # Get the appropriate subplot
            
            class_indices = np.where(yk == class_label)[0]
            top_indices = class_indices[:min(2, len(class_indices))]
            
            # Plot all trajectories for this class in its subplot
            for idx in top_indices:
                path = _xt_path[idx]
                # print(path.shape)
                color = cmap(class_label)
                ax.set_xlim(time_points.numpy()[0], time_points.numpy()[-1])
                ax.plot(time_points.numpy(), path, c=color)
                # ax.plot(time_points, path[::-1, 0], c=color)
            ax.set_title("Class {}, Mean Across [0,1] = {}".format(class_label, np.round(np.mean(_xt_path[top_indices,-1,0]), 3)))
            if class_label == 4:
                ax.set_ylabel(r'PCA Projected Observed Trajectory $x_t$')
        axes[-1].set_xlabel('Time')
        # plt.show()
        plt.savefig(os.path.join(self.args.log_sample_path, f'eval_obs_path.png'))
        plt.close()


    def test_plot_path_discrete(self, x, y):
        # for digit in each class
        # generate the sequence of xt, and compute the sequence of zt to visualize
        # see if the latent trajectory can be clustered as well
        sorted_idx = torch.argsort(y)
        xk = x[sorted_idx]
        yk = y[sorted_idx].numpy()
        if self.config.model.cnn:
            xk = xk.flatten(start_dim=1)
        # print("xk", xk.shape)

        _t_lis = torch.arange(0, 1, 0.1)
        _xt_path = []
        _zt_path = []
        tsne = TSNE(n_components=2, perplexity=30, random_state=14, n_iter=1000)
        cmap = plt.colormaps['tab10']

        for _t in _t_lis:
            t = torch.ones_like(x[..., 0, None]) * _t
            x1 = self.priory.sample((len(x),))
            xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x1
            _xt_path.append(inv_transform(xt).reshape((-1,self.c,self.p,self.p)).cpu().detach().numpy())

            logits1 = self.priorpi.rsample((len(x),))
            _, z1idx = self.Rt.rsample(None, logits1)
            z1 = self.priorz.rsample(z1idx, (1,))
            
            logitst, ztidx = self.Rt.rsample(xt.flatten(start_dim=1))
            pit = softmax(logitst/self.config.flow.beta, dim=-1)
            zt = self.rt.decode(z1, ztidx, xt.flatten(start_dim=1), _t) # given soft probability, sample z
            _zt = zt.cpu().detach().numpy()
             
            _zt_path.append(_zt)

        all_zt_path = np.asarray(_zt_path).reshape(len(x), len(_t_lis), self.config.flow.z_dim)
        _zt_path = tsne_map(all_zt_path[:,0,:], all_zt_path)

        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')

        for i, class_label in enumerate(range(10)): # self.config.data.n_classes
            class_indices = np.where(yk == class_label)[0]
            # print(class_indices)
            top_indices = class_indices[:min(2, len(class_indices))]

            for idx in top_indices:

                path = _zt_path[idx]
                color = cmap(class_label)
                # Plot the trajectory
                ax.plot(path[:, 0], path[:, 1], _t_lis, c=color)
                # Mark start and end points
                ax.scatter(path[0, 0], path[0, 1], _t_lis[0], color=color, s=20, marker='o')
                ax.scatter(path[-1, 0], path[-1, 1], _t_lis[-1], color=color, s=40, marker='x')

        plt.show()
        plt.close()

    def test_plot_latent(self, x, y):
        # sample from posterior
        z0, z0idx, pi0 = self.sample(x)
        z0_np = z0.cpu().detach().numpy()
        k0_np = np.argmax(z0idx.cpu().detach().numpy(), axis=1)
        # print(pi0_np)
        y_np = y.cpu().detach().numpy()
        pi_np = pi0.cpu().detach().numpy()

        cmap = plt.colormaps['tab10']
        
        # z0_proj = first_three_eigen_proj(z0_np)
        tsne = TSNE(n_components=2, perplexity=50, random_state=14, n_iter=1000)
        zc_proj = tsne.fit_transform(z0_np)

        plt.figure(figsize=(8,6))
        plt.scatter(zc_proj[:,0], zc_proj[:,1], c=y_np, cmap=cmap, s=10)
        plt.colorbar(shrink=0.5)

        plt.tight_layout()
        plt.savefig(os.path.join(self.args.log_sample_path, f'eval_grid_epoch_{self.current_epoch}.png'))
        # plt.show()
        plt.close()

        ari = adjusted_rand_score(y_np, k0_np)
        nmi = normalized_mutual_info_score(y_np, k0_np)
        nmi_soft = soft_nmi(pi_np, y_np)
        return nmi, ari, nmi_soft

    def test_generation_figure(self, x, y):
        # assume x cnotains each digit no repeat no missing
        # pick 1 digit each (a) plot original (2) plot one predictive (3) plot one posterior on \xi
        sorted_idx = torch.argsort(y)
        xk = x[sorted_idx]
        if self.config.model.cnn:
            xk = xk.flatten(start_dim=1)
        print("xk", xk.shape)
        row_ratios = [1, 1, 1.5]
        with torch.no_grad():
            fig, axes = plt.subplots(3, self.k, figsize=(10,4), gridspec_kw={'height_ratios': row_ratios}) 

            # genrate posterior
            if not self.config.flow.fix_pi:                
                # z1idx = F.gumbel_softmax(pi1.log(), tau=self.config.flow.tau, hard=False)
                logits1 = self.priorpi.sample((len(xk),)).to(self.device)
                pi1 = F.softmax(logits1/self.config.flow.beta, dim=-1)
                z1idx = F.gumbel_softmax(logits1, tau=self.config.flow.beta, hard=False)
                z1 = self.priorz.sample(z1idx, (1,)).to(self.device)
                # print('z1', z1.shape)
                logits0 = self.Rt.decode(logits1, xk)
                pi0 = F.softmax(logits0/self.config.flow.beta, dim=-1)
                z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.beta, hard=False)
            else:
                logits1 = self.priorpi.sample((len(xk),)).to(self.device)
                _, z1idx = self.Rt.rsample(None, logits1)
                logits0, z0idx = self.Rt.rsample(xk)
                pi0 = F.softmax(logits0/self.config.flow.beta, dim=-1)
            # z0idx = y_k.repeat(m-1,1).to(self.device)
            if self.config.flow.fix_z:
                z0 = self.rt.sample(z0idx, xk)
            else:

                z1 = self.priorz.sample(z1idx, (1,)).to(self.device)
                z0 = self.rt.decode(z1, z0idx, xk)
                

            y1_new = self.priory.sample((len(xk),)).to(self.device)
            y0_new = self.vt.decode(y1_new, z0)
            xnew = inv_transform(y0_new).reshape((-1,self.c,self.p,self.p))

            xk = inv_transform(xk).reshape((-1,self.c,self.p,self.p))
            # print("xnew", xnew.shape)
            # find probability for each x

            xnew_np = xnew.cpu().detach().numpy()
            xk_np = xk.cpu().detach().numpy()
            prob_np = pi0.cpu().detach().numpy()
            # print("x0", x0.shape)
            
            for col_idx in range(10):
                axes[0, col_idx].imshow(np.transpose(xk_np[col_idx], (1,2,0)), cmap="gray")
                axes[1, col_idx].imshow(np.transpose(xnew_np[col_idx], (1,2,0)), cmap="gray")
                axes[2, col_idx].bar(range(self.k), prob_np[col_idx])
                axes[0, col_idx].get_xaxis().set_ticks([])
                axes[0, col_idx].get_yaxis().set_ticks([])
                axes[1, col_idx].get_xaxis().set_ticks([])
                axes[1, col_idx].get_yaxis().set_ticks([])
                axes[2, col_idx].set_xlim(-1,self.k)
                axes[2, col_idx].set_ylim(0,1)
                if col_idx == 0:
                    axes[0, col_idx].set_ylabel("Real")
                    axes[1, col_idx].set_ylabel("Generated")
                    axes[2, col_idx].set_ylabel(r"$p(\xi|x)$")
                if col_idx != 0:
                    axes[2, col_idx].get_yaxis().set_ticks([])

            plt.tight_layout()
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'eval_sfa_mixture_eval_display.png'))
            plt.close()


    def test_sample(self, n, x, y):
        # first sample from x
        sorted_idx = torch.argsort(y)
        xk = x[sorted_idx]
        yk = y[sorted_idx].numpy()
        if self.config.model.cnn:
            xk = xk.flatten(start_dim=1)

        fig, axes = plt.subplots(10, n, figsize=(12,16))
        
        # with one sample from each class, see the perturbation
        for i, class_label in enumerate(range(self.k)): # self.config.data.n_classes
            # print(i)
            class_indices = np.where(yk == class_label)[0]
            # print(class_indices)
            top_index = class_indices[-1:]
            
            _xk = xk[top_index].repeat(n,1,1,1)
            # then sample from p(z|x)
            logits1 = self.priorpi.rsample((n,)).to(self.device)
            # z1idx = F.gumbel_softmax(logits1, tau=self.tau, hard=hard)
            _, z1idx = self.Rt.rsample(None, logits1)

            z1 = self.priorz.rsample(z1idx, (1,)).to(self.device)
            logits0, z0idx = self.Rt.rsample(_xk)
            pi0 = softmax(logits0/self.config.flow.beta, dim=-1)
            if self.config.flow.fix_z:
                z0 = self.rt.sample(z0idx, _xk.flatten(start_dim=1))
            else:
                # z0 = self.rt.decode(z1, pi0, _xk)
                z0 = self.rt.decode(z1, z0idx, _xk)

            x1_new = self.priory.sample((n,)).to(self.device)
            x0_new = self.vt.decode(x1_new, z0)

            _z0 = z0.cpu().detach().numpy()
            _z0idx = z0idx.cpu().detach().numpy()
            _x0_new = inv_transform(x0_new).cpu().detach().numpy()

            for j in range(n):

                axes[i, j].imshow(np.transpose(_x0_new[j], (1,2,0)))
                axes[i, j].axis("off")

        plt.tight_layout()
        # plt.show()
        plt.savefig(os.path.join(self.args.log_sample_path, f'eval_cond_generation.png'))
        plt.close()
        


    def test_perturb_generation(self, x, y, kind="pca", conditional=None):
        # perturbe the posterior sample in PCA directions
        sorted_idx = torch.argsort(y)
        xk = x[sorted_idx]
        yk = y[sorted_idx].numpy()
        if self.config.model.cnn:
            xk = xk.flatten(start_dim=1)

        # sample
        logits1 = self.priorpi.rsample((len(x),))
        # z1idx = F.gumbel_softmax(logits1, tau=self.tau, hard=hard)
        _, z1idx = self.Rt.rsample(None, logits1)
        z1 = self.priorz.rsample(z1idx, (1,))
        logits0, z0idx = self.Rt.rsample(xk)
        pi0 = softmax(logits0/self.config.flow.beta, dim=-1)
        z0 = self.rt.decode(z1, pi0, xk)
        _z0 = z0.cpu().detach().numpy()
        _z0idx = z0idx.cpu().detach().numpy()
        if kind=="pca":
            if conditional is None:
                _z0_sub = _z0
            else:
                print(_z0idx.argmax(-1)[[1,2]])
                _z0_sub = _z0[_z0idx.argmax(-1)==conditional]
            # normalize z
            scaler = StandardScaler()
            z0_normalized = scaler.fit_transform(_z0_sub)
            print(len(z0_normalized))

            pca = PCA(n_components=8).fit(z0_normalized)
            pcs = pca.components_

            alpha = np.arange(0, 20, 20./8.)

            
            for idx in range(4):
                pc = pcs[idx]
                fig, axes = plt.subplots(10, len(alpha), figsize=(12,16))
                # with one sample from each class, see the perturbation
                for i, class_label in enumerate(range(10)): # self.config.data.n_classes
                    # print(i)
                    class_indices = np.where(yk == class_label)[0]
                    # print(class_indices)
                    top_indices = class_indices[-2:]
                    z0_origin = scaler.transform(_z0[top_indices])
                    # for j, pc in enumerate(pcs):
                    
                    for j, a in enumerate(alpha):
                        z0_perturbed = torch.tensor(scaler.inverse_transform(z0_origin + a * pc))
                        # solve for x
                        x1_new = self.priory.sample((len(z0_perturbed),)).to(self.device)
                        x0_new = self.vt.decode(x1_new, z0_perturbed)
                        _x0_new = inv_transform(x0_new).cpu().detach().numpy()
                        # print(_x0_new.shape)

                        axes[i, j].imshow(np.transpose(_x0_new[0], (1,2,0)))
                        if j ==0:
                            axes[i,j].set_ylabel(f"Class={class_label}")
                        if i == len(alpha)-1:
                            axes[i,j].set_xlabel(rf"$\alpha$={a}")

                fig.supxlabel('Latent PC Coordinate', fontsize=15)
                fig.supylabel('Perturbation Values', fontsize=15)
                plt.tight_layout()
                # plt.show()
                if conditional is None:
                    plt.savefig(os.path.join(self.args.log_sample_path, f'eval_perturbed_generation_pca_{idx}.png'))
                else:
                    plt.savefig(os.path.join(self.args.log_sample_path, f'eval_{conditional}cond_generation_pca_{idx}.png'))
                plt.close()
        elif kind=="coord":
            # along axis in latent dimension
            pcs = np.eye(64)[:8]
            alpha = np.arange(0, 500, 50.)

            for i, class_label in enumerate(range(10)): # self.config.data.n_classes
                class_indices = np.where(yk == class_label)[0]
                # print(class_indices)
                top_indices = class_indices[:min(2, len(class_indices))]

                z0_origin = _z0[top_indices]

                fig, axes = plt.subplots(len(alpha), len(pcs), figsize=(12,15))
                for j, pc in enumerate(pcs):
                    for i, a in enumerate(alpha):
                        z0_perturbed = torch.tensor(z0_origin + a * pc)
                        # solve for x
                        x1_new = self.priory.sample((len(z0_perturbed),)).to(self.device)
                        x0_new = self.vt.decode(x1_new, z0_perturbed)
                        _x0_new = inv_transform(x0_new).cpu().detach().numpy()
                        # print(_x0_new.shape)

                        axes[i, j].imshow(np.transpose(_x0_new[0], (1,2,0)))
                        if j ==0:
                            axes[i,j].set_ylabel(rf"$\alpha$={a}")
                        if i == len(alpha)-1:
                            axes[i,j].set_xlabel(f"Coordinate {j}")

                fig.supxlabel('Latent Coordinates', fontsize=15)
                fig.supylabel('Perturbation Values', fontsize=15)
                plt.tight_layout()
                # plt.show()
                plt.savefig(os.path.join(self.args.log_sample_path, f'eval_perturbed_generation_onehot_{class_label}.png'))
                plt.close()
        elif kind=="interpolation":
            # linear interpolation
            pc = _z0[0]
            alpha = np.arange(0, 1, 1/8)
            fig, axes = plt.subplots(10, len(alpha), figsize=(12,15))
            # print(axes.shape)
            for i, class_label in enumerate(range(10)): # self.config.data.n_classes
                for j, a in enumerate(alpha):
                    class_indices = np.where(yk == class_label)[0]
                    # print(class_indices)
                    top_indices = class_indices[:min(2, len(class_indices))]
                    z0_origin = _z0[top_indices]
                    z0_perturbed = torch.tensor(a* z0_origin + (1-a) * pc)

                    x1_new = self.priory.sample((len(z0_perturbed),)).to(self.device)
                    x0_new = self.vt.decode(x1_new, z0_perturbed)
                    _x0_new = inv_transform(x0_new).cpu().detach().numpy()
                    # print(_x0_new.shape)

                    axes[i, j].imshow(np.transpose(_x0_new[0][0], (1,2,0)))
                    if j ==0:
                        axes[i,j].set_ylabel(rf"$class$={class_label}")
                    if i == len(alpha)-1:
                        axes[i,j].set_xlabel(rf"$alpha$={np.round(a,3)}")
            fig.supxlabel('Latent Coordinates', fontsize=15)
            fig.supylabel('Perturbation Values', fontsize=15)
            plt.tight_layout()
            # plt.show()
            plt.savefig(os.path.join(self.args.log_sample_path, f'eval_perturbed_interpolation.png'))
            plt.close()

    def test_step(self, batch, batch_idx):
        X, y = batch
        if not self.config.model.cnn:
            X = X.view(-1, self.config.data.size**2)
        
        if self.args.figure:
            self.test_generation_figure(X, y)
            
        else:
            ssim_loss, is_loss, is_loss_std = self.test_generate(X, y, conditional=False, ncomp=10)
            self.log('test_ssim_loss', ssim_loss, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            self.log('test_is_loss', is_loss, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            self.log('test_is_loss_std', is_loss_std, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            """ Snapshot sampling at the end of every epoch """
            # if self.config.training.snapshot_sampling:
            nmi, ard, nmi_soft = self.test_plot_latent(X, y)
            self.log('test_post_clus_nmi', nmi, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            self.log('test_post_clus_ard', ard, on_step=True, on_epoch=True, sync_dist=True, logger=True)
            self.log('test_post_clus_nmi_soft', nmi_soft, on_step=True, on_epoch=True, sync_dist=True, logger=True)


def remap_checkpoint_state_dict(state_dict):
    new_state_dict = {}
    
    for key, value in state_dict.items():
        if "Rt.fc.0." in key:
            # Convert from "Rt.fc.0.X" to "Rt.fc.X"
            parts = key.split(".")
            # Remove the "0" part and reconstruct
            new_parts = parts[:2] + parts[3:]
            new_key = ".".join(new_parts)
            new_state_dict[new_key] = value
        elif "flow_matching_loss.Rt.fc.0." in key:
            # Same pattern for flow_matching_loss.Rt.fc
            parts = key.split(".")
            # Remove the "0" part
            new_parts = parts[:3] + parts[4:]
            new_key = ".".join(new_parts)
            new_state_dict[new_key] = value
        else:
            # Keep other parameters unchanged
            new_state_dict[key] = value
    
    return new_state_dict


class MNISTRunner():
    def __init__(self, args, config):
        self.args = args
        self.config = config
        args.log_sample_path = os.path.join(args.log_path, 'samples')
        os.makedirs(args.log_sample_path, exist_ok=True)

        self.k, self.d = self.config.flow.pi_dim, self.config.flow.z_dim
        self.c, self.p = self.config.data.channel, self.config.data.size
        self.priorz = update_GaussianMixtureComponent(self.k, self.d, hidden_features=[], fct=nn.Softplus())
        
        self.vt = cnnLLK(
            self.p, self.d, in_ch=self.c, 
            mod_ch=self.config.model.mod_ch, freqs=self.config.model.freqs,
            hidden_features=[self.config.flow.ngf]*3,
            fct=nn.Softplus(),
            )
        self.embx = None

        if self.config.flow.fix_z:
            self.rt = GaussianMixtureComponent(
                self.k, self.d, self.c*self.p**2, hidden_features=[self.config.flow.ngf]*0
                , fct=nn.Tanh()
                ).to(self.config.device)
            # self.rt.apply(siren_init)
        else:
            self.rt = CNF(
                self.c*self.p**2, self.d, self.k, hidden_features=[]*1
                , fct=nn.Tanh()
                ).to(self.config.device)

        if self.config.flow.fix_pi:
            self.Rt = CatNF_fixed(
                self.c*self.p**2, self.k, hidden_features=[self.config.flow.ngf]*1
                ,temp=self.config.flow.beta
                , fct=nn.Tanh()
                ).to(self.config.device)
        else:
            self.Rt = CatNF(self.c*self.p**2, self.k, hidden_features=[self.config.flow.ngf]*0
            , fct=nn.Softplus()
            # , batch_norm=True
            ).to(self.config.device)

        # Define the ModelCheckpoint callback
        self.checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Metric to monitor
            dirpath=self.args.log_path,  # Directory where checkpoints will be saved
            filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
            save_top_k=1,  # Only save the best model based on val_loss
            mode='min'  # Minimize the validation loss
        )
        # Initialize the Trainer

        if torch.cuda.is_available():
            accelerator='gpu'
            strategy=DDPStrategy(find_unused_parameters=True)
            devices="auto"
        else:
            accelerator='cpu'
            devices="auto"
            strategy = "auto"

        # Add the callback
        plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss.png'), update_interval=1)
        plot_llk_callback = PlotLogLikelihoodCallback(save_path=self.args.log_sample_path, log_keys=("val_log_lik", "val_log_post_z"))
        self.trainer = pl.Trainer(
            max_epochs=self.config.training.n_epochs, 
            # accelerator='gpu',
            accelerator = accelerator,
            devices=devices,
            strategy=strategy,
            callbacks=[plot_loss_callback, plot_llk_callback, self.checkpoint_callback],
        )

        
    def train(self):
        # load data
        dataset, val_dataset, sampler, val_sampler = get_mnist(
            self.config.data.n_classes, "data", self.config.data.samplesize, self.config.data.test_samplesize)
        train_dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size,
                                num_workers=self.config.data.num_workers, sampler=sampler)
        val_dataloader = DataLoader(val_dataset, batch_size=self.config.training.batch_size, 
                                 num_workers=self.config.data.num_workers, sampler=val_sampler, drop_last=True)
        # Initialize the Lightning model
        model = FlowMatchingLightningModule(self.vt, self.Rt, self.rt, self.priorz, self.embx, self.config, self.args)
        # Run the training loop
        if not self.args.resume_training:
            ckpt_path = None
        else:
            ckpt_path = self.checkpoint_callback.best_model_path
        self.trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)


    def sample(self):
        if self.config.data.in_sample:
            dataset, _, sampler, _ = get_mnist(
                    self.config.data.n_classes, "data", 300, 50)
            dataloader = DataLoader(dataset, batch_size=5000,
                    num_workers=self.config.data.num_workers, sampler=sampler)
        else:
            # 0123456789ABCDEFGHIJ
            dataset, sampler = get_emnist(
                20, "data", 300, 50, split="balanced") # balanced
            dataloader = DataLoader(dataset, batch_size=6000,
                                num_workers=self.config.data.num_workers, sampler=sampler)
        ckpt_path = ckpt_path = self.checkpoint_callback.best_model_path

        
        model = FlowMatchingLightningModule(self.vt, self.Rt, self.rt, self.priorz, self.embx, self.config, self.args)

        # set to test mode
        self.trainer.test(model, dataloaders=dataloader, ckpt_path=ckpt_path)

        # Access the aggregated value
        log_lik = self.trainer.callback_metrics["val_log_lik"]
        log_post = self.trainer.callback_metrics["val_log_post_z"] 
        test_pred_loss = self.trainer.callback_metrics["test_pred_loss"]
        test_post_clus = self.trainer.callback_metrics["test_post_clus"]


    def draw_figure(self):

        _, test_dataset, _, test_sampler = get_mnist(
                self.config.data.n_classes, "data", 1, 1) # 1,1,10 # 500, 500, 5000
        test_dataloader = DataLoader(test_dataset, batch_size=10,
                num_workers=self.config.data.num_workers, sampler=test_sampler)
        
        ckpt_path = ckpt_path = self.checkpoint_callback.best_model_path

        model = FlowMatchingLightningModule(self.vt, self.Rt, self.rt, self.priorz, self.embx, self.config, self.args)

        self.trainer.test(model, dataloaders=test_dataloader, ckpt_path=ckpt_path)
        
       

