import os
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import logging
import glob
import numpy as np
import matplotlib.gridspec as gridspec
# from torchvision.utils import save_image

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_

import torch.optim as optim
from utils import get_named_beta_schedule
from Scheduler import WarmUpScheduler # GradualWarmupScheduler
from sfa_lds import *
from torchdiffeq import odeint_adjoint as odeint

from dataloader.dataloader_pendulum import *
from utils  import * 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy

import gc
gc.collect()


torch.set_printoptions(precision=3)
torch.set_default_dtype(torch.float64)



def map_to_position(input_tensor, num_classes):
    """
    Maps each element in the input tensor to a one-hot encoded row
    with a 1 in the position specified by the value in the input tensor.
    
    Parameters:
    - input_tensor: A tensor of values between 0 and 9.
    
    Returns:
    - A new tensor where each row is a one-hot encoding with 1 in the position
      specified by the corresponding element in the input_tensor.
    """
    num_rows = input_tensor.size(0)
    output_tensor = torch.zeros((num_rows, num_classes))  # Create a tensor of zeros with 10 columns
    output_tensor[torch.arange(num_rows), input_tensor] = 1  # Set the index specified by input_tensor to 1
    
    return output_tensor

def scaledsigmoid(x):
    return torch.sigmoid(x) * 2 - 1

