import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import ot
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import logging
import glob
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image.inception import InceptionScore
from sklearn.cluster import KMeans
from vendi_score import vendi

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_

import torch.optim as optim
from utils import get_named_beta_schedule
from Scheduler import GradualWarmupScheduler

from vae import *
from torchdiffeq import odeint_adjoint as odeint

from dataloader.dataloader_pinwheel import *
from dataloader.dataloader_mnist import *
from dataloader.dataloader_hvg import *

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy


torch.set_default_dtype(torch.float64)
torch.set_printoptions(precision=3)


def first_eigen_proj(x):
    # Step 1: Compute the covariance matrix
    x_centered = x - x.mean(0, keepdims=True)
    cov_matrix = np.cov(x_centered, rowvar=False)

    # Step 2: Perform Eigen decomposition
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # Step 3: Identify the first Eigen direction
    first_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]

    # Step 4: Project the array onto the first Eigen direction
    projected_array = np.dot(x, first_eigenvector)

    return projected_array

class PlotLossCallback(Callback):
    def __init__(self, save_path="loss_plot.png", update_interval=5):
        super().__init__()
        self.save_path = save_path
        self.update_interval = update_interval
        self.train_losses = []
        self.val_losses = []
        self.epochs = []

    def on_train_epoch_end(self, trainer, pl_module):
        # Get current epoch
        current_epoch = trainer.current_epoch

        # Log training and validation loss
        train_loss = trainer.callback_metrics.get("train_loss")
        val_loss = trainer.callback_metrics.get("val_loss")

        if train_loss is not None and val_loss is not None:
            self.epochs.append(current_epoch)
            self.train_losses.append(train_loss.cpu().item())
            self.val_losses.append(val_loss.cpu().item())

        # Update the plot every `update_interval` epochs
        if current_epoch % self.update_interval == 0:
            self.plot_and_save()

    def plot_and_save(self):
        plt.figure(figsize=(10, 6))
        plt.plot(self.epochs, self.train_losses, label="Training Loss", marker="o")
        plt.plot(self.epochs, self.val_losses, label="Validation Loss", marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.save_path)
        plt.close()

