# Standard library imports
from argparse import ArgumentParser
import os, sys
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

# Third party imports
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torchdiffeq import odeint
import matplotlib.pyplot as plt
import numpy as np 
from tqdm import tqdm
from torch.distributions import Normal
from sklearn.metrics import r2_score

# local application imports
from lag_caVAE.lag import Lag_Net
from lag_caVAE.leap import Leap_Net
from lag_caVAE.leap import TF_Block_EXP_Residual_TV_2
from lag_caVAE.leap import TransformerEncoderLayerCategoricalsCatPos
from lag_caVAE.leap import TransformerEncoderLayer_v2_CatPos
from lag_caVAE.leap import TransformerEncoderLayer_v2_CategoricalsCatPos
from lag_caVAE.leap import EstimatorNetwork
from lag_caVAE.leap import TransformerEncoderLayerCategoricals
from lag_caVAE.leap import TransformerEncoderLayer_v2_Categoricals
from lag_caVAE.leap import EstimatorNetwork_v2
from lag_caVAE.leap import EstimatorNetwork_NoScale
from lag_caVAE.leap import MLP_Mag
from lag_caVAE.leap import P_Neural_TIME_MultipleParameter
from lag_caVAE.leap import TransformerEncoderLayer
from lag_caVAE.leap import TransformerEncoderLayer_v2
from lag_caVAE.nn_models import MLP_Encoder, MLP, MLP_Decoder, PSD, Encoder
from hyperspherical_vae.distributions import VonMisesFisher
from hyperspherical_vae.distributions import HypersphericalUniform
from utils import arrange_data, from_pickle, my_collate, ImageDataset, HomoImageDataset

# Set the prediction length
T_pred = 100
# Set the first n index to compute the loss values
Loss_first_index = 100
# Set the first n FFT index (from lowest to highest)
cutoff_freq_input = 15
# Set the dataset that we want to use
dataset_type = 14 # 0->mu+L; 1->L; 2->mu; 3->mu (large); 4->mu (samll, random); 5->mu (smaller, random)
# Learning rate
lr = 1e-3
# Define the size of the batch samples
num_batch = 32 # (gap = 20 -> 50 batches)
# Gradient clip
gradient_clip = 1.0
# Define the non-linearity for the q net and the recons net
MLP_Encoder_nonLinear = 'tanh' # tanh; elu; softplus; relu
# Define the model we use: 0->FFT input to ODE Net; 1->Time Series data to ODE Net
MODEL_TYPE = 0
# Set the plotting enable
Plot_enable = 1
# Set the attention
enable_attn = 0
# Set model_variant
model_variant = 4 # 1->pos and vel are all attention; 2->pos is MLP; and vel is attention; 4->pos and vel combined together
# Set model_variant for attention
# 0->attnetion; 1->attention v2 with extra layer norm and skip connection; 2->based on 1, additional categoricals
model_variant_attn = 4 
# Set if using physics or not
enable_physics = 1
# Set the weight of the reconstruction loss function
weight_recons = 1
# Set the loss function that aligns the states between the ODE solver and the encoder
Time_loss_weight = 1
# Set the FFT cosine theta
FFT_loss_weight = 0
# Set frame velocity
velocity_loss_enable = 0.0
# Define the gap_interval
gap_interval = 100
# Define source mask length 
att_len = 7 #value->size of mask: 1->3; 2->5; 7->15; 12->25; 50 -> 101
# Attention model parameters
d_model_attn = 300 # the final size after doing CNN and flattening
nhead_attn = 10
d_middle_attn = 100 # no use here
d_final_attn = 6 # no use here
dropout_attn = 0.0
pos_en_scale_attn = 1.0
attn_nonlinearity = 'relu'
attn_nonlinearity_1 = 'tanh'
cnn_pooling = 'max' #or 'avg'
# for estimator network
# Define nonlinear for the esitmator network
nonlinearity = 0 # 0 for tanh; 1 for softplut (donot work well); 2 for elu
# Definin decoder network
dec_size = 100

NN_parameter = {'d_model': d_model_attn,\
                'nhead': nhead_attn,\
                'd_middle': d_middle_attn,\
                'd_final': d_final_attn,\
                'dropout': dropout_attn,\
                'pos_en_scale': pos_en_scale_attn,\
                'nonlinearity': attn_nonlinearity,\
                'nonlinearity_1': attn_nonlinearity_1,\
                'pooling': cnn_pooling} # default pos_en_scale is 1.0; for d_final, using Gaussain: d_final=5
gpu_number = 0

# Set the samll value to prevent NaN
small_value = 1e-5
# Set the simulation Hz
Hz = 20
# Define NN input layer
NN_input_layer = [50,50]
# Define the time that the training loss will alternate
training_loss_interval = 15
mean_over_everything = 0
save_dir = '_Da'    + str(dataset_type) 


seed_everything(42)

