"""
This file is used to run and save the results of multiple prediction tests done using the FNO 2D time model. 
"""

"""
This file is the Fourier Neural Operator for 2D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper]
"Derived from code for "FOURIER NEURAL OPERATOR FOR PARAMETRIC PARTIAL DIFFERENTIAL EQUATIONS" 
which uses a recurrent structure to propagates in time.
It is derived largely from the code that was made public from that paper.
"""

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# append the parent directory path 
import os 
import sys 
current_dir = os.getcwd()
parent_dir = os.path.join(current_dir, "..")
sys.path.append(parent_dir)

import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import imageio
from utilities3 import *

import operator
from functools import reduce
from functools import partial

from timeit import default_timer
import scipy.io
import pickle
import pdb
from tqdm import tqdm 

torch.manual_seed(0)
np.random.seed(0)

######################################## Save utilities #######################################

# Error functions

def relative_error(x, y):
    """
    Get the relative error: mean of norm of difference/ norm of ground truth 
    Derived from: code for "FOURIER NEURAL OPERATOR FOR PARAMETRIC PARTIAL DIFFERENTIAL EQUATIONS"

    Uses the 2 norm. 
    args: 
        - x: image to compare 
        - y: second image - the ground truth image 
    returns: 
        - error: float 
    """
    norm_ord = 2
    difference_image = np.array(x) - np.array(y)
    difference_norm = np.linalg.norm(difference_image, ord=norm_ord)
    y_norm = np.linalg.norm(y, ord=norm_ord)
    return difference_norm/y_norm


def MSE(x,y):
    """
    Get the MSE (mean squared error) between the two images. 
    args: 
        - x: image to compare 
        - y: second image - the ground truth image 
    returns: 
        - error: float 
    """
    difference_image = np.array(x) - np.array(y)
    diff_sq_image = np.square(difference_image)
    diff_sq_image_vec = diff_sq_image.reshape(np.product(diff_sq_image.shape))
    return np.mean(diff_sq_image_vec)

def avg_relative_error(xset, yset):
    """
    The average relative error of all the images in the x and y datasets 
    args: 
        - xset: list of images
        - yset: list of images
    returns: 
        - average relative error 
    """
    x = np.array(xset)
    y = np.array(yset)

    num_images = x.shape[0]
    x = x.reshape((num_images, -1))
    y = y.reshape((num_images, -1))

    difference_images = x - y 
    difference_norm = np.linalg.norm(difference_images, ord=2, axis=1)
    y_norm = np.linalg.norm(y, ord=2, axis=1)

    return np.mean(difference_norm/y_norm)

def avg_MSE(xset, yset):
    """
    The average MSE of all the images in the xset and the yset 
    args: 
        - xset: list of images
        - yset: list of images
    """

    x = np.array(xset)
    y = np.array(yset)

    num_images = x.shape[0]
    x = x.reshape((num_images, -1))
    y = y.reshape((num_images, -1))

    difference_image = x - y 
    diff_img_sq = np.square(difference_image)
    return np.mean(diff_img_sq)

def compute_stds_off(mean_image, var_image, ground_truth_image):
    """
    Create a corresponding image that states how many standard deviations the
    ground truth image is off from the mean image. 
    args: 
        - mean_image: the mean image
        - var_image: var image - each pixel is the variance of the corresponding pixel in the mean image
        - ground_truth_image: the ground truth image
    returns: 
        - std_off_image: image where each pixel corresponds to how many standard deviations the ground
                         truth image is from the mean image. 
    """
    diff_image = np.abs(mean_image - ground_truth_image)
    std_image = np.sqrt(var_image)
    std_image[std_image==0] = 1

    return diff_image/std_image

