import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import itertools
import numpy as np
from tqdm import tqdm
import logging
import glob
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image.inception import InceptionScore
from sklearn.cluster import KMeans
from vendi_score import vendi
# kernel = lambda a, b: np.exp(-np.linalg.norm(a - b)**2/2)
kernel = lambda a, b: np.dot(a,b)/100
import pandas as pd
import ot

import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_

import torch.optim as optim
from utils import *
from Scheduler import WarmUpScheduler # GradualWarmupScheduler
from torchdiffeq import odeint_adjoint as odeint

from dataloader.dataloader_mnist import *
from dataloader.dataloader_cifar import *
from dataloader.dataloader_pendulum import *
from dataloader.dataloader_hvg import *
from dataloader.dataloader_pinwheel import *

from runner_mnist import PlotLogLikelihoodCallback, PlotLossCallback, map_to_position
from LatentFM.latentfm import *
from sfa_lds import LatentDynamicalSystem

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy

torch.set_default_dtype(torch.float32)
torch.set_printoptions(precision=3)

class LatentFlowLightning(pl.LightningModule):
	def __init__(self, vt: nn.Module, encoder: nn.Module, decoder: nn.Module, config, args, stage="recon"):
		super().__init__()
		self.config = config
		self.args = args
		self.stage = stage

		self.vt = vt
		self.encoder = encoder
		self.decoder = decoder

		self.d = self.config.flow.z_dim
		self.c, self.p = self.config.data.channel, self.config.data.size
		self.sig_min = 1e-4

		self.automatic_optimization = False
		self.last_validation_batch = None

	def setup(self, stage=None):
		self.prior = DiagNormal(torch.zeros(self.d).to(self.device), torch.ones(self.d).to(self.device))
		self.flow_matching_loss = FlowMatchingLoss(self.vt, self.encoder, self.prior, self.sig_min)
		self.recon_loss = ReconstructionLoss(self.encoder, self.decoder)

	def configure_optimizers(self):
		
		optimizer = torch.optim.AdamW(
            itertools.chain(
                self.vt.parameters(),
                self.encoder.parameters(),
                self.decoder.parameters()
                ), 
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)

		cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
		    optimizer = optimizer,
		    T_max = self.config.training.n_epochs,
		    eta_min = 1e-5,
		    last_epoch = -1
		)
		warmUpScheduler = WarmUpScheduler(
			optimizer = optimizer, 
			lr_scheduler=cosineScheduler, 
			warmup_steps=self.config.training.n_epochs // 10, 
			warmup_start_lr=0.00005,
			len_loader=self.config.data.samplesize//self.config.training.batch_size
		)

		return [optimizer], [warmUpScheduler]


	def training_step(self, batch, batch_idx):
		X, y = batch
		if not self.config.model.cnn:
		    X = X.view(-1, self.c*self.p**2)

		if self.stage == 'recon':
		    loss = self.recon_loss(X)
		    self.log('train_loss', loss, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
		if self.stage == 'flow':
		    loss = self.flow_matching_loss(X)
		    self.log('train_loss', loss, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

		opt = self.optimizers()
		opt.zero_grad()
		self.manual_backward(loss)
		opt.step()

		return loss
    
	def validation_step(self, batch, batch_idx):
		X, y = batch
		# print("recon loss", self.recon_loss(X))
		if not self.config.model.cnn:
		    X = X.view(-1, self.c*self.p**2)

		if self.stage == 'recon':
		    val_loss = self.recon_loss(X)
		    self.log('val_loss', val_loss, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
		if self.stage == 'flow':
		    val_loss = self.flow_matching_loss(X)
		    self.log('val_loss', val_loss, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

		if batch_idx == self.trainer.num_val_batches[0] - 1:
			self.last_validation_batch = {"X": X, "y": y}        

		return val_loss

	def on_train_epoch_end(self):
		# manual scheduler step
		sch = self.lr_schedulers()
		sch.step()

	def on_validation_epoch_end(self):
		if self.last_validation_batch is not None:
			X = self.last_validation_batch["X"]
			y = self.last_validation_batch["y"]

			if self.args.config == "mnist.yml":
				if not self.config.model.cnn:
					X = X.view(-1, self.c*self.p**2)

				if self.current_epoch % self.config.training.snapshot_freq == 0:
					""" Snapshot sampling at the end of every epoch """
					# if self.config.training.snapshot_sampling:
					if self.stage == "recon":
						self.plot_latent(X, y)
						self.generate_img(X, y)
					if self.stage == "flow":
						self.sample(100)

		self.last_validation_batch = None


	def sample(self, n):
		self.encoder.eval()
		self.decoder.eval()
		self.vt.eval()
		inception = InceptionScore()
		with torch.no_grad():
			z1 = self.prior.sample((n,)).to(self.device)
			z0 = self.vt.decode(z1)
			x0 = self.decoder(z0)

			z0_np = z0.cpu().detach().numpy()
			x0_new = inv_transform(x0)
			x0_np = x0_new.cpu().detach().numpy()
			# print("x0", x0_np.shape)
			# x0_np_tr = np.transpose(x0_np, (1, 2, 0))

			cmap = plt.colormaps['tab10']

			# plot sampled x
			fig, axes = plt.subplots(10, 8, figsize=(10, 10))
			for i, ax in enumerate(axes.flatten()):
				x0_np_tr = np.transpose(x0_np[i], (1, 2, 0))
				ax.imshow(x0_np_tr, cmap='gray')
				ax.axis('off')  # hides ticks and labels

			plt.tight_layout()
			# plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
			plt.savefig(os.path.join(self.args.log_sample_path, f'sampx_grid_epoch_{self.current_epoch}.png'))
			plt.close()

		gen_img = x0_new.repeat(1, 3, 1, 1).to(torch.float32)
		inception.update(gen_img.to(torch.uint8)) # comput this over all classes
		is_score_mean, is_score_std = inception.compute()

		self.encoder.train()
		self.decoder.train()
		self.vt.train()

		return is_score_mean, is_score_std
		
	def generate(self, plot=True):
		# generate posterior predictive
		self.encoder.eval()
		self.decoder.eval()
		self.vt.eval()

		z1 = self.prior.sample((self.config.sample.n_gen,)).to(self.config.device)
		z0 = self.vt.decode(z1)
		x1 = self.decoder(z1)

		xnp = x1.cpu().detach().numpy()
		# print(x1.shape)
		znp = z1.cpu().detach().numpy()
		znp = znp.squeeze()

		if plot:
			plt.figure()
			# cmap = plt.get_cmap('viridis')
			# colors = [cmap(z) for z in z_postr_samples.squeeze()]
			plt.scatter(xnp[:,0], xnp[:,1], c=znp, marker=".", cmap=plt.colormaps['gist_rainbow'])
			plt.colorbar()
			plt.savefig(os.path.join(self.args.log_sample_path, 'image_grid_{}.png'.format(self.current_epoch)))
			plt.close()

		self.encoder.train()
		self.decoder.train()
		self.vt.train()

	def generate_img(self, x, y):
		self.encoder.eval()
		self.decoder.eval()
		self.vt.eval()

		ssim = StructuralSimilarityIndexMeasure()

		loss_ssim = []
		with torch.no_grad():

			fig, axes = plt.subplots(10, 8, figsize=(10, 10))
			for row_idx, row_axes in enumerate(axes):
				m = len(row_axes)
				mask = y==row_idx
				if mask.sum() == 0:
				    pass
				else:
				    x_k = x[mask][0]
				    # x0 = inv_transform(x_k).repeat(m-1,1)
				    if self.config.model.cnn:
				        x0 = x_k.repeat(m-1,1,1,1)
				    else:
				        x0 = x_k.repeat(m-1,1)

				    z0 = self.encoder(x0)
				    x0_new = self.decoder(z0)
				    x0_new = inv_transform(x0_new)
				    x0 = inv_transform(x0)
				    x0_np = x0_new.cpu().detach().numpy()
				    x_np = inv_transform(x_k).cpu().detach().numpy()
				    if not self.config.model.cnn:
				        x0_np = x0_np.reshape((-1,self.c,self.p,self.p))
				        x_np = x_np.reshape((self.c,self.p,self.p))
				    
				for col_idx, ax in enumerate(row_axes):   
					if col_idx == 0:
						x_np_tr = np.transpose(x_np, (1, 2, 0))
						ax.imshow(x_np_tr, cmap='gray')
						ax.set_ylabel("y={}".format(row_idx))
					else:
						x0_np_tr = np.transpose(x0_np[col_idx-1], (1, 2, 0))
						ax.imshow(x0_np_tr, cmap='gray')
					ax.axis("off")

				real_img = x0.repeat(1, 3, 1, 1).to(torch.float32)
				gen_img = x0_new.repeat(1, 3, 1, 1).to(torch.float32)
				ssim_score = ssim(real_img, gen_img)

				loss_ssim.append(ssim_score)

			plt.tight_layout()
			# plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
			plt.savefig(os.path.join(self.args.log_sample_path, f'image_grid_epoch_{self.current_epoch}.png'))
			plt.close()
		self.encoder.train()
		self.decoder.train()
		self.vt.train()

		return np.array(loss_ssim).mean()


	def plot_latent(self, x, y):
		self.encoder.eval()
		self.decoder.eval()
		self.vt.eval()

		with torch.no_grad():
			z0 = self.encoder(x)
			z0_np = z0.cpu().detach().numpy()
			y_np = y.cpu().detach().numpy()

			cmap = plt.colormaps['tab20']

			tsne = TSNE(n_components=2, perplexity=50, random_state=0)
			z0_proj = tsne.fit_transform(z0_np)

			fig = plt.figure(figsize=(8, 6))
			ax = fig.add_subplot(111)
			scatter = ax.scatter(z0_proj[:,0], z0_proj[:,1], c=y_np, cmap=cmap, s=10)
			cbar = plt.colorbar(scatter, ax=ax, pad=0.1, orientation='vertical', shrink=0.5)
			# ax.set_title("Generated latent z given x")

			plt.tight_layout()
			if self.training:
				plt.savefig(os.path.join(self.args.log_sample_path, f'postz_grid_epoch_{self.current_epoch}.png'))
			else:
				if self.config.data.in_sample:
					k = self.config.data.n_classes
					plt.savefig(os.path.join(self.args.log_sample_path, f'test_latent_mnist.png'))
				else:
					k = 20
					plt.savefig(os.path.join(self.args.log_sample_path, f'test_latent_emnist.png'))
			plt.close()

		self.encoder.train()
		self.decoder.train()
		self.vt.train()

		if not self.training:
			
			# vs = vendi.score(z0_np, kernel)
			vs = vendi.score_dual(z0_np, normalize=True)
			# clustering
			kmeans = KMeans(n_clusters=k, random_state=42, n_init='auto')
			kmeans.fit(z0_np)
			k0_np = kmeans.labels_

			ari = adjusted_rand_score(y_np, k0_np)
			# nmi = soft_nmi(pi_np, y_np)
			nmi = normalized_mutual_info_score(y_np, k0_np)
			# ari = soft_ari(pi_np, y_np)
			return nmi, ari, vs

	def test_regenerate(self, x, y):
		
		z0 = self.encoder(x)
		x0 = self.decoder(z0)

		x0_np = x0.cpu().detach().numpy()
		y0_np = y.cpu().detach().numpy()

		
		# vs_x = vendi.score(x0_np, kernel)
		vs = vendi.score_dual(x0_np, normalize=True)

		# plot the umap
		dict = {'CD14 Mono': 0, 'CD4 T': 1, 'T': 2, 'CD8 T': 3, 'B': 4, 'DC': 5, 'CD16 Mono': 6, 'NK': 7}
		inv_dict = {v: k for k, v in dict.items()}
		cell_type = [inv_dict.get(i) for i in y0_np]

		adata0 = ad.AnnData(x0_np)
		adata0.obs['cell_type'] = cell_type
		return vs

	def test_step(self, batch, batch_idx):
		X, y = batch
		n = len(X)
		# print("recon loss", self.recon_loss(X))
		if not self.config.model.cnn:
		    X = X.view(-1, self.c*self.p**2)

		if self.args.sample:
			if self.args.config == "mnist.yml":
				is_loss, is_loss_std = self.sample(100)
				self.log('test_is_loss', is_loss, on_step=True, on_epoch=True, sync_dist=True, logger=True)
				self.log('test_is_loss_std', is_loss_std, on_step=True, on_epoch=True, sync_dist=True, logger=True)

				ssim = self.generate_img(X, y)
				self.log('test_ssim', ssim, on_step=True, on_epoch=True, sync_dist=True, logger=True)

				nmi, ari, vs = self.plot_latent(X, y)
				self.log('test_nmi', nmi, on_step=True, on_epoch=True, sync_dist=True, logger=True)
				self.log('test_ari', ari, on_step=True, on_epoch=True, sync_dist=True, logger=True)
				self.log('test_vendi', vs, on_step=True, on_epoch=True, sync_dist=True, logger=True)
        	
			elif self.args.config == "hvg.yml":
				vs_x = self.test_regenerate(X, y)
				nmi, ari, vs = self.plot_latent(X, y)
				self.log('test_nmi', nmi, on_step=True, on_epoch=True, sync_dist=True, logger=True)
				self.log('test_ari', ari, on_step=True, on_epoch=True, sync_dist=True, logger=True)
				self.log('test_vendi', vs, on_step=True, on_epoch=True, sync_dist=True, logger=True)
				self.log('test_vendi_x', vs_x, on_step=True, on_epoch=True, sync_dist=True, logger=True)

				z1 = self.prior.sample((n,)).to(self.device)
				z0 = self.vt.decode(z1)
				xnew = self.decoder(z0)

				xnew = xnew.cpu().detach().numpy()

				xorg = X.cpu().detach().numpy()

				w = 1/n * np.ones(n)

				M = ot.dist(xorg, xnew, "euclidean")
				M /= M.max() * 0.1
				d_emd = ot.emd2(w, w, M)

				self.log('test_emd', d_emd, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)


			elif self.args.config == "pinwheel.yml":
				self.generate()

				z1 = self.prior.sample((n,)).to(self.device)
				z0 = self.vt.decode(z1)
				xnew = self.decoder(z0)

				xnew = xnew.cpu().detach().numpy()

				xorg = X.cpu().detach().numpy()

				w = 1/n * np.ones(n)

				M = ot.dist(xorg, xnew, "euclidean")
				M /= M.max() * 0.1
				d_emd = ot.emd2(w, w, M)

				self.log('test_emd', d_emd, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)


class MNISTRunner():
	def __init__(self, args, config):
		self.args = args
		self.config = config
		args.log_sample_path = os.path.join(args.log_path, 'samples')
		os.makedirs(args.log_sample_path, exist_ok=True)

		self.d = self.config.flow.z_dim
		self.c, self.p = self.config.data.channel, self.config.data.size

		if not self.config.model.cnn:
			self.encoder = MLPEncoder(self.p, self.d, self.c,
				fct=nn.Softplus(),
				hidden_features = [self.config.model.ngf] * 3
				)

			self.decoder = MLPDecoder(self.p, self.d, self.c,
				fct=nn.Softplus(),
				hidden_features = [self.config.model.ngf] * 3
				)
		else:
			self.encoder = CNNEncoder(self.c, self.d,
				fct=nn.Softplus(),
				)
			self.decoder = CNNDecoder(self.c, self.d,
				fct=nn.Softplus(),
				target_size=self.p,
				)

		self.vt = LatentCNF(self.d, freqs=10,
				fct=nn.Softplus(),
				hidden_features = [512] * 2)

		if torch.cuda.is_available():
			accelerator='gpu'
			strategy=DDPStrategy(find_unused_parameters=True)
			devices="auto"
		else:
			accelerator='cpu'
			devices="auto"
			strategy = "auto"


		self.checkpoint_callback = ModelCheckpoint(
		    monitor='val_loss',  # Metric to monitor
		    dirpath=self.args.log_path,  # Directory where checkpoints will be saved
		    filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
		    save_top_k=1,  # Only save the best model based on val_loss
		    mode='min'  # Minimize the validation loss
		)

		plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss_{self.args.flowmatch * "flowmatch"}.png'), update_interval=1)

		self.trainer = pl.Trainer(
			max_epochs=self.config.training.n_epochs, 
			# accelerator='gpu',
			accelerator = accelerator,
			devices=devices,
			strategy=strategy,
			callbacks=[self.checkpoint_callback, plot_loss_callback],
		)

	def train(self):
		if self.args.config == "mnist.yml":
			dataset, val_dataset, sampler, val_sampler = get_mnist(
		    	self.config.data.n_classes, "data", self.config.data.samplesize, self.config.data.test_samplesize)

			train_dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size,
			                        num_workers=self.config.data.num_workers, sampler=sampler)
			val_dataloader = DataLoader(val_dataset, batch_size=self.config.training.batch_size, 
			                         num_workers=self.config.data.num_workers, sampler=val_sampler, drop_last=True)


		if self.args.recon:
		    # train autodecoder with reconstruction
		    ae_model = LatentFlowLightning(self.vt, self.encoder, self.decoder, self.config, self.args, stage='recon')
		    for p in ae_model.vt.parameters():
			    p.requires_grad = False

		    self.trainer.fit(ae_model, train_dataloader, val_dataloader)


		if self.args.flowmatch:
			# best_ckpt_path = self.checkpoint_callback.best_model_path
			_ckpt_path = os.path.join(self.args.log_path)
			ckpt_files = [f for f in 
					os.listdir(_ckpt_path) if f.endswith('.ckpt')]
			
			best_ckpt_path = self.checkpoint_callback.best_model_path

			print("ckpt path", best_ckpt_path)
			ckpt = torch.load(best_ckpt_path, map_location='cpu')
			state_dict = ckpt['state_dict']
			filtered_state_dict = {
			    k: v for k, v in state_dict.items()
			    if not k.startswith('vt.') and 'vt' not in k  # robust filter
			}

			flow_model = LatentFlowLightning(self.vt, self.encoder, self.decoder, self.config, self.args, stage='flow')
			flow_model.load_state_dict(filtered_state_dict, strict=False)

			# Freeze encoder/decoder
			for p in flow_model.encoder.parameters():
			    p.requires_grad = False
			for p in flow_model.decoder.parameters():
			    p.requires_grad = False

			self.trainer.fit(flow_model, train_dataloader, val_dataloader)

	def sample(self):
		if self.config.data.in_sample:
			dataset, _, sampler, _ = get_mnist(
				self.config.data.n_classes, "data", 300, 50)
			dataloader = DataLoader(dataset, batch_size=3000,
			                    num_workers=self.config.data.num_workers, sampler=sampler)

		else:
			# 0123456789ABCDEFGHIJ
			dataset, sampler = get_emnist(
				20, "data", 300, 50, split="balanced") # balanced
			dataloader = DataLoader(dataset, batch_size=6000,
			                    num_workers=self.config.data.num_workers, sampler=sampler)
		model = LatentFlowLightning(self.vt, self.encoder, self.decoder, self.config, self.args)

		ckpt_path = self.checkpoint_callback.best_model_path

		self.trainer.test(model, dataloaders=dataloader, ckpt_path=ckpt_path)
		
class EuclideanRunner():
	def __init__(self, args, config):
		self.args = args
		self.config = config
		args.log_sample_path = os.path.join(args.log_path, 'samples')
		os.makedirs(args.log_sample_path, exist_ok=True)

		self.d = self.config.flow.z_dim
		self.p = self.config.data.size

		self.encoder = Encoder(self.p, self.d, 
			fct=nn.Softplus(),
			hidden_features = [self.config.model.ngf] * 2
			)

		self.decoder = Decoder(self.p, self.d, 
			fct=nn.Softplus(),
			hidden_features = [self.config.model.ngf] * 2
			)
		

		self.vt = LatentCNF(self.d, freqs=10,
				fct=nn.Softplus(),
				hidden_features = [self.config.flow.ngf] * self.config.flow.z_depth)

		if torch.cuda.is_available():
			accelerator='gpu'
			strategy=DDPStrategy(find_unused_parameters=True)
			devices="auto"
		else:
			accelerator='cpu'
			devices="auto"
			strategy = "auto"


		self.checkpoint_callback = ModelCheckpoint(
		    monitor='val_loss',  # Metric to monitor
		    dirpath=self.args.log_path,  # Directory where checkpoints will be saved
		    filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
		    save_top_k=1,  # Only save the best model based on val_loss
		    mode='min'  # Minimize the validation loss
		)

		plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss_{self.args.flowmatch * "flowmatch"}.png'), update_interval=1)

		self.trainer = pl.Trainer(
			max_epochs=self.config.training.n_epochs, 
			# accelerator='gpu',
			accelerator = accelerator,
			devices=devices,
			strategy=strategy,
			callbacks=[self.checkpoint_callback, plot_loss_callback],
		)

	def train(self):
		if self.args.config == "pinwheel.yml":
			dataset, test_dataset = get_dataset(self.config.data.n_classes, "data", self.config.data.samplesize, self.config.data.test_samplesize)
			train_dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True,
			                        num_workers=self.config.data.num_workers)
			val_dataloader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
			                         num_workers=self.config.data.num_workers, drop_last=True)
		elif self.args.config == "hvg.yml":
			train_dataloader, val_dataloader, _ = get_hvg_dataloaders(
		    	dat_dir="data", batch_size= self.config.training.batch_size,
		    	num_workers=self.config.data.num_workers)

		if self.args.recon:
		    # train autodecoder with reconstruction
		    ae_model = LatentFlowLightning(self.vt, self.encoder, self.decoder, self.config, self.args, stage='recon')
		    for p in ae_model.vt.parameters():
			    p.requires_grad = False

		    self.trainer.fit(ae_model, train_dataloader, val_dataloader)


		if self.args.flowmatch:

			best_ckpt_path = self.checkpoint_callback.best_model_path
			ckpt = torch.load(best_ckpt_path, map_location='cpu')
			state_dict = ckpt['state_dict']
			filtered_state_dict = {
			    k: v for k, v in state_dict.items()
			    if not k.startswith('vt.') and 'vt' not in k  # robust filter
			}

			flow_model = LatentFlowLightning(self.vt, self.encoder, self.decoder, self.config, self.args, stage='flow')
			flow_model.load_state_dict(filtered_state_dict, strict=False)

			# Freeze encoder/decoder
			for p in flow_model.encoder.parameters():
			    p.requires_grad = False
			for p in flow_model.decoder.parameters():
			    p.requires_grad = False

			self.trainer.fit(flow_model, train_dataloader, val_dataloader)

	def sample(self):
		if self.args.config == "pinwheel.yml":
			dataset, val_dataset = get_dataset(self.config.data.n_classes, "data", 2000, 2000)

			val_dataloader = DataLoader(val_dataset, batch_size=500, shuffle=True,
			                 num_workers=self.config.data.num_workers, drop_last=True)

		elif self.args.config == "hvg.yml":
			train_dataloader, val_dataloader, _ = get_hvg_dataloaders(
		    	dat_dir="data", batch_size= 1226,
		    	num_workers=self.config.data.num_workers)

		model = LatentFlowLightning(self.vt, self.encoder, self.decoder, self.config, self.args)

		ckpt_path = self.checkpoint_callback.best_model_path

		self.trainer.test(model, dataloaders=val_dataloader, ckpt_path=ckpt_path)