class Model(pl.LightningModule):

    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()

        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')
        self.loss_fn_mean = torch.nn.MSELoss()
        self.size_image = 64*64
        self.input_dim  = 4 #(cos,sin,theta_dot,u)
        self.plu_output = 0.55
        # For plotting purpose
        self.count = 0
        self.cutoff_index_input = cutoff_freq_input

        # Define encoder and decoder
        if enable_attn:
            if model_variant == 0:
                self.recog_q_net = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                                           nhead=NN_parameter['nhead'], \
                                                           d_middle=NN_parameter['d_middle'], \
                                                           d_final=3, \
                                                           dropout=NN_parameter['dropout'], \
                                                           max_len=self.T_pred+1,\
                                                           pos_en_scale=NN_parameter['pos_en_scale'], \
                                                           activation=NN_parameter['nonlinearity'],\
                                                           activation_1=NN_parameter['nonlinearity_1'])
                self.recog_q_net_velocity = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                                           nhead=NN_parameter['nhead'], \
                                                           d_middle=NN_parameter['d_middle'], \
                                                           d_final=2, \
                                                           dropout=NN_parameter['dropout'], \
                                                           max_len=self.T_pred+1,\
                                                           pos_en_scale=NN_parameter['pos_en_scale'], \
                                                           activation=NN_parameter['nonlinearity'],\
                                                           activation_1=NN_parameter['nonlinearity_1'])
            elif model_variant == 1:
                # This attention is reponsible for the locations
                if model_variant_attn == 0:
                    self.recog_q_net = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                                               nhead=NN_parameter['nhead'], \
                                                               d_middle=NN_parameter['d_middle'], \
                                                               d_final=3, \
                                                               dropout=NN_parameter['dropout'], \
                                                               max_len=self.T_pred+1,
                                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                                               activation=NN_parameter['nonlinearity'],\
                                                               activation_1=NN_parameter['nonlinearity_1'],\
                                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 1:
                    self.recog_q_net = TransformerEncoderLayer_v2(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])  
                elif model_variant_attn == 2:
                    self.recog_q_net = TransformerEncoderLayer_v2_Categoricals(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])  
                # This attnetion is reponsible for the velocity
                if model_variant_attn == 0:
                    self.recog_q_net_velocity = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 1:
                    self.recog_q_net_velocity = TransformerEncoderLayer_v2(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])   
                elif model_variant_attn == 2:
                    self.recog_q_net_velocity = TransformerEncoderLayer_v2_Categoricals(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])  
            elif model_variant == 2:
                # This attention is reponsible for the locations
                #self.recog_q_net = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                #                                           nhead=NN_parameter['nhead'], \
                #                                           d_middle=NN_parameter['d_middle'], \
                #                                           d_final=3, \
                #                                           dropout=NN_parameter['dropout'], \
                #                                           max_len=self.T_pred+1,
                #                                           pos_en_scale=NN_parameter['pos_en_scale'], \
                #                                           activation=NN_parameter['nonlinearity'],\
                #                                           activation_1=NN_parameter['nonlinearity_1'])
                self.recog_q_net = MLP_Encoder(self.size_image, 300, 3, nonlinearity=MLP_Encoder_nonLinear)
                # This attnetion is reponsible for the velocity
                if model_variant_attn == 0:
                    self.recog_q_net_velocity = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 1:
                    self.recog_q_net_velocity = TransformerEncoderLayer_v2(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])  
                elif model_variant_attn == 2:
                    self.recog_q_net_velocity = TransformerEncoderLayer_v2_Categoricals(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])           
            elif model_variant == 3:
                # This attention is reponsible for the locations
                #self.recog_q_net = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                #                                           nhead=NN_parameter['nhead'], \
                #                                           d_middle=NN_parameter['d_middle'], \
                #                                           d_final=3, \
                #                                           dropout=NN_parameter['dropout'], \
                #                                           max_len=self.T_pred+1,
                #                                           pos_en_scale=NN_parameter['pos_en_scale'], \
                #                                           activation=NN_parameter['nonlinearity'],\
                #                                           activation_1=NN_parameter['nonlinearity_1'])
                self.recog_q_net = MLP_Encoder(self.size_image, 300, 3, nonlinearity=MLP_Encoder_nonLinear)
                # This attnetion is reponsible for the velocity
                # TransformerEncoderLayerCategoricalsCatPos
                self.recog_q_net_velocity = TransformerEncoderLayerCategoricalsCatPos(d_model=NN_parameter['d_model'],\
                                           nhead=NN_parameter['nhead'], \
                                           d_middle=NN_parameter['d_middle'], \
                                           d_final=2, \
                                           dropout=NN_parameter['dropout'], \
                                           max_len=self.T_pred+1,
                                           pos_en_scale=NN_parameter['pos_en_scale'], \
                                           activation=NN_parameter['nonlinearity'],\
                                           activation_1=NN_parameter['nonlinearity_1'])
            elif model_variant == 4:
                if model_variant_attn == 0:
                    self.recog_q_net_state = TransformerEncoderLayer(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=6, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 1:
                    # A single NN that process all the information
                    self.recog_q_net_state = TransformerEncoderLayer_v2(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3+3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 2:
                    # A single NN that process all the information
                    self.recog_q_net_state = TransformerEncoderLayer_v2_Categoricals(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3+3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 3:
                    # A single NN that process all the information
                    self.recog_q_net_state = TransformerEncoderLayer_v2_CategoricalsCatPos(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3+3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
                elif model_variant_attn == 4:
                    # A single NN that process all the information
                    self.recog_q_net_state = TransformerEncoderLayer_v2_CatPos(d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3+3, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])

        else:
            self.recog_q_net = MLP_Encoder(self.size_image, 300, 3, nonlinearity=MLP_Encoder_nonLinear)
        
        self.obs_net = MLP_Encoder(1, dec_size, self.size_image, nonlinearity=MLP_Encoder_nonLinear)

        self.context_size = 10
        context_vector_size = 10
        self.context_network  = MLP_Mag(input_size=2*self.context_size,output_size=context_vector_size,input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=1.0)  
        self.dynamics_network = MLP_Mag(input_size=2+context_vector_size,output_size=2,input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=1.0)  
        self.backward_network = MLP_Mag(input_size=2+context_vector_size,output_size=2,input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=1.0)  
        
        self.train_dataset = None
        self.non_ctrl_ind = 1

        # Generate the src mask 
        self.SRC_MAS_V = []
        for i in np.arange(0,T_pred-cutoff_freq_input ,gap_interval):
            self.SRC_MAS_V.append(self.src_mask(T_pred+1-i))

        self.training_loss_flag = 0

    def train_dataloader(self):
        if self.hparams.homo_u:
            # must set trainer flag reload_dataloaders_every_epoch=True
            if self.train_dataset is None:
                self.train_dataset = HomoImageDataset(self.data_path, self.hparams.T_pred)
            if self.current_epoch < 1000:
                # feed zero ctrl dataset and ctrl dataset in turns
                if self.current_epoch % 2 == 0:
                    u_idx = 0
                else:
                    u_idx = self.non_ctrl_ind
                    self.non_ctrl_ind += 1
                    if self.non_ctrl_ind == 9:
                        self.non_ctrl_ind = 1
            else:
                u_idx = self.current_epoch % 9
            self.train_dataset.u_idx = u_idx
            self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
            return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate)
        else:
            # This is our default setting since all the u is zero
            train_dataset = ImageDataset(self.data_path, self.hparams.T_pred)
            self.t_eval = torch.from_numpy(train_dataset.t_eval)

            return DataLoader(train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate)

    def angle_vel_est(self, q0_m_n, q1_m_n, delta_t):
        delta_cos = q1_m_n[:,0:1] - q0_m_n[:,0:1]
        delta_sin = q1_m_n[:,1:2] - q0_m_n[:,1:2]
        q_dot0 = - delta_cos * q0_m_n[:,1:2] / delta_t + delta_sin * q0_m_n[:,0:1] / delta_t
        return q_dot0

    def angle_vel_est_euler(self, q0, delta_t):

        T = q0.shape[0]
        theta_dot = (q0[1:T]-q0[0:T-1]) / delta_t

        return theta_dot.unsqueeze(1)

    def get_theta_inv(self, cos, sin, x, y, bs=None):
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos ; theta[:, 0, 1] += -sin ; theta[:, 0, 2] += - x * cos + y * sin
        theta[:, 1, 0] += sin ; theta[:, 1, 1] += cos ;  theta[:, 1, 2] += - x * sin - y * cos
        return theta

    def encode(self, batch_image):

        '''
        NaN_flag = 0
        for name, param in self.recog_q_net_velocity.named_parameters():
            if param.requires_grad:
                if torch.isnan(param.data).any():
                    print('parameter of recog_q_net_velocity is NaN!!')
                    NaN_flag = 1

        for name, param in self.obs_net.named_parameters():
            if param.requires_grad:
                if torch.isnan(param.data).any():
                    print('parameter of obs_net is NaN!!')
                    NaN_flag = 1

        if NaN_flag == 1:
            sys.exit('parameter is NaN. Halt forcely!!')
        '''
        q_m_logv = self.recog_q_net(batch_image+1e-5)
        q_m, q_logv = q_m_logv.split([2, 1], dim=1)
        q_m_n = q_m / (q_m.norm(dim=-1, keepdim=True) + small_value)
        q_v = F.softplus(q_logv) + 1

        return q_m, q_v, q_m_n


    def src_mask(self, dim):
        #https://discuss.pytorch.org/t/how-to-add-padding-mask-to-nn-transformerencoder-module/63390/2
        mask = torch.zeros(dim,dim).float() + float('-inf')
        # Define attend half range
        for i in range(dim):
            min_ = max(0, i-att_len)
            max_ = min(dim, i+att_len+1)
            for j in range(min_,max_):
                mask[i,j] = 0.0
        return mask

    def print_parameter(self):
        
        print('-----h: recog_q_net-----')
        for name, param in self.recog_q_net.named_parameters():
            print(name,param)
            break
        
        print('-----h: recog_q_net_velocity-----')
        for name, param in self.recog_q_net_velocity.named_parameters():            
            print(name,param)
            break
                
        print('-----f: MLP_Spec_mu-----')
        for name, param in self.MLP_Spec_mu.named_parameters():
            print(name,param)
            break

        print('-----g: obs_net-----')
        for name, param in self.obs_net.named_parameters():
            print(name,param)
            break
        

        


    def encode_self_attention(self, batch_image, src_mask_v):

        

        if model_variant==0:
            # X is in the shape of (T,64*64)
            batch_image = batch_image+1e-5
            X_attn = batch_image.unsqueeze(1)
            # Here, attn_output is of size [101, 1, 6]
            attn_output, attn_output_weight = self.recog_q_net(X_attn,src_mask=src_mask_v.to(self.device))
            # Add the source mask to make sure it attends to the right position
            attn_output_velocity, attn_output_weight_velocity = self.recog_q_net_velocity(X_attn,src_mask=src_mask_v.to(self.device))
            # Here, attn_output is of size [101, 6]
            attn_output = attn_output.squeeze()
            attn_output_velocity = attn_output_velocity.squeeze()

            q_m_loc, q_logv_loc = attn_output.split([2, 1], dim=1)
            q_m_vel, q_v_vel = attn_output_velocity.split([1, 1], dim=1)
            # Location
            q_m_loc_n = q_m_loc / (q_m_loc.norm(dim=-1, keepdim=True) + small_value)
            q_v_loc = F.softplus(q_logv_loc) + 1
            # Velocity
            q_m_vel = torch.tanh(q_m_vel) * 7
            q_v_vel = torch.sigmoid(q_v_vel) * 0.05 + 0.0001

            return q_m_loc, q_v_loc, q_m_loc_n, q_m_vel, q_v_vel, attn_output_weight, attn_output_weight_velocity

        elif model_variant==1 :
            # X is in the shape of (T,64*64)
            batch_image = batch_image+1e-5
            X_attn = batch_image.unsqueeze(1)
            # Here, attn_output is of size [101, 1, 3]
            # For position
            attn_output, attn_output_weight = self.recog_q_net(X_attn)
            # For velocity
            # Add the source mask to make sure it attends to the right position
            attn_output_velocity, attn_output_weight_velocity = self.recog_q_net_velocity(X_attn,src_mask=src_mask_v.to(self.device))
            # Here, attn_output is of size [101, 3]
            attn_output = attn_output.squeeze()
            # Here, attn_output_velocity is of size [101, 3]
            attn_output_velocity = attn_output_velocity.squeeze()

            q_m_loc, q_logv_loc = attn_output.split([2, 1], dim=1)
            q_m_vel, q_logv_vel = attn_output_velocity.split([2, 1], dim=1)
            # Location
            q_m_loc_n = q_m_loc / (q_m_loc.norm(dim=-1, keepdim=True) + small_value)
            q_v_loc = F.softplus(q_logv_loc) + 1
            # Velocity
            q_m_vel_n = q_m_vel / (q_m_vel.norm(dim=-1, keepdim=True) + small_value)
            q_v_vel = F.softplus(q_logv_vel) + 1

            return q_m_loc, q_v_loc, q_m_loc_n, \
                   q_m_vel, q_v_vel, q_m_vel_n, \
                   attn_output_weight, attn_output_weight_velocity

        elif model_variant==2 :
            # X is in the shape of (T,64*64)
            batch_image = batch_image+1e-5
            X_attn = batch_image.unsqueeze(1)
            # Here, attn_output is of size [101, 1, 3]
            # For position
            attn_output_weight = []
            attn_output = self.recog_q_net(X_attn)
            # For velocity
            # Add the source mask to make sure it attends to the right position
            attn_output_velocity, attn_output_weight_velocity = self.recog_q_net_velocity(X_attn,src_mask=src_mask_v.to(self.device))
            # Here, attn_output is of size [101, 3]
            attn_output = attn_output.squeeze()
            # Here, attn_output_velocity is of size [101, 3]
            attn_output_velocity = attn_output_velocity.squeeze()

            q_m_loc, q_logv_loc = attn_output.split([2, 1], dim=1)
            q_m_vel, q_logv_vel = attn_output_velocity.split([2, 1], dim=1)
            # Location
            q_m_loc_n = q_m_loc / (q_m_loc.norm(dim=-1, keepdim=True) + small_value)
            q_v_loc = F.softplus(q_logv_loc) + 1
            # Velocity
            q_m_vel_n = q_m_vel / (q_m_vel.norm(dim=-1, keepdim=True) + small_value)
            q_v_vel = F.softplus(q_logv_vel) + 1

            return q_m_loc, q_v_loc, q_m_loc_n, \
                   q_m_vel, q_v_vel, q_m_vel_n, \
                   attn_output_weight, attn_output_weight_velocity

        elif model_variant==3 :
            # X is in the shape of (T,64*64)
            batch_image = batch_image+1e-5
            X_attn = batch_image.unsqueeze(1)
            # Here, attn_output is of size [101, 1, 3]
            # For position
            attn_output_weight = []
            attn_output = self.recog_q_net(X_attn)
            # For velocity
            # Add the source mask to make sure it attends to the right position
            attn_output_velocity, attn_output_weight_velocity = self.recog_q_net_velocity(X_attn,src_mask=src_mask_v.to(self.device))
            # Here, attn_output is of size [101, 3]
            attn_output = attn_output.squeeze()
            # Here, attn_output_velocity is of size [101, 3]
            attn_output_velocity = attn_output_velocity.squeeze()

            q_m_loc, q_logv_loc = attn_output.split([2, 1], dim=1)

            # Location
            q_m_loc_n = q_m_loc / (q_m_loc.norm(dim=-1, keepdim=True) + small_value)
            q_v_loc = F.softplus(q_logv_loc) + 1

            return q_m_loc, q_v_loc, q_m_loc_n, \
                   attn_output_velocity, \
                   attn_output_weight, attn_output_weight_velocity

        elif model_variant == 4 :

            # X is in the shape of (T,64*64)
            batch_image = batch_image+1e-5
            X_attn = batch_image.unsqueeze(1)
            # Here, attn_output is of size [101, 1, 3]
            # For position and velocity
            # Add the source mask to make sure it attends to the right position
            attn_output, attn_output_weight = self.recog_q_net_state(X_attn,src_mask=src_mask_v.to(self.device))
            # Here, attn_output_velocity is of size [101, 6]
            attn_output = attn_output.squeeze()

            q_m_loc, q_logv_loc, q_m_vel, q_logv_vel = attn_output.split([2, 1, 2, 1], dim=1)
            # Location
            q_m_loc_n = q_m_loc / (q_m_loc.norm(dim=-1, keepdim=True) + small_value)
            q_v_loc = F.softplus(q_logv_loc) + 1
            # Velocity
            q_m_vel_n = q_m_vel / (q_m_vel.norm(dim=-1, keepdim=True) + small_value)
            q_v_vel = F.softplus(q_logv_vel) + 1

            return q_m_loc, q_v_loc, q_m_loc_n, \
                   q_m_vel, q_v_vel, q_m_vel_n, \
                   attn_output_weight, attn_output_weight

    def forward(self, X, u, S, TIME_INDEX, src_mask_v):

        '''
        for param in self.recog_q_net.parameters():
            param.requires_grad = True
        for param in self.recog_q_net_velocity.parameters():
            param.requires_grad = True
        for param in self.MLP_Spec_mu.parameters():
            param.requires_grad = True
        for param in self.obs_net.parameters():
            param.requires_grad = True   
        '''
        [T, self.bs, d, d] = X.shape
        #T = len(self.t_eval)

        x_enc_list = []
        x_enc_frameV_list = []
        x_sim_list = []
        x_sim_backward_list = []

        Enc_theta_FFT_list = []
        Enc_theta_dot_FFT_list = []
        ODE_theta_FFT_list = []
        ODE_theta_dot_FFT_list = []
        Attn_output_weight_list = []
        Attn_output_weight_velocity_list = []

        #for batch_ii in tqdm(range(self.bs)):
        for batch_ii in range(self.bs):

            u = torch.zeros((T,1)).to(self.device)

            # =======Encode=======
            # Get the mean and the variance of the distribution
            if enable_attn:
                if model_variant==0:
                    self.q0_m, self.q0_v, self.q0_m_n, \
                            self.q0_dot_m, self.q0_dot_v, \
                                self.attn_output_weight, self.attn_output_weight_velocity \
                                    = self.encode_self_attention(X[:,batch_ii,:,:].reshape(T, d*d),src_mask_v)
                elif model_variant==3:
                    self.q0_m, self.q0_v, self.q0_m_n, \
                            self.q0_dot_m, \
                                self.attn_output_weight, self.attn_output_weight_velocity \
                                    = self.encode_self_attention(X[:,batch_ii,:,:].reshape(T, d*d),src_mask_v)
                else:
                    self.q0_m, self.q0_v, self.q0_m_n, \
                        self.q0_dot_m, self.q0_dot_v, self.q0_dot_m_n, \
                            self.attn_output_weight, self.attn_output_weight_velocity \
                                = self.encode_self_attention(X[:,batch_ii,:,:].reshape(T, d*d),src_mask_v)
                
                if model_variant == 1:
                    Attn_output_weight_list.append(self.attn_output_weight[0])
                Attn_output_weight_velocity_list.append(self.attn_output_weight_velocity[0])
            else:
                self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[:,batch_ii,:,:].reshape(T, d*d))

            # Sample mean and the variance
            self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v) 
            self.P_q = HypersphericalUniform(1, device=self.device)
            self.q0 = self.Q_q.rsample().to(self.device) # bs, 2 = cos\theta and sin\theta instead of \theta

            while torch.isnan(self.q0).any():
                self.q0 = self.Q_q.rsample().to(self.device) # a bad way to avoid nan

            if enable_attn:
                if model_variant==0:
                    # Sample mean and the variance
                    self.Q_dot_q = Normal(self.q0_dot_m, self.q0_dot_v)
                    self.P_normal = Normal(torch.zeros_like(self.q0_dot_m), torch.ones_like(self.q0_dot_v))
                    self.q_dot0 = self.Q_dot_q.rsample().to(self.device) # bs, 2 = cos\theta and sin\theta instead of \theta
                    self.q_dot0 = self.q_dot0[0:T-1]

                    # Compute the velocity using finit element
                    # This is achieved by comparing two frames
                    self.q_dot0_compareFrame = self.angle_vel_est(self.q0_m_n[0:T-1], self.q0_m_n[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)
                elif model_variant == 3:
                    self.q_dot0 = self.q0_dot_m[0:T-1].unsqueeze(1) # bs, 2 = cos\theta and sin\theta instead of \theta
                    self.q_dot0_compareFrame = self.angle_vel_est(self.q0_m_n[0:T-1], self.q0_m_n[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)
                else:
                    # Using attention to estimate the velocity
                    # Sample mean and the variance
                    self.Q_dot_q = VonMisesFisher(self.q0_dot_m_n, self.q0_dot_v) 
                    self.q_dot0 = self.Q_dot_q.rsample().to(self.device) # bs, 2 = cos\theta and sin\theta instead of \theta
                    while torch.isnan(self.q_dot0).any():
                        self.q_dot0 = self.Q_dot_q.rsample().to(self.device) # a bad way to avoid nan
                    # Trim it make it one time step smaller. And the output is size of [T,2], 2 is for cos/sin
                    self.q_dot0 = self.q_dot0[0:T-1,0].unsqueeze(1)

                    # Make it in the resonable scale [-7,7]
                    self.q_dot0 = self.q_dot0 * 7
                    
                    # Compute the velocity using finit element
                    # This is achieved by comparing two frames
                    self.q_dot0_compareFrame = self.angle_vel_est(self.q0_m_n[0:T-1], self.q0_m_n[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)
            else:
                # Estimate velocity using finit element
                self.q_dot0 = self.angle_vel_est(self.q0_m_n[0:T-1], self.q0_m_n[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)
                self.q_dot0_compareFrame = self.angle_vel_est(self.q0_m_n[0:T-1], self.q0_m_n[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)

            # Estimate euler velocity
            #self.q_dot0 = self.angle_vel_est_euler(torch.atan2(self.q0[:,1],self.q0[:,0]) + np.pi, \
            #                                       self.t_eval[1]-self.t_eval[0]).to(self.device)

            # predict
            z0_u = torch.cat((self.q0[0:T-1], self.q_dot0, u[0:T-1]), dim=1) #torch.Size([simulation_length, 4])
            x_enc_list.append(z0_u)
            x_enc_frameV_list.append( torch.cat((self.q0[0:T-1], self.q_dot0_compareFrame, u[0:T-1]), dim=1))

            #if batch_ii == 0:
            #    print('before ode (cos,sin,v,u):',z0_u[0:10,:],'index:',batch_ii,z0_u.shape) 
            # Transfer to the theta, theta_dot state
            z0_u = z0_u.reshape((1,-1))

            # This is form: atan2(y=sin, x=cos)
            theta = torch.zeros((1,1)).to(self.device) + torch.atan2(z0_u[:,1],z0_u[:,0]) + np.pi
            theta_dot =  torch.zeros((1,1)).to(self.device) + z0_u[0,2]
            # Append with the inital state
            s_init = torch.cat((theta, theta_dot), dim=1).to(self.device)
            z0_u = torch.cat((s_init.to(self.device),z0_u),dim=1)

            self.t_eval_ = self.t_eval[TIME_INDEX:]

            # Predict the next states
            z0_u = z0_u[:,2:]
            theta = torch.atan2(z0_u[:,1::4],z0_u[:,0::4]) + np.pi
            theta_dot = z0_u[:,2::4]
            zT_u = torch.cat((theta,theta_dot),axis=0)
            zT_u = zT_u.T
            zT_u_forward_temp = []
            zT_u_backward_temp = []
            for idx in range(0,(T-1)-(self.context_size+2)+1):
                context_vector = self.context_network(zT_u[idx:idx+self.context_size,:].flatten())
                context_vector_dynamics = torch.cat((context_vector,zT_u[idx+self.context_size]))
                next_state = self.dynamics_network(context_vector_dynamics)
                context_vector_backward = torch.cat((context_vector,zT_u[idx+self.context_size+1]))
                current_state = self.backward_network(context_vector_backward)
                zT_u_forward_temp.append(next_state)
                zT_u_backward_temp.append(current_state)

            zT_u_forward_temp = torch.stack(zT_u_forward_temp, axis=0)
            zT_u_forward_temp.retain_grad()  
            zT_u_backward_temp = torch.stack(zT_u_backward_temp, axis=0)
            zT_u_backward_temp.retain_grad() 

            # Get the state
            zT_u = zT_u_forward_temp
            # get cosine
            z_cos = torch.cos(zT_u[:,0] - np.pi).unsqueeze(1)
            z_sin = torch.sin(zT_u[:,0] - np.pi).unsqueeze(1)
            z_vel = zT_u[:,1].unsqueeze(1).to(self.device)
            z_u   = torch.zeros((T-1)-(self.context_size+2)+1).unsqueeze(1).to(self.device)
            zT_u  = torch.cat((z_cos, z_sin, z_vel, z_u),dim=1) # T,4

            x_sim_list.append(zT_u) # T, bs, 4

            # Get the state
            zT_u = zT_u_backward_temp
            # get cosine
            z_cos = torch.cos(zT_u[:,0] - np.pi).unsqueeze(1)
            z_sin = torch.sin(zT_u[:,0] - np.pi).unsqueeze(1)
            z_vel = zT_u[:,1].unsqueeze(1).to(self.device)
            z_u   = torch.zeros((T-1)-(self.context_size+2)+1).unsqueeze(1).to(self.device)
            zT_u  = torch.cat((z_cos, z_sin, z_vel, z_u),dim=1) # T,4

            x_sim_backward_list.append(zT_u) # T, bs, 4

        # Stack the data, and retain_grad() after using stack function
        x_sim_list = torch.stack(x_sim_list, axis=0)
        x_sim_list.retain_grad()
        x_sim_list = x_sim_list.permute(1,0,2)

        x_sim_backward_list = torch.stack(x_sim_backward_list, axis=0)
        x_sim_backward_list.retain_grad()
        x_sim_backward_list = x_sim_backward_list.permute(1,0,2)

        x_enc_list = torch.stack(x_enc_list, axis=0)
        x_enc_list.retain_grad()
        x_enc_list = x_enc_list.permute(1,0,2)

        x_enc_frameV_list = torch.stack(x_enc_frameV_list, axis=0)
        x_enc_frameV_list.retain_grad()
        x_enc_frameV_list = x_enc_frameV_list.permute(1,0,2)

        
        self.qT, self.q_dotT, _ = x_sim_list.split([2, 1, 1], dim=-1)
        self.qT = self.qT.contiguous()
        self.qT = self.qT.view(((T-1)-(self.context_size+2)+1)*self.bs, 2)
        
        self.qT_backward, self.q_dotT_backward, _ = x_sim_backward_list.split([2, 1, 1], dim=-1)
        self.qT_backward = self.qT_backward.contiguous()
        self.qT_backward = self.qT_backward.view(((T-1)-(self.context_size+2)+1)*self.bs, 2)

        self.qT_enc, self.q_dotT_enc, _ = x_enc_list.split([2, 1, 1], dim=-1)
        self.qT_enc = self.qT_enc.contiguous()
        self.qT_enc = self.qT_enc.view((T-1)*self.bs, 2)

        self.qT_enc_frameV, self.q_dotT_enc_frameV, _ = x_enc_frameV_list.split([2, 1, 1], dim=-1)
        self.qT_enc_frameV = self.qT_enc_frameV.contiguous()
        self.qT_enc_frameV = self.qT_enc_frameV.view((T-1)*self.bs, 2)

        # =======Decode=======
        # Here we want to get the content of the pole
        ones = torch.ones_like(self.qT[:, 0:1])
        self.content = self.obs_net(ones)
        # Get the theta information to place the pole
        theta = self.get_theta_inv(self.qT[:, 0], self.qT[:, 1], 0, 0, bs=((T-1)-(self.context_size+2)+1)*self.bs) # cos , sin 
        grid = F.affine_grid(theta, torch.Size((((T-1)-(self.context_size+2)+1)*self.bs, 1, d, d)))
        # Get the reconstruction images
        self.Xrec = F.grid_sample(self.content.view(((T-1)-(self.context_size+2)+1)*self.bs, 1, d, d), grid)
        self.Xrec = self.Xrec.view([(T-1)-(self.context_size+2)+1, self.bs, d, d])

        # Plot something to track the performance
        if self.count % 50 == 0 and Plot_enable == True:
            for tt in range((T-1)-(self.context_size+2)+1):
                fig1 = plt.figure(constrained_layout=False, figsize=(10,4))
                gs = fig1.add_gridspec(1, 2, width_ratios=[1.0,1.0])
                ax = fig1.add_subplot(gs[0, 0])
                from torchvision import utils
                grid = utils.make_grid(X[tt+self.context_size+1, 0].view(-1, 1, 64, 64))
                X_ = np.array(grid.permute(1,2,0).detach().cpu().numpy())
                ax.imshow(X_)
                ax = fig1.add_subplot(gs[0, 1])

                grid = utils.make_grid(self.Xrec[tt, 0].view(-1, 1, 64, 64))
                X_ = np.array(grid.permute(1,2,0).detach().cpu().numpy())
                ax.imshow(X_)
                
       
        if TIME_INDEX == 0:
            self.count += 1

        return None

    def training_step(self, train_batch, batch_idx):

        #self.print_parameter()

        X, u, State = train_batch
        # X is in the shape of torch.Size([100, 256, 64, 64]) = [time, gray_scale,image_dim,image_dim]
        # T: simulation length: T = 100
        # size of X is (T+1, batch_size, 64, 64)
        # size of u is (64, 1), because of constant u
        # size of State is (T+1, batch_size, 7)

        lhood_list = []
        kl_q_list = []
        penalty_list = []
        Time_loss_list = []
        Time_pos_loss_list = []
        Time_vel_loss_list = []
        Time_pos_backward_loss_list = []
        Time_vel_backward_loss_list = []

        iii = 0
        for TIME_INDEX in np.arange(0,T_pred-cutoff_freq_input ,gap_interval): # Default: 20 is the gap interval
            X_ = X[TIME_INDEX:,:,:,:]

            State_ = State[TIME_INDEX:,:,:]
            if TIME_INDEX != 0:
                self.forward(X_, u, State_,TIME_INDEX,self.SRC_MAS_V[iii])
            else:
                self.forward(X_, u, State_,TIME_INDEX,self.SRC_MAS_V[iii])

            T = X_.shape[0]
            T_pred_ = (T-1)-(self.context_size+2)+1

            #print('T:',T)
            #print(self.qT.shape)
            #print(self.qT_enc.shape)
            #print(self.qT.view(T_pred_,self.bs,-1).shape)
            #print(self.qT_enc.view(T-1,self.bs,-1).shape)
            #print(self.q_dotT.shape)
            #print(self.q_dotT_enc.shape)
            
            # current version mean over everthing
            if mean_over_everything == 1:
                Time_pos_loss = self.loss_fn_mean(self.qT.view(T_pred_,self.bs,-1).detach(),self.qT_enc.view(T-1,self.bs,-1)[self.context_size+1:])
                Time_vel_loss = self.loss_fn_mean(self.q_dotT.detach(),self.q_dotT_enc[self.context_size+1:])
                Time_pos_backward_loss = self.loss_fn_mean(self.qT_backward.view(T_pred_,self.bs,-1).detach(),self.qT_enc.view(T-1,self.bs,-1)[self.context_size:-1])
                Time_vel_backward_loss = self.loss_fn_mean(self.q_dotT_backward.detach(),self.q_dotT_enc[self.context_size:-1])
            else:      
                Time_pos_loss = self.loss_fn(self.qT.view(T_pred_,self.bs,-1).detach(),self.qT_enc.view(T-1,self.bs,-1)[self.context_size+1:])
                Time_vel_loss = self.loss_fn(self.q_dotT.detach(),self.q_dotT_enc[self.context_size+1:])
                Time_pos_backward_loss = self.loss_fn(self.qT_backward.view(T_pred_,self.bs,-1).detach(),self.qT_enc.view(T-1,self.bs,-1)[self.context_size:-1])
                Time_vel_backward_loss = self.loss_fn(self.q_dotT_backward.detach(),self.q_dotT_enc[self.context_size:-1])

                #Time_pos_loss: torch.Size([89, 32, 2])
                #Time_vel_loss: torch.Size([89, 32, 1])
                #Time_pos_backward_loss: torch.Size([89, 32, 2])
                #Time_vel_backward_loss: torch.Size([89, 32, 1])

                #print(self.qT.view(T_pred_,self.bs,-1).detach().shape)
                #print(self.qT_enc.view(T-1,self.bs,-1).shape)
                #torch.Size([89, 32, 2])
                #torch.Size([100, 32, 2])
                #sys.exit()

                # The size of Time_pos_loss is [Time_Steps, BatchSize, 2]
                Time_pos_loss = Time_pos_loss.sum([0,2]).mean() # Per batchsize over steps and states
                Time_vel_loss = Time_vel_loss.sum([0,2]).mean() # Per batchsize over steps and states
                Time_pos_backward_loss = Time_pos_backward_loss.sum([0,2]).mean() # Per batchsize over steps and states
                Time_vel_backward_loss = Time_vel_backward_loss.sum([0,2]).mean() # Per batchsize over steps and states
            
            #Time_pos_loss = Time_pos_loss.mean() # Per batchsize and steps over states
            #Time_vel_loss = Time_vel_loss.mean() # Per batchsize and steps over states
            Time_loss = Time_pos_loss + Time_vel_loss + Time_pos_backward_loss + Time_vel_backward_loss
  
            ######### Compute the loss #########
            lhood = - self.loss_fn(self.Xrec, X_[self.context_size+1:-1])
            lhood = lhood.sum([0, 2, 3]).mean()
            #lhood = lhood.sum([2, 3]).mean()

            if model_variant==0:
                kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean() \
                     + torch.distributions.kl.kl_divergence(self.Q_dot_q, self.P_normal).mean()
            elif model_variant==3:
                kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean()
            else:
                if enable_attn == 1:
                    kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean() \
                         + torch.distributions.kl.kl_divergence(self.Q_dot_q, self.P_q).mean()
                else:
                    kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean()
            if enable_attn == 1:    
                norm_penalty = (self.q0_m.norm(dim=-1).mean() - 1) ** 2 + (self.q0_dot_m.norm(dim=-1).mean() - 1) ** 2
            else:
                norm_penalty = (self.q0_m.norm(dim=-1).mean() - 1) ** 2
            lambda_ = self.current_epoch/8000 if self.hparams.annealing else 1/100

            lhood_list.append(lhood)
            kl_q_list.append(kl_q)
            penalty_list.append(lambda_ * norm_penalty)
            Time_loss_list.append(Time_loss)
            Time_pos_loss_list.append(Time_pos_loss)
            Time_vel_loss_list.append(Time_vel_loss)
            Time_pos_backward_loss_list.append(Time_pos_backward_loss)
            Time_vel_backward_loss_list.append(Time_vel_backward_loss)

        #### Final loss function ####
        # Reconstruction loss
        # KL loss
        # Regulization loss
        # State Alignment loss
        # Velocity Alignment Loss

        lhood_list = torch.stack(lhood_list, axis=0)
        lhood_list.retain_grad()
        kl_q_list = torch.stack(kl_q_list, axis=0)
        kl_q_list.retain_grad()
        penalty_list = torch.stack(penalty_list, axis=0)
        penalty_list.retain_grad()
        Time_loss_list = torch.stack(Time_loss_list, axis=0)
        Time_loss_list.retain_grad()
        Time_pos_loss_list = torch.stack(Time_pos_loss_list, axis=0)
        Time_pos_loss_list.retain_grad()
        Time_vel_loss_list = torch.stack(Time_vel_loss_list, axis=0)
        Time_vel_loss_list.retain_grad()
        Time_pos_backward_loss_list = torch.stack(Time_pos_backward_loss_list, axis=0)
        Time_pos_backward_loss_list.retain_grad()
        Time_vel_backward_loss_list = torch.stack(Time_vel_backward_loss_list, axis=0)
        Time_vel_backward_loss_list.retain_grad()

        # current version
        #loss = Time_loss_list.mean()
        loss = - weight_recons * lhood_list.mean() \
               + lambda_ * penalty_list.mean() \
               + 1.0 * kl_q_list.mean()\
               + Time_loss_weight * Time_loss_list.mean() # This one combines the forward and the backward states

        logs = {'Recons_Loss': -lhood_list.mean(), \
                'State_Loss': Time_loss_list.mean(), \
                'State_Pos_Loss': Time_pos_loss_list.mean(), \
                'State_Vel_Loss': Time_vel_loss_list.mean(), \
                'State_Pos_Backward_Loss': Time_pos_backward_loss.mean(), \
                'State_Vel_Backward_Loss': Time_vel_backward_loss.mean(), \
                'KL_loss': kl_q_list.mean(), \
                'Regulization_loss':  penalty_list.mean(), \
                'Regulization_loss_lambda':  lambda_, \
                'loss': loss, \
                'monitor': loss}

        # Log the running loss
        return {'loss':loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=lr, type=float)
        parser.add_argument('--batch_size', default=num_batch, type=int)
        
        return parser

def main(args):

    
    # Define the dataset type
    if dataset_type == 0:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L0.1-1.0_mu0.1-1.0_Hz100_sL750_nS64_DaTrue.pkl'
    elif dataset_type == 1:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_20Hz_HighinitV.pkl'
    elif dataset_type == 2:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_20Hz_HighinitV_mu.pkl'
    elif dataset_type == 3:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz100_sL750_nS64_DaTrue.pkl'
    elif dataset_type == 4:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz100_sL750_nS64_DaFalse.pkl'
    elif dataset_type == 5:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz100_sL405_nS32_DaFalse.pkl'
    elif dataset_type == 6:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL125_nS100_DaTrue.pkl'
    elif dataset_type == 7:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz100_sL125_nS100_DaTrue.pkl'
    elif dataset_type == 8:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-0.1_Hz20_sL102_nS10_DaTrue.pkl'
    elif dataset_type == 9:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL102_nS100_DaTrue_Sv5.0-10.0.pkl'
    elif dataset_type == 10:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL102_nS100_DaTrue_Sv0.5-4.0.pkl'
    elif dataset_type == 11:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL102_nS100_DaTrue_Sv0.5-1.0_Sp-1.5707963267948966-1.5707963267948966.pkl'
    elif dataset_type == 12:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL102_nS10_DaTrue_Sv1.0-1.0_Sp-1.5707963267948966-1.5707963267948966.pkl'
    elif dataset_type == 13:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL102_nS20_DaTrue_Sv1.0-1.0_Sp1.5707963267948966-1.5707963267948966.pkl'
    elif dataset_type == 14:
        dataset_name = dataset_folder + 'pendulum-gym-image-dataset-train_L1.0-1.0_mu0.1-1.0_Hz20_sL125_nS1000_DaFalse_Sv0.5-4.0_Sp-3.141592653589793-3.141592653589793.pkl'


    model = Model(hparams=args, data_path=os.path.join(PARENT_DIR, 'datasets', dataset_name))

    # doc link for "ModelCheckpoint"
    # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/model_checkpoint.py
    checkpoint_callback = ModelCheckpoint(monitor='monitor',
                                          dirpath=args.name + '/',
                                          filename='Model-{epoch:05d}-{loss:.2f}',
                                          save_top_k=5, 
                                          save_last=True)

    # doc link for "Trainer"
    # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/trainer.py
    trainer = Trainer.from_argparse_args(args, 
                                         limit_train_batches=1,
                                         max_epochs=10000,
                                         deterministic=True,
                                         terminate_on_nan=True,
                                         log_every_n_steps=1,
                                         default_root_dir=os.path.join(PARENT_DIR, 'logs', args.name),
                                         checkpoint_callback=checkpoint_callback,gradient_clip_val=gradient_clip,track_grad_norm=2) 
    
    # if want to load the model from the checkpoint
    trainer.fit(model)

if __name__ == '__main__':

    parser = ArgumentParser(add_help=False)
    parser.add_argument('--name', default=''+save_dir, type=str)
    parser.add_argument('--T_pred', default=T_pred, type=int)
    parser.add_argument('--solver', default='rk4', type=str)#euler # rk4
    parser.add_argument('--homo_u', dest='homo_u', action='store_true')
    parser.add_argument('--annealing', dest='annealing', action='store_true')
    parser.set_defaults(homo_u=False, annealing=True)
    # Add args from trainer
    parser = Trainer.add_argparse_args(parser)
    # Give the module a chance to add own params
    # Good practice to define LightningModule speficic params in the module
    parser = Model.add_model_specific_args(parser)
    # Parse params
    args = parser.parse_args()

    main(args)


    # make gif
    #https://ezgif.com/maker/ezgif-6-69fc2d3f-gif