class PlotLogLikelihoodCallback(Callback):
    def __init__(self, save_path="loss_plot.png", log_keys=("log_likelihood_x", "log_likelihood_z")):
        """
        Callback to plot the log likelihoods logged during training for two variables (e.g., x and z).

        Args:
            log_keys (tuple): A tuple containing keys for the two log likelihoods.
        """
        super().__init__()
        self.save_path = os.path.join(save_path, f'llk.png')
        self.log_keys = log_keys
        self.log_likelihoods_x = []
        self.log_likelihoods_z = []

    def on_train_epoch_end(self, trainer, pl_module):
        """
        Called at the end of each training epoch.

        Args:
            trainer (Trainer): The PyTorch Lightning trainer instance.
            pl_module (LightningModule): The LightningModule being trained.
        """
        # Retrieve the logged values from the trainer's logger
        if self.log_keys[0] in trainer.callback_metrics:
            log_likelihood_x = trainer.callback_metrics[self.log_keys[0]].item()
            self.log_likelihoods_x.append(log_likelihood_x)
        
        if self.log_keys[1] in trainer.callback_metrics:
            log_likelihood_z = trainer.callback_metrics[self.log_keys[1]].item()
            self.log_likelihoods_z.append(log_likelihood_z)

        # Plot the log likelihoods
        plt.figure(figsize=(12, 6))

        # Subplot for log likelihood of x
        plt.subplot(1, 2, 1)
        plt.plot(self.log_likelihoods_x, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(x|z)")
        plt.title("Log Likelihood of x During Training")
        # plt.legend()
        plt.grid()

        # Subplot for log likelihood of z
        plt.subplot(1, 2, 2)
        plt.plot(self.log_likelihoods_z, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(z|x)")
        plt.title("Log Likelihood of z During Training")
        # plt.legend()
        plt.grid()

        plt.tight_layout()
        plt.savefig(self.save_path)
        plt.close()

class PlotLossCallback(Callback):
    def __init__(self, save_path="loss_plot.png", update_interval=1):
        super().__init__()
        self.save_path = save_path
        self.update_interval = update_interval
        self.train_losses = []
        self.val_losses = []
        self.epochs = []

    def on_train_epoch_end(self, trainer, pl_module):
        # print(f"At train epoch end, metrics: {list(trainer.callback_metrics.keys())}")
        # Get current epoch
        epoch_num = trainer.current_epoch
        metrics = trainer.callback_metrics

        # Get training loss
        train_loss = metrics.get("train_loss")
        if train_loss is not None:
            if len(self.train_losses) <= epoch_num:
                self.train_losses.append(train_loss.item())
        
        # Get validation loss
        val_loss = metrics.get("val_loss")
        if val_loss is not None:
            if len(self.val_losses) <= epoch_num:
                self.val_losses.append(val_loss.item())
        # print("val", val_loss)
        
        # Track epoch number
        if len(self.epochs) <= epoch_num:
            self.epochs.append(epoch_num)

        # Update the plot every `update_interval` epochs
        if epoch_num % self.update_interval == 0:
            self.plot_and_save()



    def plot_and_save(self):
        plt.figure(figsize=(10, 6))
        # plt.xscale("log")
        # plt.yscale("log")

        plt.plot(self.epochs, self.train_losses, label="Training Loss", marker="o")
        plt.plot(self.epochs, self.val_losses, label="Validation Loss", marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.save_path)
        plt.close()

class GradientNormPlotCallback(pl.Callback):
    def __init__(self, save_path="gradnorm_plot.png"):
        super().__init__()
        # List to store the mean gradient norm for each epoch
        self.save_path = save_path
        self.epoch_grad_norms = []

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        total_norm = 0.0
        count = 0
        # Iterate over all parameters and compute their L2 norm if a gradient exists
        for p in pl_module.parameters():
            if p.grad is not None:
                total_norm += p.grad.data.norm(2).item()
                count += 1

        # Compute the average gradient norm for this batch
        avg_grad_norm = total_norm / count if count > 0 else 0.0

        # Store the batch gradient norm in a temporary list on the module
        if not hasattr(pl_module, 'batch_grad_norms'):
            pl_module.batch_grad_norms = []
        pl_module.batch_grad_norms.append(avg_grad_norm)

    def on_train_epoch_end(self, trainer, pl_module):
        # At the end of the epoch, compute the mean gradient norm across all batches
        if hasattr(pl_module, 'batch_grad_norms') and pl_module.batch_grad_norms:
            epoch_mean = sum(pl_module.batch_grad_norms) / len(pl_module.batch_grad_norms)
        else:
            epoch_mean = 0.0
        
        self.epoch_grad_norms.append(epoch_mean)
        # Clear the temporary storage for the next epoch
        pl_module.batch_grad_norms = []

        # Plot the gradient norm history up to the current epoch
        plt.figure(figsize=(8, 4))
        plt.plot(range(1, len(self.epoch_grad_norms) + 1), self.epoch_grad_norms, marker='o')
        plt.xlabel("Epoch")
        plt.ylabel("Mean Gradient Norm")
        plt.title("Mean Gradient Norm per Epoch")
        plt.grid(True)
        # Save the figure to a file or log it through your logger
        plt.savefig(self.save_path)
        plt.close()


# Define your LightningModule
class FlowMatchingLightningModule(pl.LightningModule):
    def __init__(self, vt: nn.Module, rt: nn.Module, prior, config, args):
        super().__init__()
        self.config = config
        self.args = args
        self.lr = self.config.optim.lr
        self.prior = prior
        self.vt = vt
        self.rt = rt
        self.c, self.p = self.config.data.channel, self.config.data.p
        self.S = self.config.data.S

        self.automatic_optimization = False
        self.last_validation_batch = None

        self.q = self.config.data.q

    def setup(self, stage=None):
        if not self.config.model.cnn:
            self.priory = Normal(torch.zeros(self.p**2).to(self.device), torch.ones(self.p**2).to(self.device))
        else:
            self.priory = Normal(torch.zeros(self.c, self.p, self.p).to(self.device), torch.ones(self.c, self.p, self.p).to(self.device))
        
        if self.config.model.const:
            if not self.config.model.cnn:
                self.flow_matching_loss = FlowMatchingLossc(self.vt, self.rt, self.prior, alpha=self.config.training.alpha, flowcnn=self.config.flow.cnn, fixz=self.config.flow.fixed)
            else:
                self.flow_matching_loss = FlowMatchingLosscnn(self.vt, self.rt, self.prior, alpha=self.config.training.alpha, flowcnn=self.config.flow.cnn)
        else:
            self.flow_matching_loss = FlowMatchingLoss(self.vt, self.rt, self.prior, alpha=self.config.training.alpha, fixz=self.config.flow.fixed)


    def forward(self, n, x, z, indices=None):
        # for generating data given class z (batched integer)
        z0 = self.prior.sample(n, self.config.data.S, device=self.device)

        if not self.config.flow.cnn:
            x0 = torch.randn(self.config.data.S, n, 1, self.config.data.p, self.config.data.p, device=self.device) # channel=1
            x1 = self.vt.decodeS(x0, z0)
            z1 = self.rt.decodeS(z0, x1.flatten(start_dim=2), indices=indices)
        else:  
            x0 = torch.randn(self.config.data.S, n, 1, self.config.data.p, self.config.data.p, device=self.device) # channel=1
            x1 = self.vt.decodeS(x0, z0)
            z1 = self.rt.decodeS(z0, x1, indices=indices)

        x1_np = inv_transform(x1).cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p, self.config.data.p))
        z1_np = z1.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))
        z0_np = z0.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))
        
        return x1_np, z1_np, z0_np

    def training_step(self, batch, batch_idx):
        # print("train")
        X, y, indices = batch  # Assuming the batch is the input data `x`
        X = X.view(self.config.data.S, -1, 1, self.config.data.p, self.config.data.p).to(dtype=torch.float64)
        indices = indices.view(self.config.data.S, -1)
        
        loss = self.flow_matching_loss(X, indices=indices)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        g_opt, d_opt = self.optimizers()
        d_opt.zero_grad()
        g_opt.zero_grad()
        self.manual_backward(loss)
        # if self.current_epoch >= 20:
        d_opt.step()
        g_opt.step()

        self.clip_gradients(d_opt, gradient_clip_val=self.config.training.clipval, gradient_clip_algorithm="norm")
        self.clip_gradients(g_opt, gradient_clip_val=self.config.training.clipval, gradient_clip_algorithm="norm")

        return loss

    def validation_step(self, batch, batch_idx):
        X, y, indices = batch

        X = X.view(self.config.data.S, -1, 1, self.config.data.p, self.config.data.p).to(dtype=torch.float64)
        indices = indices.view(self.config.data.S, -1)
        val_loss = self.flow_matching_loss(X, indices=indices)

        # Store the last batch for plotting
        if batch_idx == self.trainer.num_val_batches[0] - 1:
            self.last_validation_batch = {"X": X, "y": y, "indices":indices}
        
        self.log('val_loss', val_loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
        
        return val_loss

    def configure_optimizers(self):
        # Define your optimizer
        # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        # return optimizer
        optimizer = torch.optim.AdamW(
            itertools.chain(
                self.vt.parameters(),
                ), 
            lr=self.lr,
            weight_decay=self.config.optim.weight_decay)
        cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = self.config.training.n_epochs,
                            eta_min = 1e-6,
                            last_epoch = -1
                        )

        warmUpScheduler = WarmUpScheduler(
            optimizer = optimizer, 
            lr_scheduler=cosineScheduler, 
            warmup_steps=self.config.training.n_epochs // 10, 
            warmup_start_lr=0.00005,
            len_loader=self.config.data.samplesize//self.config.training.batch_size
            )
        optimizer_latent = torch.optim.AdamW(
            itertools.chain(
                # self.prior.parameters(),
                self.rt.parameters()
                ), 
            lr=self.lr,
            weight_decay=self.config.optim.weight_decay)
        cosineScheduler_latent = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer_latent,
                            T_max = self.config.training.n_epochs,
                            eta_min = 1e-6,
                            last_epoch = -1
                        )

        warmUpScheduler_latent = WarmUpScheduler(
            optimizer = optimizer_latent, 
            lr_scheduler=cosineScheduler_latent, 
            warmup_steps=self.config.training.n_epochs // 10, 
            warmup_start_lr=0.00005,
            len_loader=self.config.data.samplesize//self.config.training.batch_size
            )

        return [optimizer, optimizer_latent], [
                {'scheduler': warmUpScheduler, 
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}, 
                {'scheduler': warmUpScheduler_latent, 
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}]



    def eval_latent(self, x, y, indices):
        # Set models to evaluation mode
        self.vt.eval()
        self.rt.eval()
        # self.prior.eval()
        n = x.shape[1]
        with torch.no_grad():
            # find the latent z given data x
            z0 = self.prior.sample(n, self.config.data.S, device=self.device) # S,n,p
            x0 = torch.randn(self.config.data.S, n, 1, self.config.data.p, self.config.data.p, device=self.device) # channel=1

            if self.config.flow.cnn:
                z1 = self.rt.decodeS(z0, x, indices=indices)
            else:
                z1 = self.rt.decodeS(z0, x.flatten(start_dim=2), indices=indices)

            if self.config.model.const:
                x1 = self.vt.decodeS(x0, z1)
            else:
                x1 = self.vt.decodeS(x0, z1, indices=indices)

            x1_np = inv_transform(x1).cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p,self.config.data.p))
            z1_np = z1.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))
            y_np = y.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.n))
            x_np = x.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p,self.config.data.p))
            # first 100 frames
            x1_stack = x1_np[1][:self.q**2].reshape(self.q, self.config.data.p*self.q, self.config.data.p) # first 7**2=49 frames
            x1_imgrid = x1_stack.swapaxes(0, 1).reshape(self.config.data.p * self.q, self.config.data.p * self.q)

            x_stack = x_np[1][:self.q**2].reshape(self.q, self.config.data.p*self.q, self.config.data.p) # first 49 frames
            x_imgrid = x_stack.swapaxes(0, 1).reshape(self.config.data.p * self.q, self.config.data.p * self.q)

            fig, axes = plt.subplots(2,2, figsize=(12, 5))
            # plt.hist2d(*x.T, bins=64)
            axes[0,0].imshow(x_imgrid, cmap="gray", origin="lower", aspect=.2)
            axes[0,1].imshow(x1_imgrid, cmap="gray", origin="lower", aspect=.2)
            axes[1,0].plot(np.arange(self.S), y_np[1])
            # axes[2].plot(y1_np[0][:self.q**2,0], y1_np[0][:self.q**2,1])
            axes[1,0].set_xmargin(0)
            axes[1,1].plot(np.arange(self.S), z1_np[1])
            # axes[1].plot(z1_np[0][:self.q**2,0], z1_np[0][:self.q**2,1])
            axes[1,1].set_xmargin(0)
            
            axes[0,0].set_xlabel("real data")
            axes[0,1].set_xlabel("gen data")
            axes[1,0].set_xlabel("real latent")
            axes[1,1].set_xlabel("gen latent")
            plt.tight_layout()

            # plot_image_sequence_and_trajectory(x1_np[0], z1_np[0], figsize=(20,2))
            # # # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'image_eval_epoch_{self.current_epoch}.png'))
            plt.close()

            # plot_image_sequence_and_trajectory(x_np[0], y_np[0], figsize=(20,2))
            # # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            # plt.savefig(os.path.join(self.args.log_sample_path, f'image_true.png'))
            # plt.close()

        # Optionally, switch back to training mode after sampling
        self.vt.train()
        self.rt.train()


    def on_train_epoch_end(self):
        sch1, sch2 = self.lr_schedulers()
        # sch.step()
        sch1.step()
        sch2.step()


    def on_validation_epoch_end(self):
        if self.last_validation_batch is not None:
            X = self.last_validation_batch["X"]
            y = self.last_validation_batch["y"]
            indices = self.last_validation_batch["indices"]
            
            """ Snapshot sampling at the end of every epoch """
            if self.current_epoch % self.config.training.snapshot_freq == 0:
                self.eval_latent(X, y, indices)

        self.last_validation_batch = None


    def test_step(self, batch, batch_idx):
        x, y, indices = batch  # Assuming the batch is the input data `x`
        F = self.config.data.Stotal
        n = x.shape[0]

        if self.args.sample:
            x = x.view(self.config.data.S, -1, 1, self.config.data.p, self.config.data.p).to(dtype=torch.float64)
            indices = indices.view(self.config.data.S, -1)

            
            z0 = self.prior.sample(n, self.config.data.S, device=self.device) # S,n,p
            x0 = torch.randn(self.config.data.S, n, 1, self.config.data.p, self.config.data.p, device=self.device) # channel=1
            
            if self.config.flow.cnn:
                z1 = self.rt.decodeS(z0, x, indices=indices)
            else:
                z1 = self.rt.decodeS(z0, x.flatten(start_dim=2), indices=indices)

            if self.config.model.const:
                x1 = self.vt.decodeS(x0, z1)
            else:
                x1 = self.vt.decodeS(x0, z1, indices=indices)

            x1_np = inv_transform(x1).cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p,self.config.data.p))
            z1_np = z1.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))
            y_np = y.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.n))
            x_np = x.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p,self.config.data.p))

            # evaluate L2 distance between generated and truth
            rmse_obs = np.sqrt((((x1_np-x_np)**2).sum((-1))).mean(1)).mean()
            rmse_latent = np.sqrt((((z1_np-y_np)**2).sum((-1))).mean(1)).mean()

            self.log('test_latent_loss', rmse_latent, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
            self.log('test_obs_loss', rmse_obs, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

            self.eval_latent(x, y, indices)
        elif self.args.predict:
            q = np.floor(np.sqrt(F)).astype(np.int32)
            # forward prediction, take the first length xS subsequence as context, then forward prediction
            x = x.view(F, n, 1, self.config.data.p, self.config.data.p).to(dtype=torch.float64)

            z0 = self.prior.sample(n, F, device=self.device) # S,n,p

            x_ = x[:self.config.data.S] # .flatten(start_dim=2)
            x0 = torch.randn(F, n, 1, self.config.data.p, self.config.data.p, device=self.device) # channel=1

            if self.config.flow.cnn:
                z1 = self.rt.predictF(F, z0, x_)
            else:
                z1 = self.rt.predictF(F, z0, x_.flatten(start_dim=2))

            # posterior forward prediction
            x1 = self.vt.decodeS(x0, z1)

            x1_np = inv_transform(x1).cpu().detach().numpy().reshape((-1,F,self.config.data.p,self.config.data.p))
            z1_np = z1.cpu().detach().numpy().reshape((-1,F,self.config.flow.feature_dim))
            y_np = y.cpu().detach().numpy().reshape((-1,F,self.config.data.n))
            x_np = x.cpu().detach().numpy().reshape((-1,F,self.config.data.p,self.config.data.p))

            x1_stack = x1_np[1][:q**2].reshape(q, self.config.data.p*q, self.config.data.p) # first 7**2=49 frames
            x1_imgrid = x1_stack.swapaxes(0, 1).reshape(self.config.data.p * q, self.config.data.p * q)

            x_stack = x_np[1][:q**2].reshape(q, self.config.data.p*q, self.config.data.p) # first 49 frames
            x_imgrid = x_stack.swapaxes(0, 1).reshape(self.config.data.p * q, self.config.data.p * q)

            fig, axes = plt.subplots(2,2, figsize=(12, 5))
            # plt.hist2d(*x.T, bins=64)
            axes[0,0].imshow(x_imgrid, cmap="gray", origin="lower", aspect=.2)
            axes[0,1].imshow(x1_imgrid, cmap="gray", origin="lower", aspect=.2)
            axes[1,0].plot(np.arange(F), y_np[1])
            # axes[2].plot(y1_np[0][:self.q**2,0], y1_np[0][:self.q**2,1])
            axes[1,0].set_xmargin(0)
            axes[1,1].plot(np.arange(F), z1_np[1])
            # axes[1].plot(z1_np[0][:self.q**2,0], z1_np[0][:self.q**2,1])
            axes[1,1].set_xmargin(0)
            
            axes[0,0].set_xlabel("real data")
            axes[0,1].set_xlabel("gen data")
            axes[1,0].set_xlabel("real latent")
            axes[1,1].set_xlabel("gen latent")
            plt.tight_layout()

            # plot_image_sequence_and_trajectory(x1_np[0], z1_np[0], figsize=(20,2))
            # # # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'predict_{F}.png'))
            plt.close()

            # plot_image_sequence_and_trajectory(x_np[0], y_np[0], figsize=(20,2))
            # # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            # plt.savefig(os.path.join(self.args.log_sample_path, f'image_true.png'))
            # plt.close()

        


def plot_image_sequences(input_sequence, output_sequence, titles=None, figsize=(6, 12)):
    """
    Plots two sequences of images side-by-side.

    Args:
        input_sequence (list or np.ndarray or torch.Tensor):
            A list/array containing the sequence of input images.
            Each image should be suitable for plt.imshow (e.g., HxW, HxWxC).
            If PyTorch tensors, expects shape like (N, C, H, W) or (N, H, W).
        output_sequence (list or np.ndarray or torch.Tensor):
            A list/array containing the sequence of output images.
            Must have the same length as input_sequence.
            Each image should be suitable for plt.imshow.
            If PyTorch tensors, expects shape like (N, C, H, W) or (N, H, W).
        titles (list or tuple, optional):
            A list/tuple of two strings for the titles of the columns
            (e.g., ["Input Images", "Output Images"]). Defaults to None.
        figsize (tuple, optional):
            The figure size for the plot. Defaults to (6, 12).
    """
    n_steps = len(input_sequence)
    if n_steps == 0:
        print("Input sequences are empty, nothing to plot.")
        return
    if len(output_sequence) != n_steps:
        raise ValueError("Input and output sequences must have the same length.")

    # Create subplots: n_steps rows, 2 columns
    fig, axes = plt.subplots(n_steps, 2, figsize=figsize, squeeze=False)

    # Set column titles if provided
    if titles and len(titles) == 2:
        axes[0, 0].set_title(titles[0])
        axes[0, 1].set_title(titles[1])

    for i in range(n_steps):
        # --- Process Input Image ---
        img_in = input_sequence[i]
        # Handle PyTorch Tensors
        if isinstance(img_in, torch.Tensor):
            img_in = img_in.detach().cpu().numpy()
        # Handle channel dimension (e.g., C, H, W -> H, W, C or H, W if C=1)
        if img_in.ndim == 3 and img_in.shape[0] in [1, 3]: # Check if first dim is channel
             img_in = np.squeeze(img_in) # Remove channel dim if 1
             if img_in.ndim == 3: # If still 3D (RGB), move channel to last axis
                 img_in = np.transpose(img_in, (1, 2, 0))

        # --- Process Output Image ---
        img_out = output_sequence[i]
        # Handle PyTorch Tensors
        if isinstance(img_out, torch.Tensor):
            img_out = img_out.detach().cpu().numpy()
        # Handle channel dimension
        if img_out.ndim == 3 and img_out.shape[0] in [1, 3]: # Check if first dim is channel
             img_out = np.squeeze(img_out) # Remove channel dim if 1
             if img_out.ndim == 3: # If still 3D (RGB), move channel to last axis
                 img_out = np.transpose(img_out, (1, 2, 0))

        # --- Plotting ---
        # Plot input image
        ax = axes[i, 0]
        ax.imshow(img_in, cmap='gray' if img_in.ndim == 2 else None)
        ax.set_ylabel(f"Step {i}")
        ax.set_xticks([])
        ax.set_yticks([])

        # Plot output image
        ax = axes[i, 1]
        ax.imshow(img_out, cmap='gray' if img_out.ndim == 2 else None)
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout(rect=[0, 0.03, 1, 0.97] if titles else None) # Adjust layout to prevent overlap
    plt.show()


class PENDULUMRunner():
    def __init__(self, args, config):
        self.args = args
        self.config = config
        args.log_sample_path = os.path.join(args.log_path, 'samples')
        os.makedirs(args.log_sample_path, exist_ok=True)
        # os.makedirs(self.args.tb_path)

        self.S = self.config.data.S
        self.F = self.config.data.Stotal
        
        # Initialize models `vt` and `rt`
        if self.config.model.const:

            if not self.config.model.cnn:
                print("using LLKc")
                self.vt = LLKc(
                    self.config.data.p, self.config.flow.feature_dim, self.S, self.F,
                    num_hidden= 100, #48, # 32, # 16, # 100, # 128, # 32, # 128, # 64,
                    hidden_features=[self.config.model.ngf]*self.config.model.depth,
                    # fct=nn.Softplus(),
                    fct=ShiftedTanh(),
                    freqs=10,# 2,
                    )
            else:
                self.vt = LLKcnn(
                    self.config.data.p, self.config.flow.feature_dim, self.S, self.F,
                    num_hidden=32,
                    hidden_features=[self.config.model.ngf]*3,
                    freqs=10,
                    fct=ShiftedTanh(),
                    )
        else:
            if not self.config.model.cnn:
                self.vt = LLK(
                    self.config.data.p, self.config.flow.feature_dim, self.S, self.F, 
                    num_hidden=100,
                    dsemb=2,
                    hidden_features=[self.config.model.ngf]*self.config.model.depth,
                    fct=nn.Softplus(),
                    )
            else:
                raise NotImplementedError

        if self.config.flow.AR:
            self.rt = ARCNF(
                self.config.data.p**2, self.config.flow.feature_dim, self.S, dsemb=self.config.flow.dsemb, 
                num_hidden=128,
                hidden_features=[self.config.flow.ngf]*self.config.flow.depth,
                fct=nn.Tanh(),
                )
        else:
            if self.config.flow.cnn:
                if self.config.flow.fixed:
                    self.rt = fullGauss(
                        self.config.data.p, self.config.flow.feature_dim, self.S, self.F, dsemb=self.config.flow.dsemb, 
                        num_hidden=16,
                        num_layers=1,
                        cnn=True,
                        attention=True,
                        emb_dim=100,
                        freqs=2,
                        hidden_features=[self.config.flow.ngf]*self.config.flow.depth,
                        fct=nn.Tanh(),
                        )
                else:
                    self.rt = fullCNF(
                        self.config.data.p, self.config.flow.feature_dim, self.S, self.F, dsemb=self.config.flow.dsemb, 
                        num_hidden=32,
                        cnn=True,
                        hidden_features=[self.config.flow.ngf]*self.config.flow.depth,
                        fct=nn.Softplus(),
                        )
            else:
                if self.config.flow.fixed:
                    print("using fullGauss")
                    self.rt = fullGauss(
                        self.config.data.p**2, self.config.flow.feature_dim, self.S, self.F, dsemb=self.config.flow.dsemb, 
                        num_hidden=16, #32,#16,# 14,#16,#32, # 16, # 8,
                        num_layers=1, # 3,
                        emb_dim=100, #10, #6, # 8,
                        freqs=2,
                        cnn=False,
                        # rnn=True, #True, #True,
                        attention=True,
                        hidden_features=[self.config.flow.ngf]*self.config.flow.depth,
                        fct=nn.Tanh(),
                        )

                else:
                    print("using fullCNF")
                    self.rt = fullCNF(
                        self.config.data.p**2, self.config.flow.feature_dim, self.S, self.F, dsemb=self.config.flow.dsemb, 
                        num_hidden=16, # 16, # 16, #20 # it seems that when num_hidden too small it does not meaningfully capture the trend (collapse to horizontal line)
                        num_layers=3,
                        emb_dim=4, # 4, # 8,
                        freqs=2,
                        cnn=False,
                        hidden_features=[self.config.flow.ngf]*self.config.flow.depth,
                        )
        
        self.prior = LatentDynamicalSystem(self.config.flow.feature_dim)
        
        # self.dataset, self.val_dataset = get_pendulum(self.config.data.samplesize, self.config.data.p, self.S, "data", gen=self.config.data.gen, plot=False)
        # Define the ModelCheckpoint callback
        self.checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Metric to monitor
            dirpath=self.args.log_path,  # Directory where checkpoints will be saved
            filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
            save_top_k=1,  # Only save the best model based on val_loss
            mode='min'  # Minimize the validation loss
        )

        if torch.cuda.is_available():
            accelerator='gpu'
            strategy="ddp"
            devices="auto"
        else:
            accelerator='cpu'
            devices="auto"
            strategy = "auto"

        plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss.png'), update_interval=1)
        plot_llk_callback = PlotLogLikelihoodCallback(save_path=self.args.log_sample_path, log_keys=("val_log_lik", "val_log_post"))
        plot_gradnorm_callback = GradientNormPlotCallback(save_path=os.path.join(self.args.log_sample_path, f'gradnorm.png'))
        self.trainer = pl.Trainer(
            # gradient_clip_val=self.config.optim.clip_value,
            max_epochs=self.config.training.n_epochs, 
            # accelerator='gpu',
            accelerator = accelerator,
            devices=devices,
            strategy=strategy,
            callbacks=[plot_loss_callback, plot_gradnorm_callback, self.checkpoint_callback],
        )


        
    def train(self):
        self.length = self.config.data.Stotal # 100 # self.S + 1*(self.config.training.batch_size-1)
        train_dataloader, val_dataloader = get_pendulum_dataloader(self.config.data.samplesize, self.config.data.test_samplesize, 
            self.config.data.p, self.length, "data", 
            window_size=self.S, stride=1, batch_size=self.config.training.batch_size,
            batch_from_same_trajectory=self.config.data.samebatch,
            state_dim=self.config.data.n, gen=self.config.data.gen, shuffle=True, num_workers=self.config.data.num_workers)

        # Initialize the Lightning model
        model = FlowMatchingLightningModule(self.vt, self.rt, self.prior, self.config, self.args)
        # Run the training loop
        if not self.args.resume_training:
            ckpt_path = None
        else:
            ckpt_path = self.checkpoint_callback.best_model_path

        self.trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)


        # Optionally, run the test loop if a test set is provided
        # trainer.test(model, test_dataloader)

    def sample(self):
        self.length = self.config.data.Stotal 
        train_dataloader, val_dataloader = get_pendulum_dataloader(self.config.data.samplesize, self.config.data.test_samplesize, 
            self.config.data.p, self.length, "data", 
            window_size=self.S, stride=1, batch_size=self.config.training.batch_size,
            batch_from_same_trajectory=self.config.data.samebatch,
            state_dim=self.config.data.n, gen=self.config.data.gen, shuffle=True, num_workers=self.config.data.num_workers)


        # Initialize the Lightning model
        model = FlowMatchingLightningModule(self.vt, self.rt, self.prior, self.config, self.args)
       
        ckpt_path = self.checkpoint_callback.best_model_path

        self.trainer.test(model, val_dataloader, ckpt_path=ckpt_path)

    def predict(self):
        self.length = self.config.data.Stotal 
        _, test_dataloader = get_pendulum_dataloader(1, 2, 
            self.config.data.p, self.length, "data", 
            window_size=self.F, stride=1, batch_size=self.config.training.batch_size,
            batch_from_same_trajectory=self.config.data.samebatch,
            state_dim=self.config.data.n, gen=self.config.data.gen, shuffle=True, num_workers=self.config.data.num_workers)

        # Initialize the Lightning model
        model = FlowMatchingLightningModule(self.vt, self.rt, self.prior, self.config, self.args)
        
        ckpt_path = self.checkpoint_callback.best_model_path

        self.trainer.test(model, test_dataloader, ckpt_path=ckpt_path)