class PlotLogLikelihoodCallback(Callback):
    def __init__(self, save_path="loss_plot.png", log_keys=("log_likelihood_x", "log_likelihood_z")):
        """
        Callback to plot the log likelihoods logged during training for two variables (e.g., x and z).

        Args:
            log_keys (tuple): A tuple containing keys for the two log likelihoods.
        """
        super().__init__()
        self.save_path = os.path.join(save_path, f'llk.png')
        self.log_keys = log_keys
        self.log_likelihoods_x = []
        self.log_likelihoods_z = []

    def on_train_epoch_end(self, trainer, pl_module):
        """
        Called at the end of each training epoch.

        Args:
            trainer (Trainer): The PyTorch Lightning trainer instance.
            pl_module (LightningModule): The LightningModule being trained.
        """
        # Retrieve the logged values from the trainer's logger
        if self.log_keys[0] in trainer.callback_metrics:
            log_likelihood_x = trainer.callback_metrics[self.log_keys[0]].item()
            self.log_likelihoods_x.append(log_likelihood_x)
        
        if self.log_keys[1] in trainer.callback_metrics:
            log_likelihood_z = trainer.callback_metrics[self.log_keys[1]].item()
            self.log_likelihoods_z.append(log_likelihood_z)

        # Plot the log likelihoods
        plt.figure(figsize=(12, 6))

        # Subplot for log likelihood of x
        plt.subplot(1, 2, 1)
        plt.plot(self.log_likelihoods_x, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(x|z)")
        plt.title("Log Likelihood of x During Training")
        # plt.legend()
        plt.grid()

        # Subplot for log likelihood of z
        plt.subplot(1, 2, 2)
        plt.plot(self.log_likelihoods_z, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(z|x)")
        plt.title("Log Likelihood of z During Training")
        # plt.legend()
        plt.grid()

        plt.tight_layout()
        plt.savefig(self.save_path)
        plt.close()


class VAELightningModule(pl.LightningModule):
    def __init__(self, decoder: nn.Module, encoder: nn.Module, priorz, config, args):
        super().__init__()
        self.config = config
        self.args = args
        
        self.decoder = decoder
        self.encoder = encoder

        self.priorz = priorz
        try:
            self.d = self.config.flow.feature_dim
        except AttributeError:
            self.d = self.config.flow.z_dim
        self.p = self.config.data.size
        # Register buffers for priorz parameters

        self.automatic_optimization = False
        self.last_validation_batch = None


    def setup(self, stage=None):
    	self.elbo_loss = ELBO(self.decoder, self.encoder, self.priorz)


    def training_step(self, batch, batch_idx):
        # print("train")
        X, y = batch  # Assuming the batch is the input data `x`
        if not self.config.model.cnn:
            X = X.view(-1,self.p**2)

        loss = self.elbo_loss(X.to(torch.float64))

        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        g_opt = self.optimizers()
        g_opt.zero_grad()

        self.manual_backward(loss)
        g_opt.step()

        self.clip_gradients(g_opt, gradient_clip_val=self.config.training.clipval, gradient_clip_algorithm="norm")

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        if not self.config.model.cnn:
            X = X.view(-1,self.p**2)
        val_loss = self.elbo_loss(X.to(torch.float64))

        # Store the last batch for plotting
        if batch_idx == self.trainer.num_val_batches[0] - 1:
            self.last_validation_batch = {"X": X.to(torch.float64), "y": y}

        self.log('val_loss', val_loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        return val_loss

    def on_train_epoch_end(self):
        # manual scheduler step
        
        sch = self.lr_schedulers()
        sch.step()
        

    def on_validation_epoch_end(self):
        if self.last_validation_batch is not None:
            X = self.last_validation_batch["X"]
            y = self.last_validation_batch["y"]

            if self.current_epoch % self.config.training.snapshot_freq == 0:
                print("sampling")
                if self.args.config == "mnist.yml":
                    self.generate_img(X, y)
                elif self.args.config == "pinwheel.yml":
                    self.generate()
                elif self.args.config == "hvg.yml":
                    self.plot_latent_hvg(X, y)
                """ Snapshot sampling at the end of every epoch """
                # if self.config.training.snapshot_sampling:
                log_post_z, log_lik = self.sample_and_log(X, y)
                # self.sample_and_log()
                print("log_post_z", log_post_z)
                print("log_lik", log_lik)
                # self.log('tra_log_post_pi', log_post_pi, on_step=False, on_epoch=True, sync_dist=True, logger=True)
                self.log('tra_log_post_z', log_post_z, on_step=False, on_epoch=True, sync_dist=True, logger=True)
                self.log('tra_log_lik', log_lik, on_step=False, on_epoch=True, sync_dist=True, logger=True)

        # Clear the stored batch for next epoch
        self.last_validation_batch = None

    def configure_optimizers(self):
        # Define your optimizer
        # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        # return optimizer
        optimizer = torch.optim.AdamW(
            itertools.chain(
                self.decoder.parameters(),
                self.encoder.parameters(),
                # self.priorz.parameters()
                ),   
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)

        # return optimizer
        cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = self.config.training.n_epochs // 10 * 9,
                            eta_min = 0,
                            last_epoch = -1
                        )
        warmUpScheduler = GradualWarmupScheduler(
                                optimizer = optimizer,
                                multiplier = self.config.optim.multiplier,
                                warm_epoch = self.config.training.n_epochs // 10,
                                after_scheduler = cosineScheduler,
                                # after_scheduler = Scheduler,
                                last_epoch = self.current_epoch
                            )
        
        return [optimizer], [{'scheduler': warmUpScheduler, 
                            'monitor': 'train_loss',
                            "interval":"epoch", 
                            "frequency":1}]

    def generate(self):
        # generate posterior predictive
        self.decoder.eval()
        self.encoder.eval()
        # self.priorz.eval()

        # sample from posterior predictive
        z0 = self.priorz.sample((self.config.sample.n_gen,))
        
        xnew = self.decoder.sample(z0)
        zc = self.encoder.sample(xnew)

        znp = zc.cpu().detach().numpy()
        xnp = xnew.cpu().detach().numpy()

        if self.config.flow.feature_dim > 1:
            # get the projection of z onto its first eigenvector direction
            znp = first_eigen_proj(znp)
        else:
            znp = znp.squeeze()
        

        plt.figure()
        # cmap = plt.get_cmap('viridis')
        # colors = [cmap(z) for z in z_postr_samples.squeeze()]
        plt.scatter(xnp[:,0], xnp[:,1], c=znp, marker=".", cmap=plt.colormaps['gist_rainbow'])
        plt.colorbar()
        plt.savefig(os.path.join(self.args.log_sample_path, 'image_grid_{}.png'.format(self.current_epoch)))
        plt.close()

        self.decoder.train()
        self.encoder.train()
        # self.priorz.train()

    def generate_img(self, x, y, in_sample=True):
        if in_sample:
            name = "mnist"
            k = self.config.data.n_classes
        else:
            name = "emnist"
            k = 20
        # generate posterior predictive
        self.decoder.eval()
        self.encoder.eval()
        ssim = StructuralSimilarityIndexMeasure()
        loss_ssim = []
        fig, axes = plt.subplots(self.config.data.n_classes, 8, figsize=(10, 10))
        for row_idx, row_axes in enumerate(axes):
            m = len(row_axes)
            mask = y==row_idx
            if mask.sum() == 0:
                pass
            else:
                x_k = x[mask][0]
                # x0 = inv_transform(x_k).repeat(m-1,1)
                if self.config.model.cnn:
                    x0 = x_k.repeat(m-1,1,1,1)
                else:
                    x0 = x_k.repeat(m-1,1)
                z1 = self.priorz.sample((m-1,)).to(self.device)
                x0_new = self.decoder.sample(z1).to(self.device)
                zc = self.encoder.sample(x0_new).to(self.device)

                znp = zc.cpu().detach().numpy()
                x0 = inv_transform(x0)
                x0_new = inv_transform(x0_new)
                x0_np = x0_new.cpu().detach().numpy()
                x_np = inv_transform(x_k).cpu().detach().numpy()
                # print("x0_np", x0_np[0])
                x0_np = x0_np.reshape((-1,1,self.p,self.p))
                # print(x0_np[0])
                x_np = x_np.reshape((1,self.p,self.p))
                
            for col_idx, ax in enumerate(row_axes):   
                if col_idx == 0:
                    x_np_tr = np.transpose(x_np, (1, 2, 0))
                    ax.imshow(x_np_tr, cmap='gray')
                    ax.set_ylabel("y={}".format(row_idx))
                else:
                    x0_np_tr = np.transpose(x0_np[col_idx-1], (1, 2, 0))
                    ax.imshow(x0_np_tr, cmap='gray')
                ax.axis("off")

            real_img = x0.repeat(1, 3, 1, 1).to(torch.float32)
            gen_img = x0_new.repeat(1, 3, 1, 1).to(torch.float32)
            ssim_score = ssim(real_img, gen_img)

            loss_ssim.append(ssim_score)

        plt.tight_layout()
        # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
        plt.savefig(os.path.join(self.args.log_sample_path, f'image_grid_epoch_{self.current_epoch}.png'))
        plt.close()

        # latent 
        z0 = self.encoder.sample(x).to(self.device)
        z0_np = z0.cpu().detach().numpy()
        y_np = y.cpu().detach().numpy()
        cmap = plt.colormaps['tab10']

        tsne = TSNE(n_components=3, perplexity=30, random_state=0)
        z0_proj = tsne.fit_transform(z0_np)

        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(111, projection='3d')
        scatter = ax.scatter(z0_proj[:,0], z0_proj[:,1], z0_proj[:,2], c=y_np, cmap=cmap, s=10)
        cbar = plt.colorbar(scatter, ax=ax, pad=0.1, orientation='vertical', shrink=0.5)
        ax.set_title("Generated latent z given x")

        plt.tight_layout()
        plt.savefig(os.path.join(self.args.log_sample_path, f'postz_grid_epoch_{self.current_epoch}.png'))
        plt.close()

        self.decoder.train()
        self.encoder.train()

        if not self.training:
            vs = vendi.score_dual(z0_np, normalize=True)
            # clustering
            kmeans = KMeans(n_clusters=k, random_state=42, n_init='auto')
            kmeans.fit(z0_np)
            k0_np = kmeans.labels_

            ari = adjusted_rand_score(y_np, k0_np)
            # nmi = soft_nmi(pi_np, y_np)
            nmi = normalized_mutual_info_score(y_np, k0_np)

            return np.array(loss_ssim).mean(), nmi, ari, vs


    def sample_and_log(self, x, y):
        self.decoder.eval()
        self.encoder.eval()
        # self.priorz.eval()
        # posterior log lik
        zc = self.encoder.sample(x)
        log_post_z = self.encoder.log_prob(x, zc).mean()
        log_lik = self.decoder.log_prob(zc, x).mean()

        self.decoder.train()
        self.encoder.train()
        # self.priorz.train()

        return log_post_z, log_lik

    def plot_latent(self, x, y, in_sample=True):
        if in_sample:
            name = "mnist"
        else:
            name = "emnist"
        zc = self.encoder.sample(x)

        z0_np = zc.cpu().detach().numpy()
        y_np = y.cpu().detach().numpy()
        
        cmap = plt.colormaps['tab20']

        tsne = TSNE(n_components=2, perplexity=50, random_state=0) # perplexity=30 for mnist
        z0_proj = tsne.fit_transform(z0_np)

        # fig = plt.figure(figsize=(8, 8))
        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(111)
        scatter = ax.scatter(z0_proj[:,0], z0_proj[:,1], c=y_np, cmap=cmap, s=10)
        cbar = plt.colorbar(scatter, ax=ax, pad=0.1, orientation='vertical', shrink=0.5)
        # ax.set_title("Generated latent z given x")

        plt.tight_layout()
        plt.savefig(os.path.join(self.args.log_sample_path, f'test_plot_latent_{name}.png'))
        plt.close()

    def plot_latent_hvg(self, x, y, in_sample=True):
        zc = self.encoder.sample(x)

        z0_np = zc.cpu().detach().numpy()
        y_np = y.cpu().detach().numpy()
        
        cmap = plt.colormaps['tab20']

        tsne = TSNE(n_components=2, perplexity=50, random_state=0) # perplexity=30 for mnist
        z0_proj = tsne.fit_transform(z0_np)

        # fig = plt.figure(figsize=(8, 8))
        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(111)
        scatter = ax.scatter(z0_proj[:,0], z0_proj[:,1], c=y_np, cmap=cmap, s=10)
        cbar = plt.colorbar(scatter, ax=ax, pad=0.1, orientation='vertical', shrink=0.5)
        # ax.set_title("Generated latent z given x")

        plt.tight_layout()
        plt.savefig(os.path.join(self.args.log_sample_path, f'test_plot_latent_{self.current_epoch}.png'))
        plt.close()

        if not self.training:
            vs = vendi.score_dual(z0_np, normalize=True)

            # clustering
            kmeans = KMeans(n_clusters=self.config.data.n_classes, random_state=42, n_init='auto')
            kmeans.fit(z0_np)
            k0_np = kmeans.labels_

            ari = adjusted_rand_score(y_np, k0_np)
            # nmi = soft_nmi(pi_np, y_np)
            nmi = normalized_mutual_info_score(y_np, k0_np)
            # ari = soft_ari(pi_np, y_np)
            return nmi, ari, vs




    def test_generate(self, n):
        # generate n samples
        z0 = self.priorz.sample((n,))
        xnew = self.decoder.sample(z0)
        return xnew

    def test_inference(self, x, N=100, q_vals=torch.tensor([0.025, 0.975]), cmap='gray', in_sample=True):
        # use this for within sample and iut of sample evaluation
        x0 = x.unsqueeze(0).repeat(N, 1, 1, 1, 1)
        N, B, C = x0.shape

        # find the latent CI, and check for rare examples in test with wide 95% quantile
        z0 = self.encoder.sample(x.view(N*B, C)).to(self.device).reshape(N, B, self.d)
        # z0_norms = torch.linalg.norm(z0, dim=-1)
        z0_var = torch.var(z0, dim=0).sum(-1)

        # # compute quantiles 2.5% and 97.5% (2, B)
        # quantiles = torch.quantile(input=z0_norms, q=q_vals, dim=0)
        # # dim (B, d)
        # widths = np.abs(quantiles[1,:] - quantiles[0,:])
        # # find the maximum along the feature dimension (B,)
        # # max_widths = torch.max(widths, axis=-1)
        # # area_quantiles = torch.max(widths, dim=-1)[0]
        # area_quantiles = widths
        plt.figure()
        plt.hist(z0_var, bins=100)
        plt.savefig(os.path.join(self.args.log_sample_path, f'mnist_interval_width.png'))
        plt.close()
        # reorder the bactch index according to the area
        max_sorted_areas, max_sorted_idx = torch.sort(z0_var, descending=True, dim=0)

        # plot the first 8 images
        x_np = inv_transform(x).cpu().detach().numpy()
        if not self.config.model.cnn:
            x_np = x_np.reshape((-1,self.c,self.p,self.p))

        fig, axes = plt.subplots(1, 8, figsize=(10,2))
        for i in range(8):
            x_np_tr = np.transpose(x_np[max_sorted_idx][i], (1, 2, 0))
            axes[i].imshow(x_np_tr, cmap=cmap)
            axes[i].set_title("{:.2f}".format(max_sorted_areas[i]))
            # axes[i].set_ylabel("")
            axes[i].axis("off")

        plt.tight_layout()
        # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
        if in_sample:
            plt.savefig(os.path.join(self.args.log_sample_path, 'mnist_top8_uncertain.png'))
        else:
            plt.savefig(os.path.join(self.args.log_sample_path, 'emnist_top8_uncertain.png'))
        plt.close()

        fig, axes = plt.subplots(1, 8, figsize=(10,2))
        for i in range(1,8+1):
            x_np_tr = np.transpose(x_np[max_sorted_idx][-i], (1, 2, 0))
            axes[i-1].imshow(x_np_tr, cmap=cmap)
            axes[i-1].set_title("{:.2f}".format(max_sorted_areas[-i]))
            # axes[i-1].set_ylabel("")
            axes[i-1].axis("off")

        plt.tight_layout()
        # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
        if in_sample:
            plt.savefig(os.path.join(self.args.log_sample_path, f'mnist_bot8_uncertain.png'))
        else:
            plt.savefig(os.path.join(self.args.log_sample_path, f'emnist_bot8_uncertain.png'))
        plt.close()

    def test_step(self, batch, batch_idx):
        X, y = batch  # Assuming the batch is the input data `x`
        if not self.config.model.cnn:
            X = X.view(-1,self.p**2)

        loss = self.elbo_loss(X.to(torch.float64))
        n = len(X)

        if self.args.sample:
            xnew = self.test_generate(n).cpu().detach().numpy()
            xorg = X.cpu().detach().numpy()

            if self.args.config == "mnist.yml":
                # xnew = xnew.reshape(n, 1, self.p, self.p)
                # # xorg = xorg.reshape(n, 1, self.p, self.p)
                # fig, axes = plt.subplots(self.config.data.n_classes, 8, figsize=(10, 10))
                # for i, ax in enumerate(axes.flatten()):

                #     x0_np_tr = np.transpose(xnew[i], (1, 2, 0))
                #     ax.imshow(x0_np_tr, cmap=cmap)
                    
                #     ax.axis("off")
                # plt.tight_layout()
                # # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
                # plt.savefig(os.path.join(self.args.log_sample_path, f'test_gen_grid.png'))
                # plt.close()
                
                # self.plot_latent(X.to(torch.float64), y, in_sample=self.config.data.in_sample)
                ssim, nmi, ari, vs = self.generate_img(X.to(torch.float64), y, in_sample=self.config.data.in_sample)

                self.log('test_nmi', nmi, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
                self.log('test_ari', ari, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
                self.log('test_vendi', vs, on_step=True, on_epoch=True, sync_dist=True, logger=True)
                self.log('test_ssim', ssim, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
            elif self.args.config == "hvg.yml":
                nmi, ari, vs = self.plot_latent_hvg(X.to(torch.float64), y)
                self.log('test_nmi', nmi, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
                self.log('test_ari', ari, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
                self.log('test_vendi', vs, on_step=True, on_epoch=True, sync_dist=True, logger=True)

                # compute distance between true sample to generated samples
                w = 1/n * np.ones(n)

                M = ot.dist(xorg, xnew, "euclidean")
                M /= M.max() * 0.1
                d_emd = ot.emd2(w, w, M)

                self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
                self.log('test_emd', d_emd, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        elif self.args.inference:
            self.test_inference(X, in_sample=True)
            self.test_inference(X, in_sample=False)


class VAERunner():
    def __init__(self, args, config):
        self.args = args
        self.config = config
        args.log_sample_path = os.path.join(args.log_path, 'samples')
        os.makedirs(args.log_sample_path, exist_ok=True)

        self.config.flow.cnf = False

        # print(self.args)
        self.d = self.config.flow.z_dim
        self.p = self.config.data.size

        self.decoder = GaussianNet(self.p, self.d, hidden_features=[1024, 2048], fct=nn.Softplus()) # 3
        self.encoder = GaussianNet(self.d, self.p, hidden_features=[2048, 1024], fct=nn.Softplus())

        # self.priorz = GaussianNet(self.d, hidden_features=[])
        if self.config.flow.cnf:    
            self.priorz = mCNF(self.d, hidden_features=[]*1)
        else:
            self.priorz = DiagNormal(torch.zeros(self.d), torch.ones(self.d))
        
        # Define the ModelCheckpoint callback
        self.checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Metric to monitor
            dirpath=self.args.log_path,  # Directory where checkpoints will be saved
            filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
            save_top_k=1,  # Only save the best model based on val_loss
            mode='min'  # Minimize the validation loss
        )
        plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss.png'), update_interval=1)
        llk_callback = PlotLogLikelihoodCallback(save_path=self.args.log_sample_path, log_keys=("tra_log_lik", "tra_log_post_z"))
        # Initialize the Trainer

        if torch.cuda.is_available():
            accelerator='gpu'
            strategy="ddp"
            devices="auto"
        else:
            accelerator='cpu'
            devices="auto"
            strategy = "auto"

        self.trainer = pl.Trainer(
            max_epochs=self.config.training.n_epochs, 
            # accelerator='gpu',
            accelerator = accelerator,
            devices=devices,
            strategy=strategy,
            # strategy='ddp',
            # strategy=DDPStrategy(find_unused_parameters=True),
            # gpus=self.config.,  # Number of GPUs
            # logger=pl.loggers.TensorBoardLogger(save_dir=self.args.tb_path, name="lightning_logs"),  # Optional: log to TensorBoard
            # logger=self.config.tb_logger
            callbacks=[plot_loss_callback, llk_callback, self.checkpoint_callback]
        )

    def train(self):
        # load data
        if self.args.config == "pinwheel.yml":
            dataset, test_dataset = get_dataset(self.config.data.n_classes, "data", self.config.data.samplesize, self.config.data.test_samplesize)
            train_dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                    num_workers=self.config.data.num_workers)
            val_dataloader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                     num_workers=self.config.data.num_workers, drop_last=True)
        elif self.args.config == "hvg.yml":
            train_dataloader, val_dataloader, _ = get_hvg_dataloaders(
            dat_dir="data", batch_size= self.config.training.batch_size,
            num_workers=self.config.data.num_workers)

        model = VAELightningModule(self.decoder, self.encoder, self.priorz, self.config, self.args)
        # Run the training loop
        if not self.args.resume_training:
            ckpt_path = None
        else:
            ckpt_path = "exp/logs/may13-vae-hvg/best-checkpoint-epoch=07-val_loss=-40.94.ckpt"

        self.trainer.fit(
            model, train_dataloader, val_dataloader, ckpt_path=ckpt_path
            )

    def sample(self):
        _, val_dataloader, test_dataloader = get_hvg_dataloaders(
            dat_dir="data", batch_size=1226, num_workers=self.config.data.num_workers)
        model = VAELightningModule(self.decoder, self.encoder, self.priorz, self.config, self.args)

        # ckpt_path = "exp/logs/may13-vae-hvg/best-checkpoint-epoch=07-val_loss=-40.94.ckpt"
        ckpt_path = "exp/logs/may13-vae-hvg-2/best-checkpoint-epoch=11-val_loss=-57.05.ckpt"

        self.trainer.test(model, dataloaders=val_dataloader, ckpt_path=ckpt_path)


class VAERunner_MNIST():
    def __init__(self, args, config):
        self.args = args
        self.config = config
        args.log_sample_path = os.path.join(args.log_path, 'samples')
        os.makedirs(args.log_sample_path, exist_ok=True)
        # print(self.args)
        self.d = self.config.flow.z_dim
        self.c, self.p = self.config.data.channel, self.config.data.size

        self.config.model.cnn = False #True
        self.config.training.n_epochs = 200
        self.config.optim.lr = 0.0005
        self.config.training.clipval = 1.
        
        self.priorz = DiagNormal(torch.zeros(self.d), torch.ones(self.d))
        self.decoder = LLKNet(self.p**2*self.c, self.d, fct=nn.ReLU(), hidden_features=[128, 512, 512])
        self.encoder = GaussianNet(self.d, self.p**2*self.c, fct=nn.ReLU(), hidden_features=[512, 512, 128])

        # Define the ModelCheckpoint callback
        self.checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Metric to monitor
            dirpath=self.args.log_path,  # Directory where checkpoints will be saved
            filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
            save_top_k=1,  # Only save the best model based on val_loss
            mode='min'  # Minimize the validation loss
        )
        # Initialize the Trainer

        if torch.cuda.is_available():
            accelerator='gpu'
            strategy="ddp"
            devices="auto"
        else:
            accelerator='cpu'
            devices="auto"
            strategy = "auto"

        plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss.png'), update_interval=1)
        plot_llk_callback = PlotLogLikelihoodCallback(save_path=self.args.log_sample_path, log_keys=("tra_log_lik", "tra_log_post_z"))

        self.trainer = pl.Trainer(
            max_epochs=self.config.training.n_epochs, 
            # accelerator='gpu',
            accelerator = accelerator,
            devices=devices,
            strategy=strategy,
            callbacks=[plot_loss_callback, plot_llk_callback, self.checkpoint_callback]
        )

    def train(self):
        # load data
        dataset, val_dataset, sampler, val_sampler = get_mnist(
            self.config.data.n_classes, "data", self.config.data.samplesize, self.config.data.test_samplesize,
            logtran=True)
        train_dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size,
                            num_workers=self.config.data.num_workers, sampler=sampler)
        val_dataloader = DataLoader(val_dataset, batch_size=self.config.training.batch_size, 
                             num_workers=self.config.data.num_workers, sampler=val_sampler, drop_last=True)
        # Initialize the Lightning model
        model = VAELightningModule(self.decoder, self.encoder, self.priorz, self.config, self.args)
        # Run the training loop
        if not self.args.resume_training:
            ckpt_path = None
        else:
            ckpt_path = self.checkpoint_callback.best_model_path

        self.trainer.fit(
            model, train_dataloader, val_dataloader, ckpt_path=ckpt_path
            )

    def sample(self):
        if self.config.data.in_sample:
            dataset, _, sampler, _ = get_mnist(
                self.config.data.n_classes, "data", 300, 50,
                logtran=True)
            dataloader = DataLoader(dataset, batch_size=3000, 
                                 num_workers=self.config.data.num_workers, sampler=sampler, drop_last=True)
        else:
            # 0123456789ABCDEFGHIJ
            dataset, sampler = get_emnist(
                20, "data", 300, 50, split="balanced") # balanced
            dataloader = DataLoader(dataset, batch_size=6000,
                                num_workers=self.config.data.num_workers, sampler=sampler)

        model = VAELightningModule(self.decoder, self.encoder, self.priorz, self.config, self.args)

        ckpt_path = self.checkpoint_callback.best_model_path

        self.trainer.test(model, dataloaders=dataloader, ckpt_path=ckpt_path)

    def inference(self):
        if self.config.data.in_sample:
            train_dataset, val_dataset, train_sampler, val_sampler = get_mnist(
                self.config.data.n_classes, "data", 50, 50, logtran=False)
            # train_dataloader = DataLoader(dataset, batch_size=1000,
            #                     num_workers=self.config.data.num_workers, sampler=sampler)
            dataloader = DataLoader(val_dataset, batch_size=500, 
                                 num_workers=self.config.data.num_workers, sampler=val_sampler, drop_last=True)
        else:
            dataset, sampler = get_emnist(
                10, "data", 50, 50, split="letters") # digits
            dataloader = DataLoader(dataset, batch_size=500,
                                num_workers=self.config.data.num_workers, sampler=sampler)
            

        model = VAELightningModule(self.decoder, self.encoder, self.priorz, self.config, self.args)

        ckpt_path = self.checkpoint_callback.best_model_path

        self.trainer.test(model, dataloaders=dataloader, ckpt_path=ckpt_path)







