import os

import argparse
import itertools
import numpy as np
from tqdm import tqdm
import logging
import glob
import numpy as np
# from torchvision.utils import save_image
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
import ot

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_

import torch.optim as optim
# from diffusion import GaussianDiffusion
from utils import *
# from embedding import ConditionalEmbedding
from Scheduler import GradualWarmupScheduler
from sfa_discrete import *

# from zuko.utils import odeint
from torchdiffeq import odeint_adjoint as odeint

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy

from dataloader.dataloader_pinwheel import *

import gc
gc.collect()


torch.set_default_dtype(torch.float64)


def siren_init(m):
    if isinstance(m, nn.Linear):
        with torch.no_grad():
            num_input = m.weight.size(-1)
            m.weight.uniform_(-1 / num_input, 1 / num_input)

def map_to_position(input_tensor, num_classes):
    """
    Maps each element in the input tensor to a one-hot encoded row
    with a 1 in the position specified by the value in the input tensor.
    
    Parameters:
    - input_tensor: A tensor of values between 0 and 9.
    
    Returns:
    - A new tensor where each row is a one-hot encoding with 1 in the position
      specified by the corresponding element in the input_tensor.
    """
    num_rows = input_tensor.size(0)
    output_tensor = torch.zeros((num_rows, num_classes))  # Create a tensor of zeros with 10 columns
    output_tensor[torch.arange(num_rows), input_tensor] = 1  # Set the index specified by input_tensor to 1
    
    return output_tensor