def save_imageseq_gif(gif_savepath, max_steps, ydataset=None, predicted=None, frame_nums = None, 
    error_type='rel', show_difference=True, static_range=False, show_colorbar=True, show_error_graph=True, show_whole_error_graph=False, figsize=None, duration=3):
    """
    Save the sequence of images at gif_savepath
    args: 
        - gif_savepath: the path to save the file at 
        - max_steps: the steps of the sequence to show 
        - xdataset: list of images/image tuples
        - ydataset: ist of images/image tuples 
        - predicted: list of predicted images/image tuples
        - frame_nums: list of ints: the ints refer to the frame numbers of the predicted frames and the true frames in the original image seq
                    - NOTE: must be same length as the other lists 
        - error_type: type: string - the type of error to display over the difference image 
            - 'rel': relative error - The default in FNO 
            - 'MSE': mean squared error
        - show_difference: show the difference between the y dataset and the predicted dataset  
        - static_range: boolean: if True then have the color range be static - set vmax and vmin else do not
        - show_colorbar: boolean value: when True the colorbar is displayed 
        - show_error_graph: boolean value: when True displays a graph of the of all the error values and highlights the error value of the current frame
        - show_whole_error_graph: boolean value:
            - True: when True always shows the whole error graph - rollout number vs error and highlights the current spot
            - False: when False then only shows the graph up to the current frame with the current frame error highlighted
        - show_stds_off: type: bool
            - True: show the number of standard deviations off - need to specify the variance images for this
            - False: don't show this 
        - figsize: (width, height) tuple 
    
    It will automaticlally adjust based on which datasets you provide, the columns will be labelled
    and image sequences from the datasets will be shown side by side. 
    """
    assert(not (type(ydataset) == type(None)  and type(predicted) == type(None)))

    column_titles = []
    ycols = 0
    pred_cols = 0 

    if (type(ydataset) != type(None)): 
        if type(ydataset[0]) == tuple: 
            ycols = len(ydataset[0])
        else: 
            ycols = 1

    if (type(predicted) != type(None)):
        if type(predicted[0]) == tuple: 
            pred_cols = len(predicted[0])
        else: 
            pred_cols = 1


    if show_difference and (pred_cols > 0 and ycols > 0): 
        diff_cols = pred_cols
    else: 
        diff_cols = 0 

    if show_error_graph and (pred_cols > 0 and ycols > 0):
        graph_cols = diff_cols
    else:
        graph_cols = 0

    total_cols =  ycols + pred_cols + diff_cols + graph_cols

    column_titles = []

    column_titles.extend(['y'] * ycols)
    column_titles.extend(['predicted'] * pred_cols)
    column_titles.extend(['diff'] * diff_cols)
    column_titles.extend(['error graph'] * graph_cols)

    error_name = ''
    if error_type == 'rel':
        error_name = "Relative"
    elif error_type == 'MSE':
        error_name = "Mean Squared"


    vmin = min(np.min(np.array(predicted)), np.min(np.array(ydataset)))
    vmax = max(np.max(np.array(predicted)), np.max(np.array(ydataset)))


    # extract all the error values
    all_errors = [] 
    all_stds_off = []
    for step in range(max_steps):

        if diff_cols > 1: 
            current_error_list = []
            current_stds_off_list = []
            for ycol in range(ycols):
                if error_type == 'rel':
                    # compute relative error
                    error = relative_error(x=predicted[step][diff_col], y=ydataset[step][diff_col]) 
                elif error_type == 'MSE':
                    # compute the MSE 
                    error = MSE(x=predicted[step][diff_col], y=ydataset[step][diff_col]) 
                else: 
                    print("ERROR: incorrect error_type inputted in function")
                    assert(False)
                current_error_list.append(error)

            all_errors.append(current_error_list)

        elif diff_cols > 0:
            if error_type == 'rel':
                # compute relative error
                error = relative_error(x=predicted[step], y=ydataset[step]) 
            elif error_type == 'MSE':
                # compute the MSE 
                error = MSE(x=predicted[step], y=ydataset[step]) 
            else: 
                print("ERROR: incorrect error_type inputted in function")
                assert(False)
            all_errors.append(error)



    if show_error_graph:
        # create a list where every sublist step long that contains all the errors, length of the outer list is graph_cols
        if graph_cols > 1: 
            error_lists_tograph = []
            for graph_col in graph_cols: 
                current_error_graph = []
                for step in range(steps):
                    current_error_graph.append(all_errors[step][graph_col])
                error_lists_tograph.append(current_error_graph)

    # create figures for the gif 
    all_figures = []
    rows = 1
    for step in range(max_steps):
        if (step % 10 == 0 or step == max_steps - 1): 
            print("Step: ", step)
        
        if type(figsize) == None:
            fig, axs = plt.subplots(rows, total_cols)
        else: 
            fig, axs = plt.subplots(rows, total_cols, figsize=figsize)

        curr_colnum = 0 

        # y dataset
        if ycols > 1: 
            for ycol in range(ycols): 
                ax = axs[curr_colnum]
                if static_range: 
                    pos = ax.imshow(ydataset[step][ycol], vmin=vmin, vmax=vmax)
                else: 
                    pos = ax.imshow(ydataset[step][ycol])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                if type(frame_nums) != type(None):
                    ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
                else:
                    ax.set_title(column_titles[curr_colnum])
                curr_colnum += 1
        elif ycols > 0:
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            if static_range: 
                pos = ax.imshow(ydataset[step], vmin=vmin, vmax=vmax)
            else: 
                pos = ax.imshow(ydataset[step])
            if show_colorbar: 
                fig.colorbar(pos, ax=ax)
            if type(frame_nums) != type(None):
                ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
            else:
                ax.set_title(column_titles[curr_colnum])
            curr_colnum += 1

        # predicted dataset
        if pred_cols > 1: 
            for pred_col in range(pred_cols): 
                ax = axs[curr_colnum]
                if static_range: 
                    pos = ax.imshow(predicted[step][pred_col], vmin=vmin, vmax=vmax)
                else: 
                    pos = ax.imshow(predicted[step][pred_col])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                if type(frame_nums) != type(None):
                    ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
                else:
                    ax.set_title(column_titles[curr_colnum])
                curr_colnum += 1
        elif pred_cols > 0: 
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            if static_range: 
                pos = ax.imshow(predicted[step], vmin=vmin, vmax=vmax)
            else: 
                pos = ax.imshow(predicted[step])
            if show_colorbar: 
                fig.colorbar(pos, ax=ax)
            if type(frame_nums) != type(None):
                ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
            else:
                ax.set_title(column_titles[curr_colnum])
            curr_colnum += 1


        if diff_cols > 1: 
            for diff_col in range(diff_cols): 
                ax = axs[curr_colnum]
                error = all_errors[step][diff_col]
                if static_range: 
                    pos = ax.imshow(ydataset[step][diff_col] - predicted[step][diff_col], vmin=vmin, vmax=vmax)
                else:  
                    pos = ax.imshow(ydataset[step][diff_col] - predicted[step][diff_col])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                ax.set_title(column_titles[curr_colnum] + ": " + str(error_type) + ": " + str(np.round(error, 5)))
                curr_colnum += 1
        elif diff_cols > 0: 
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            error = all_errors[step]
            if static_range: 
                pos = ax.imshow(ydataset[step] - predicted[step], vmin=vmin, vmax=vmax)
            else: 
                pos = ax.imshow(ydataset[step] - predicted[step])
            if show_colorbar: 
                fig.colorbar(pos, ax=ax)
            ax.set_title(column_titles[curr_colnum] + ": " + str(error_type) + ": " + str(np.round(error, 5)))
            curr_colnum += 1

        if graph_cols > 1: 
            for graph_col in range(graph_cols):
                ax = axs[curr_colnum]
                error = all_errors[step][graph_col]
                # plot
                if show_whole_error_graph:
                    ax.plot(error_lists_tograph[graph_col], color='b')
                else: 
                    ax.plot(error_lists_tograph[graph_col][:min(max(int(np.ceil(step/10)), 1)*10 + 1, max_steps)], color='b')
                # highlight current frame
                ax.scatter(step, error, s=100, color='r')
                ax.set_title("frame, error: (" + str(step) + ", " + str(np.round(error, 5)) + ")")
                #ax.set_xlabel("Rollout number")
                #ax.set_ylabel("Error")
                curr_colnum += 1
        elif graph_cols > 0: 
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            error = all_errors[step]
            # plot
            if show_whole_error_graph:
                ax.plot(all_errors, color='b')
            else: 
                ax.plot(all_errors[:min(max(int(np.ceil(step/10)), 1)*10 + 1, max_steps)], color='b')
            # highlight current frame
            ax.scatter(step, error, s=100, color='r')
            #ax.set_xlabel("Rollout number")
            #ax.set_ylabel("Error")
            ax.set_title("frame, error: (" + str(step) + ", " + str(np.round(error, 5)) + ")")

        # draw the image and store the numpy array in the list
        #plt.tight_layout()
        plt.subplots_adjust( 
                    wspace=0.5, 
                    hspace=0.0)
        canvas = FigureCanvas(fig)
        canvas.draw()
        s, (width, height) = canvas.print_to_buffer()
        width = int(width)
        height = int(height)
        image = np.array(np.fromstring(canvas.tostring_rgb(), dtype='uint8')).reshape(height, width, 3)

        all_figures.append(image)
        plt.close()

    # create the gif and write it 
    kargs = { 'duration': duration }
    with imageio.get_writer(gif_savepath, mode='I', **kargs) as writer:
        for image in all_figures:
            writer.append_data(image)

    # save the error graphs
    if show_error_graph: 
        if graph_cols > 1: 
            for graph_col in graph_cols: 
                plt.figure(figsize=(8,8))
                plt.title("Rollout frame number vs. " + str(error_type) + " error")
                plt.xlabel("Rollout frame number")
                plt.ylabel(str(error_name) + " error")
                plt.xticks(list(np.arange(0, max_steps, int(max_steps//10))) + [max_steps])

                plt.plot(np.arange(max_steps), error_lists_tograph[graph_col], color='b')
                plt.scatter(np.arange(max_steps), error_lists_tograph[graph_col], s=20, color='b')
                plt.savefig(gif_savepath[:-4] +"_" + str(graph_col) + str(".png"))
        elif graph_cols > 0: 
            plt.figure(figsize=(10,8))
            plt.title("Rollout frame number vs. " + str(error_type) + " error")
            plt.xlabel("Rollout frame number")
            plt.ylabel(str(error_name) + " error")
            plt.xticks(list(np.arange(0, max_steps, max(1, int(max_steps//10)) )) + [max_steps])

            plt.plot(np.arange(max_steps), all_errors, color='b')
            plt.scatter(np.arange(max_steps), all_errors, s=20, color='b')
            plt.savefig(gif_savepath[:-4] + str(".png"))

        # save the error values in an .npy file
        error_savepath = gif_savepath[:-4] + ".npy"
        np.save(error_savepath, np.array(all_errors))

    print("Saved: ", gif_savepath)
    return 

############################################# End: Save Utilities ####################################

#Complex multiplication
def compl_mul2d(a, b):
    op = partial(torch.einsum, "bctq,dctq->bdtq")
    return torch.stack([
        op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]),
        op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1])
    ], dim=-1)

################################################################
# fourier layer
################################################################

class SpectralConv2d_fast(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super(SpectralConv2d_fast, self).__init__()

        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2))

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.rfft(x, 2, normalized=True, onesided=True)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, 2, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        #Return to physical space
        x = torch.irfft(out_ft, 2, normalized=True, onesided=True, signal_sizes=(x.size(-2), x.size(-1)))
        return x

class SimpleBlock2d(nn.Module):
    def __init__(self, modes1, modes2, width):
        super(SimpleBlock2d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)
        input shape: (batchsize, x=64, y=64, c=12)
        output: the solution of the next timestep
        output shape: (batchsize, x=64, y=64, c=1)
        """
        global GLOBAL_TIN_PLUS2
        self.modes1 = modes1
        self.modes2 = modes2
        self.width = width
        self.fc0 = nn.Linear(GLOBAL_TIN_PLUS2, self.width)
        # self.fc0 = nn.Linear(12, self.width)
        # input channel is 12: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)

        self.conv0 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.conv1 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.conv2 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.conv3 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)
        self.bn0 = torch.nn.BatchNorm2d(self.width)
        self.bn1 = torch.nn.BatchNorm2d(self.width)
        self.bn2 = torch.nn.BatchNorm2d(self.width)
        self.bn3 = torch.nn.BatchNorm2d(self.width)


        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        batchsize = x.shape[0]
        size_x, size_y = x.shape[1], x.shape[2]

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)

        x1 = self.conv0(x)
        x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)
        x = self.bn0(x1 + x2)
        x = F.relu(x)
        x1 = self.conv1(x)
        x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)
        x = self.bn1(x1 + x2)
        x = F.relu(x)
        x1 = self.conv2(x)
        x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)
        x = self.bn2(x1 + x2)
        x = F.relu(x)
        x1 = self.conv3(x)
        x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)
        x = self.bn3(x1 + x2)


        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

class Net2d(nn.Module):
    def __init__(self, modes, width):
        super(Net2d, self).__init__()

        """
        A wrapper function
        """

        self.conv1 = SimpleBlock2d(modes, modes, width)


    def forward(self, x):
        x = self.conv1(x)
        return x


    def count_params(self):
        c = 0
        for p in self.parameters():
            c += reduce(operator.mul, list(p.size()))

        return c

############################################# Additional helper function #############################################
# save as a text file to read easily 
def save_dictionary_to_text(save_dic, save_dic_filename): 
    """
    Saves a dictionary and its keys and values as a plaintext for easy reading
    args: 
        - save_dic: dictionary to save
        - save_dic_filename: txt filename to save dictionary to
    """
    with open(save_dic_filename, "w") as dic_file: 
        for key in save_dic.keys():
            value = save_dic[key]
            if type(value) == type(save_dic):
                dic_file.write("\n") # more space to make reading clearer
            
            dic_file.write(str(key) + ": ")
            
            if type(value) == type(save_dic): 
                dic_file.write("\n")
                # iterate through the subdictionary and write its contents to file  
                # NOTE: only one level of subdictionaries are supported
                for sub_key in value.keys(): 
                    sub_value = value[sub_key]
                    dic_file.write(str(sub_key) + ": " + str(sub_value) + "\n")
                dic_file.write("\n")
            else:
                dic_file.write(str(value) + "\n")
        dic_file.write("EOF")
    return 
#########################################################################################################################

#################################################### START: Modifiable parameters ####################################################
num_predictions = [50] # keep this the same lenght as the number of filenames 
prediction_steps = 15 #20 #15 # how many steps in the future to predict
num_training_imgs = 10#10
random_img_start = False # boolean: if True chooses a random part of the sequence to train and test with, if False does it starting at 'start_point' index of each 
start_point = 0 # Part of the sequence to evaluate for training and testing

viscosity = 1e-3
nvstokes_filename = "../../dataset/ns_data_32by32_visc1e-3_1e-4dt_200T_50samples_2.mat"
nvstokes_filenames = [nvstokes_filename]
multiple_seq_in_file = True # boolean: if True: indicates that there are multiple sequences stored within a single file, if False: indicates each file only stores one data sequence

# sequence to start and end predicting on 
start_prediction_seq = [0] # keep this the same lenght as the number of filenames 
end_prediction_seq = [start_prediction_seq[i] + num_predictions[i] for i in range(len(num_predictions))]

save_dirname = "50_comparison_fno_2d_time_20epochs_2" # name of the directory to create and store all the results in 
save_filename_suffix_start = "fno2d_time_10train_32by32batchtest_visc1e-3_50test_2"#"32by32batchtest_visc1e-3_50test" # starting suffix for all the files to be saved

GLOBAL_EPOCHS = 20

#################################################### END: Modifiable parameters ####################################################

#################################################### START: Load train and test data ####################################################
train_seq_len = num_training_imgs
test_seq_len = prediction_steps

sub = 1
S = 32#64
T_in = 3 # 10
GLOBAL_TIN_PLUS2 = T_in + 2
T = 1 #train_seq_len - T_in
step = 1

T_in_test = T_in
T_test = test_seq_len

# start: network training parameters
batch_size = 1#10#20
batch_size2 = batch_size
test_batch_size = 1
# end: training parameters 

train_test_readers = [MatReader(nvstokes_filenames[i]) for i in range(len(nvstokes_filenames))]
all_sequences = [train_test_reader.read_field('u')[:, ::sub, ::sub, :] for train_test_reader in train_test_readers]

all_train_loaders = []
all_test_loaders = []

our_training_data_method = True # boolean toggle to use our method of creating training data

for file_num, train_test_reader in enumerate(train_test_readers): 
    for seq_num in range(start_prediction_seq[file_num], end_prediction_seq[file_num]): 
        if our_training_data_method: 
            # our method generates a datapoint out of each T_in -> T output pair in the train_seq_len
            train_a = None
            train_u = None
            start_train_img = 0

            while start_train_img + (T_in + T) <= train_seq_len: 
                curr_train_a = all_sequences[file_num][seq_num:seq_num+1, :, :, start_train_img: start_train_img + T_in] 
                curr_train_u = all_sequences[file_num][seq_num:seq_num+1, :, :, start_train_img + T_in:start_train_img + T_in + T]
                if type(None) == type(train_a):
                    train_a = curr_train_a
                    train_u = curr_train_u
                else: 
                    train_a = torch.cat((train_a, curr_train_a))
                    train_u = torch.cat((train_u, curr_train_u))

                start_train_img += 1
        else: 
            # their method generates a single datapoint T_in -> train_seq_len - T_in 
            train_a = all_sequences[file_num][seq_num:seq_num+1, :, :, :T_in] 
            train_u = all_sequences[file_num][seq_num:seq_num+1, :, :, T_in:train_seq_len] 

        test_a = all_sequences[file_num][seq_num:seq_num+1, :, :, train_seq_len - T_in_test: train_seq_len] 
        test_u = all_sequences[file_num][seq_num:seq_num+1, :, :, train_seq_len:train_seq_len + T_test]

        # pad the location (x,y)
        gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
        gridx = gridx.reshape(1, S, 1, 1).repeat([1, 1, S, 1])
        gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
        gridy = gridy.reshape(1, 1, S, 1).repeat([1, S, 1, 1])

        num_train_points = train_a.shape[0]
        num_test_points = test_a.shape[0]

        train_a = torch.cat((train_a, gridx.repeat([num_train_points,1,1,1]), gridy.repeat([num_train_points,1,1,1])), dim=-1)
        test_a = torch.cat((test_a, gridx.repeat([num_test_points,1,1,1]), gridy.repeat([num_test_points,1,1,1])), dim=-1)

        train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=test_batch_size, shuffle=False)

        # append to the list of all 
        all_train_loaders.append(train_loader)
        all_test_loaders.append(test_loader)

#################################################### END: Load train and test data ####################################################

#################################################### Start: Create the save structure ####################################################
##### Create the directory to save #####
if not os.path.isdir(save_dirname): 
    os.mkdir(save_dirname)

# Create the subdirectories to store different things
gif_folder = os.path.join(save_dirname, "multiple_gifs") # multiple_gifs: folder containing all the predicted gifs
npy_folder = os.path.join(save_dirname, "multiple_img_npys") # multiple_img_npys: folder containing all the npy files for the predicted images and variance images

if not os.path.isdir(gif_folder): 
    os.mkdir(gif_folder)
if not os.path.isdir(npy_folder):
    os.mkdir(npy_folder)

print("Created directory structure")

# save all the parameters into the folder - prevents having long suffixes after each name
# save as a pickle file accessed via a dictionary 
all_parameters = {}
all_parameters['nvstokes_filenames'] = nvstokes_filenames
all_parameters['num_predictions'] = num_predictions
all_parameters['prediction_steps'] = prediction_steps 
all_parameters['num_training_imgs'] = num_training_imgs
all_parameters['random_img_start'] = random_img_start
all_parameters['viscosity'] = viscosity
all_parameters['start_prediction_seq'] = start_prediction_seq
all_parameters['end_prediction_seq'] = end_prediction_seq

parameter_pkl_filename = os.path.join(save_dirname, "all_parameters.pkl")
with open(parameter_pkl_filename, "wb") as pkl_file: 
    pickle.dump(all_parameters, pkl_file)
    pkl_file.close()


parameter_txt_filename = os.path.join(save_dirname, "all_parameters.txt")
save_dictionary_to_text(save_dic=all_parameters, save_dic_filename=parameter_txt_filename)

#################################################### End: Create the save structure ####################################################
#################################################### Start: For loop for multiple tests ####################################################

def run_prediction(prediction_num): 
    """
    run the prediction: first train the model and then predict 
    args: 
        - prediction_num: index of the train and test loader to use
    returns: 
        - prediction: list of images predicted. 
    """
    global all_train_loaders, all_test_loaders
    global npy_folder, gif_folder, save_filename_suffix_start
    global S, GLOBAL_EPOCHS

    gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
    gridx = gridx.reshape(1, S, 1, 1).repeat([1, 1, S, 1])
    gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
    gridy = gridy.reshape(1, 1, S, 1).repeat([1, S, 1, 1])

    # parameters 
    modes = 12
    width = 20
    epochs = GLOBAL_EPOCHS#10#50#500
    learning_rate = 0.0025
    scheduler_step = 100
    scheduler_gamma = 0.5
    device = torch.device('cpu')

    model = Net2d(modes, width)

    print(model.count_params())
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)

    myloss = LpLoss(size_average=False)
    gridx = gridx.to(device)
    gridy = gridy.to(device)

    train_loader = all_train_loaders[prediction_num]
    test_loader = all_test_loaders[prediction_num]

    for ep in tqdm(range(epochs)):
        model.train()
        t1 = default_timer()
        train_l2_step = 0
        train_l2_full = 0
        for xx, yy in train_loader:
            loss = 0
            xx = xx.to(device)
            yy = yy.to(device)

            for t in range(0, T, step):
                y = yy[..., t:t + step]
                im = model(xx)
                loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1))

                if t == 0:
                    pred = im
                else:
                    pred = torch.cat((pred, im), -1)

                xx = torch.cat((xx[..., step:-2], im,
                                gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1)

            train_l2_step += loss.item()
            l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1))
            train_l2_full += l2_full.item()
            optimizer.zero_grad()
            loss.backward()
            # l2_full.backward()
            optimizer.step()

    print("Finished Training model")
    with torch.no_grad(): 
        for xx, yy in test_loader: 
            xx = xx.to(device)
            for t in range(0, T_test, step):
                im = model(xx)

                if t == 0: 
                    pred = im
                else: 
                    pred = torch.cat((pred, im), -1)

                xx = torch.cat((xx[..., step:-2], im,
                                gridx.repeat([test_batch_size, 1, 1, 1]), gridy.repeat([test_batch_size, 1, 1, 1])), dim=-1)
    
    # save the prediction numpy files 
    predicted_images = [pred[0, :, :, i].numpy() for i in range(pred.shape[-1])] #np.array(pred.numpy())
    save_filename = save_filename_suffix_start + "_" + str(prediction_num)
    np.save(os.path.join(npy_folder, save_filename + ".npy"), predicted_images)

    
    # save the ground truth, prediction, error 
    gif_savepath = os.path.join(gif_folder, save_filename + ".gif")
    # get the ground truth images
    ground_truth_images = []
    for xx,yy in test_loader: 
        for i in range(yy.shape[-1]):
            ground_truth_images.append(yy[0, :, :, i].numpy())
    ground_truth_images = np.array(ground_truth_images)
    save_imageseq_gif(gif_savepath, 
                      prediction_steps, 
                      ydataset=ground_truth_images, 
                      predicted=predicted_images, 
                      show_difference=True, 
                      static_range=True, 
                      show_colorbar=True, 
                      show_error_graph=True, 
                      show_whole_error_graph=True,  
                      figsize=(16, 2.1), 
                      duration=1)

    return pred

print("Starting the for loop")

for prediction_num in range(np.sum(np.array(num_predictions))): 
    run_prediction(prediction_num=prediction_num)


#################################################### End: For loop for multiple tests ####################################################