class PlotLogLikelihoodCallback(Callback):
    def __init__(self, save_path="loss_plot.png", log_keys=("log_likelihood_x", "log_likelihood_z")):
        """
        Callback to plot the log likelihoods logged during training for two variables (e.g., x and z).

        Args:
            log_keys (tuple): A tuple containing keys for the two log likelihoods.
        """
        super().__init__()
        self.save_path = os.path.join(save_path, f'llk.png')
        self.log_keys = log_keys
        self.log_likelihoods_x = []
        self.log_likelihoods_z = []

    def on_train_epoch_end(self, trainer, pl_module):
        """
        Called at the end of each training epoch.

        Args:
            trainer (Trainer): The PyTorch Lightning trainer instance.
            pl_module (LightningModule): The LightningModule being trained.
        """
        # Retrieve the logged values from the trainer's logger
        if self.log_keys[0] in trainer.callback_metrics:
            log_likelihood_x = trainer.callback_metrics[self.log_keys[0]].item()
            self.log_likelihoods_x.append(log_likelihood_x)
        
        if self.log_keys[1] in trainer.callback_metrics:
            log_likelihood_z = trainer.callback_metrics[self.log_keys[1]].item()
            self.log_likelihoods_z.append(log_likelihood_z)

        # Plot the log likelihoods
        plt.figure(figsize=(12, 6))

        # Subplot for log likelihood of x
        plt.subplot(1, 2, 1)
        plt.plot(self.log_likelihoods_x, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(x|z)")
        plt.title("Log Likelihood of x During Training")
        # plt.legend()
        plt.grid()

        # Subplot for log likelihood of z
        plt.subplot(1, 2, 2)
        plt.plot(self.log_likelihoods_z, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(z|x)")
        plt.title("Log Likelihood of z During Training")
        # plt.legend()
        plt.grid()

        plt.tight_layout()
        plt.savefig(self.save_path)
        plt.close()


class LightningModule(pl.LightningModule):
	def __init__(self, vt: nn.Module, Rt: nn.Module, rt: nn.Module, priorz, config, args):
	    super().__init__()
	    self.config = config
	    self.args = args
	    
	    self.vt = vt
	    self.Rt = Rt
	    self.rt = rt
	    self.priorz = priorz

	    self.k, self.d = self.config.flow.k_dim, self.config.flow.z_dim
	    self.p = self.config.data.size
	    # Register buffers for priorz parameters

	    self.automatic_optimization = False
	    self.last_validation_batch = None


	def setup(self, stage=None):
		self.priorpi = DiagNormal(torch.zeros(self.k, device=self.device), torch.ones(self.k, device=self.device))

		self.flow_matching_loss = FlowMatchingLoss_R(
			self.vt, self.Rt, self.rt, self.priorpi, self.priorz, 
			k=self.k, fix_z=self.config.flow.fix_z, fix_pi=self.config.flow.fix_pi)

	def training_step(self, batch, batch_idx):
	    # print("train")
	    X, y = batch  # Assuming the batch is the input data `x`
	    loss = self.flow_matching_loss(X.to(torch.float64))

	    self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

	    g_opt, d_opt = self.optimizers()
	    g_opt.zero_grad()
	    d_opt.zero_grad()

	    self.manual_backward(loss)
	    g_opt.step()
	    d_opt.step()

	    return loss

	def validation_step(self, batch, batch_idx):
	    X, y = batch
	    val_loss = self.flow_matching_loss(X.to(torch.float64))

	    # Store the last batch for plotting
	    if batch_idx == self.trainer.num_val_batches[0] - 1:
	        self.last_validation_batch = {"X": X.to(torch.float64), "y": y}

	    self.log('val_loss', val_loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

	    return val_loss

	def on_train_epoch_end(self):
		# manual scheduler step

		sch1, sch2 = self.lr_schedulers()
		sch1.step()
		sch2.step()
        

	def on_validation_epoch_end(self):
		if self.last_validation_batch is not None:
		    X = self.last_validation_batch["X"]
		    y = self.last_validation_batch["y"]

		    if self.current_epoch % self.config.training.snapshot_freq == 0:
		        
		        """ Snapshot sampling at the end of every epoch """
		        # if self.config.training.snapshot_sampling:
		        log_post_z, log_lik = self.sample_and_log(X, y)
		        self.generate(X, y)
		        # self.sample_and_log()
		        # self.log('tra_log_post_pi', log_post_pi, on_step=False, on_epoch=True, sync_dist=True, logger=True)
		        self.log('tra_log_post_z', log_post_z, on_step=False, on_epoch=True, sync_dist=True, logger=True)
		        self.log('tra_log_lik', log_lik, on_step=False, on_epoch=True, sync_dist=True, logger=True)

		# Clear the stored batch for next epoch
		self.last_validation_batch = None

	def configure_optimizers(self):
		if not self.config.training.discrete:
			optimizer = torch.optim.AdamW(
			itertools.chain(
				self.vt.parameters(),
				
				), 
			lr=self.config.optim.lr,
			weight_decay=self.config.optim.weight_decay)

			optimizer_latent = torch.optim.Adam(
				itertools.chain(
					self.rt.parameters(),
					self.Rt.parameters(),
					# self.priorz.parameters()
					), 
				lr=self.config.optim.lr,
				weight_decay=self.config.optim.weight_decay)
		else:
			optimizer = torch.optim.AdamW(
			itertools.chain(
				self.vt.parameters(),
				), 
			lr=self.config.optim.lr,
			weight_decay=self.config.optim.weight_decay)

			optimizer_latent = torch.optim.Adam(
				itertools.chain(
					self.Rt.parameters(),
					# self.priorpi.parameters()
					), 
				lr=self.config.optim.lr/10,
				weight_decay=self.config.optim.weight_decay)

		cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
		                    optimizer = optimizer,
		                    T_max = self.config.training.n_epochs,
		                    eta_min = 0,
		                    last_epoch = -1
		                )

		warmUpScheduler = GradualWarmupScheduler(
		                        optimizer = optimizer,
		                        multiplier = self.config.optim.multiplier,
		                        warm_epoch = self.config.training.n_epochs // 10,
		                        after_scheduler = cosineScheduler,
		                        last_epoch = self.current_epoch
		                    )

		cosineScheduler_latent = optim.lr_scheduler.CosineAnnealingLR(
		                    optimizer = optimizer_latent,
		                    T_max = self.config.training.n_epochs,
		                    eta_min = 0,
		                    last_epoch = -1
		                )

		warmUpScheduler_latent = GradualWarmupScheduler(
		                        optimizer = optimizer_latent,
		                        multiplier = self.config.optim.multiplier,
		                        warm_epoch = self.config.training.n_epochs // 10,
		                        after_scheduler = cosineScheduler_latent,
		                        last_epoch = self.current_epoch
		                    )

		return [optimizer, optimizer_latent], [
                {'scheduler': warmUpScheduler, 
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}, 
                {'scheduler': warmUpScheduler_latent,
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}]

	def sample_and_log(self, x, y):
		self.vt.eval()
		self.Rt.eval()
		self.rt.eval()
		# self.priorz.eval()

		with torch.no_grad():
			# evluate log probability 
			x1 = torch.randn(len(x), self.config.data.size)
			# generated samples
			logits1 = self.priorpi.sample((len(x),)).to(x.device)
			z1idx = F.gumbel_softmax(logits1, tau=self.config.flow.tau, hard=False)
			z1 = self.priorz.sample(z1idx, (1,)).to(x.device)

			if self.config.flow.fix_z and self.config.flow.fix_pi:
				logits0, z0idx = self.Rt.rsample(x)
				z0 = self.rt.sample(z0idx, x)
				pi0 = softmax(logits0/self.config.training.beta, dim=-1)
			elif self.config.flow.fix_z:
				logits0 = self.Rt.decode(logits1, x)
				pi0 = softmax(logits0/self.config.training.beta, dim=-1)
				z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.tau, hard=False)
				z0 = self.rt.sample(z0idx, x)
			elif self.config.flow.fix_pi:
				logits0, z0idx = self.Rt.rsample(x)
				pi0 = softmax(logits0/self.config.training.beta, dim=-1)
				z0 = self.rt.decode(z1, z0idx, x)
			else:
				logits0 = self.Rt.decode(logits1, x)
				pi0 = softmax(logits0/self.config.training.beta, dim=-1)
				z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.tau, hard=False)
				z0 = self.rt.decode(z1, z0idx, x)

			if self.config.flow.fix_z:
				log_post_z = self.rt.log_prob(z0idx, x, z0).mean()
			else:
				log_post_z = self.rt.log_prob(z0, z0idx, x, 0., self.priorz).mean()
			log_lik = self.vt.log_prob(x, z0, 0.).mean()

		self.vt.train()
		self.Rt.train()
		self.rt.train()
		# self.priorz.train()

		return log_post_z, log_lik

	def generate(self, x, y):
		self.vt.eval()
		self.Rt.eval()
		self.rt.eval()
		# self.priorz.eval()

		with torch.no_grad():
			# posterior predictive
			x1 = torch.randn(self.config.sample.n_gen, self.config.data.size)
			# generated samples
			logits1 = self.priorpi.sample((self.config.sample.n_gen,)).to(x.device)
			z1idx = F.gumbel_softmax(logits1, tau=self.config.flow.tau, hard=False)
			z1 = self.priorz.sample(z1idx, (1,)).to(x.device)

			x0 = self.vt.decode(x1, z1)
			
			if self.config.flow.fix_z and self.config.flow.fix_pi:
				logits0, z0idx = self.Rt.rsample(x0)
				z0 = self.rt.sample(z0idx, x0)
				pi0 = softmax(logits0/self.config.training.beta, dim=-1)
			elif self.config.flow.fix_z:
				logits0 = self.Rt.decode(logits1, x0)
				pi0 = softmax(logits0/self.config.training.beta, dim=-1)
				z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.tau, hard=False)
				z0 = self.rt.sample(z0idx, x0)
			elif self.config.flow.fix_pi:
				logits0, z0idx = self.Rt.rsample(x0)
				pi0 = softmax(logits0/self.config.training.beta, dim=-1)
				z0 = self.rt.decode(z1, z0idx, x0)
			else:
				logits0 = self.Rt.decode(logits1, x0)
				pi0 = softmax(logits0/self.config.training.beta, dim=-1)
				z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.tau, hard=False)
				z0 = self.rt.decode(z1, z0idx, x0)

			# z1 = torch.multinomial(q1, 1).view(-1)
			x0_numpy = x0.cpu().detach().numpy()
			x_numpy = x.cpu().detach().numpy()
			z0_numpy = z0.cpu().detach().numpy()
			pi0_numpy = pi0.cpu().detach().numpy()

			cmap = plt.colormaps['gist_rainbow']

			plt.figure()
			plt.scatter(x0_numpy[:,0], x0_numpy[:,1], c=z0_numpy.squeeze(), marker=".", cmap=cmap)
			plt.colorbar()
			plt.savefig(os.path.join(self.args.log_sample_path, 'image_grid_{}.png'.format(self.current_epoch)))
			plt.close()

			# histogram of the generated class label
			fig, axes = plt.subplots(5, 3, figsize=(5,5), sharex=True, constrained_layout=True)
			for i, rax in enumerate(axes):
				mask = y==i
				xk = x[y==i]
				for j, cax in enumerate(rax):
					if j<len(xk):
						pi1 = self.priorpi.sample((100,))
						x0 = xk[j].unsqueeze(0).to(torch.float64).repeat(100,1)
						
						if self.config.flow.fix_pi:
							pi0, _ = self.Rt.rsample(x0)
						else:
							pi0 = self.Rt.decode(pi1, x0) # repeat same sample 100 times
						kidx = torch.argmax(pi0, dim=-1).squeeze()
						cax.hist(kidx, bins=20, alpha=0.7)
					else:
						cax.axis("off")
					if j==0:
						cax.set_ylabel(f"y={i}")

			plt.tight_layout()
			plt.savefig(os.path.join(self.args.log_sample_path, 'post_grid_{}.png'.format(self.current_epoch)))
			plt.close()

		self.vt.train()
		self.Rt.train()
		self.rt.train()
		# self.priorz.train()

	def test_generate(self, n):
		x1 = torch.randn(n, self.config.data.size)
		# generated samples
		logits1 = self.priorpi.sample((n,))
		z1idx = F.gumbel_softmax(logits1, tau=self.config.flow.tau, hard=False)
		z1 = self.priorz.sample(z1idx, (1,))

		x0 = self.vt.decode(x1, z1)
		
		if self.config.flow.fix_z and self.config.flow.fix_pi:
			logits0, z0idx = self.Rt.rsample(x0)
			z0 = self.rt.sample(z0idx, x0)
			pi0 = softmax(logits0/self.config.training.beta, dim=-1)
		elif self.config.flow.fix_z:
			logits0 = self.Rt.decode(logits1, x0)
			pi0 = softmax(logits0/self.config.training.beta, dim=-1)
			z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.tau, hard=False)
			z0 = self.rt.sample(z0idx, x0)
		elif self.config.flow.fix_pi:
			logits0, z0idx = self.Rt.rsample(x0)
			pi0 = softmax(logits0/self.config.training.beta, dim=-1)
			z0 = self.rt.decode(z1, z0idx, x0)
		else:
			logits0 = self.Rt.decode(logits1, x0)
			pi0 = softmax(logits0/self.config.training.beta, dim=-1)
			z0idx = F.gumbel_softmax(logits0, tau=self.config.flow.tau, hard=False)
			z0 = self.rt.decode(z1, z0idx, x0)
		return x0, pi0

	def test_step(self, batch, batch_idx):
		X, y = batch  # Assuming the batch is the input data `x`
		self.test_plot_path(X, y)

	def test_plot_path(self, x, y, n_time=100):
		sorted_idx = torch.argsort(y)
		xk = x[sorted_idx].to(torch.float64)
		yk = y[sorted_idx].numpy()

		# sample
		logits1 = self.priorpi.rsample((len(x),))
		# z1idx = F.gumbel_softmax(logits1, tau=self.tau, hard=hard)
		_, z1idx = self.Rt.rsample(None, logits1)
		z1 = self.priorz.rsample(z1idx, (1,))
		logitst, ztidx = self.Rt.rsample(xk)
		pi0 = softmax(logitst/self.config.flow.tau, dim=-1)
		z0, trajectories, time_points = self.rt.decode_with_trajectory(z1, pi0, xk, num_points=n_time)
		_z0 = z0.cpu().detach().numpy()

		recoded_n_time = len(time_points)
		_zt_path = trajectories.cpu().detach().numpy().reshape(len(x), recoded_n_time, self.config.flow.z_dim)
		# print(_trajectories.shape)
		# print(time_points.shape)
		cmap = plt.colormaps['tab10']

		fig, axes = plt.subplots(5, 1, figsize=(8, 8), sharex=True)
		fig.tight_layout(pad=3.0)  # Add some padding between subplots
		for class_label in range(self.config.data.n_classes):  # self.config.data.n_classes
		    ax = axes[class_label]  # Get the appropriate subplot
		    
		    class_indices = np.where(yk == class_label)[0]
		    top_indices = class_indices[:min(5, len(class_indices))]
		    
		    # Plot all trajectories for this class in its subplot
		    for idx in top_indices:
		        path = _zt_path[idx]
		        color = cmap(class_label)
		        ax.set_xlim(time_points.numpy()[0], time_points.numpy()[-1])
		        ax.plot(time_points.numpy(), path, c=color)
		    ax.set_title("Class {}, Mean Across [0,1] = {}".format(class_label, np.round(np.median(_zt_path[top_indices,-1,0]), 3)))
		    if class_label == 2:
		        ax.set_ylabel(r'Latent Trajectory $z_t$')
		axes[-1].set_xlabel('Time')
		# plt.show()
		plt.savefig(os.path.join(self.args.log_sample_path, f'eval_latent_path_mean.png'))
		plt.close()


		x1_new = torch.randn(len(xk), self.config.data.size)
		x0_new, x0_trajectories, time_points = self.vt.decode_with_trajectory(x1_new, z0)
		_x0_new = x0_new.cpu().detach().numpy()
		recoded_n_time = len(time_points)
		_x0_trajectories = x0_trajectories.cpu().detach().numpy().reshape(len(xk), recoded_n_time, -1)
		# print(_trajectories.shape)
		# print(time_points.shape)
		_xt_path = pca_map(_x0_new, _x0_trajectories, k=1)

		fig, axes = plt.subplots(5, 1, figsize=(8,8), sharex=True)
		fig.tight_layout(pad=3.0)  # Add some padding between subplots
		for class_label in range(self.config.data.n_classes):  # self.config.data.n_classes
			ax = axes[class_label]  # Get the appropriate subplot

			class_indices = np.where(yk == class_label)[0]
			top_indices = class_indices[:min(2, len(class_indices))]

			# Plot all trajectories for this class in its subplot
			for idx in top_indices:
			    path = _xt_path[idx]
			    # print(path.shape)
			    color = cmap(class_label)
			    ax.set_xlim(time_points.numpy()[0], time_points.numpy()[-1])
			    ax.plot(time_points.numpy(), path, c=color)
			ax.set_title("Class {}, Mean Across [0,1] = {}".format(class_label, np.round(np.median(_xt_path[top_indices,-1,0]), 3)))
			if class_label == 2:
			    ax.set_ylabel(r'PCA Projected Observed Trajectory $x_t$')
		axes[-1].set_xlabel('Time')
		# plt.show()
		plt.savefig(os.path.join(self.args.log_sample_path, f'eval_obs_path.png'))
		plt.close()



def remap_checkpoint_state_dict(state_dict):
    new_state_dict = {}
    
    for key, value in state_dict.items():
        if "Rt.fc.0." in key:
            # Convert from "Rt.fc.0.X" to "Rt.fc.X"
            parts = key.split(".")
            # Remove the "0" part and reconstruct
            new_parts = parts[:2] + parts[3:]
            new_key = ".".join(new_parts)
            new_state_dict[new_key] = value
        elif "flow_matching_loss.Rt.fc.0." in key:
            # Same pattern for flow_matching_loss.Rt.fc
            parts = key.split(".")
            # Remove the "0" part
            new_parts = parts[:3] + parts[4:]
            new_key = ".".join(new_parts)
            new_state_dict[new_key] = value
        else:
            # Keep other parameters unchanged
            new_state_dict[key] = value
    
    return new_state_dict



class PinWheelRunner():
	def __init__(self, args, config):
		self.args = args
		self.config = config
		args.log_sample_path = os.path.join(args.log_path, 'samples')
		os.makedirs(args.log_sample_path, exist_ok=True)

		self.k = self.config.data.n_classes
		self.priorz = update_GaussianMixtureComponent(self.config.flow.k_dim, self.config.flow.z_dim, hidden_features=[])

		self.vt = LLK_R(self.config.data.size, self.config.flow.z_dim, self.config.flow.k_dim, hidden_features=[self.config.model.ngf]*5)
		# vt.apply(siren_init)
		
		if self.config.flow.fix_pi:
			self.Rt = CatNF_fixed(self.config.data.size, self.config.flow.k_dim, hidden_features=[self.config.flow.ngf], batch_norm=True)
		else:
			self.Rt = CatNF_R(self.config.data.size, self.config.flow.k_dim, hidden_features=[self.config.flow.ngf])

		if self.config.flow.fix_z:
			self.rt = GaussianMixtureComponent(self.config.flow.k_dim, self.config.flow.z_dim, self.config.data.size, hidden_features=[self.config.flow.ngf]*2)
		else:
			self.rt = CNF_R(self.config.data.size, self.config.flow.z_dim, self.config.flow.k_dim, hidden_features=[self.config.flow.ngf])

		# Define the ModelCheckpoint callback
		self.checkpoint_callback = ModelCheckpoint(
		    monitor='val_loss',  # Metric to monitor
		    dirpath=self.args.log_path,  # Directory where checkpoints will be saved
		    filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
		    save_top_k=1,  # Only save the best model based on val_loss
		    mode='min'  # Minimize the validation loss
		)
		# Initialize the Trainer

		if torch.cuda.is_available():
		    accelerator='gpu'
		    strategy=DDPStrategy(find_unused_parameters=True)
		    devices="auto"
		else:
		    accelerator='cpu'
		    devices="auto"
		    strategy = "auto"

		llk_callback = PlotLogLikelihoodCallback(save_path=self.args.log_sample_path, log_keys=("tra_log_lik", "tra_log_post_z"))
		# Add the callback

		self.trainer = pl.Trainer(
		    max_epochs=self.config.training.n_epochs, 
		    # accelerator='gpu',
		    accelerator = accelerator,
		    devices=devices,
		    strategy=strategy,
		    callbacks=[llk_callback, self.checkpoint_callback],
		)

	def train(self):
		# load data
		dataset, test_dataset = get_dataset(self.config.data.n_classes, "data", self.config.data.samplesize, self.config.data.test_samplesize)
		train_dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True,
		                        num_workers=self.config.data.num_workers)
		val_dataloader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
		                         num_workers=self.config.data.num_workers, drop_last=True)

		# Initialize the Lightning model
		model = LightningModule(self.vt, self.Rt, self.rt, self.priorz, self.config, self.args)
		# Run the training loop
		if not self.args.resume_training:
			ckpt_path = None
		else:
			ckpt_path = self.checkpoint_callback.best_model_path

		self.trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)

	def sample(self):
		dataset, test_dataset = get_dataset(self.config.data.n_classes, "data", 500, 500)

		test_dataloader = DataLoader(test_dataset, batch_size=2500, shuffle=True,
		                 num_workers=self.config.data.num_workers, drop_last=True)

		model = LightningModule(self.vt, self.Rt, self.rt, self.priorz, self.config, self.args)

		ckpt_path = self.checkpoint_callback.best_model_path

		# When loading:
		state_dict = checkpoint['state_dict']
		remapped_state_dict = remap_checkpoint_state_dict(state_dict)
		# Try loading with the remapped state dict
		model.load_state_dict(remapped_state_dict, strict=False)

		self.trainer.test(model, dataloaders=test_dataloader)
