import numpy as np 
import cv2 
import matplotlib.pyplot as plt 

import imageio
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

import sys 
import os 
import pathlib
current_dir = pathlib.Path(__file__).parent.absolute()
new_src_dir = os.path.join(current_dir, "new_src")
sys.path.append(new_src_dir)

from utilities import relative_error, MSE, avg_relative_error, avg_MSE, compute_stds_off


def display_images(data_set_x, data_set_y, num_pairs, title, figsize = None, print_mat=False): 
    """
    Display num_pairs pairs of images 
    args: 
        - data_set_x: list of x data  
        - data_set_y: list of y data
        - num_pairs: number of pairs to display
        - title
        - print_mat: True if the matrices should be printed
    returns: 
        - None
    """
    rows = num_pairs
    cols = 0
    if type(data_set_x[0]) == tuple: 
        x_cols = len(data_set_x[0])
        cols = x_cols
    else: 
        x_cols = 1
        cols += 1
        
    if type(data_set_y[0]) == tuple: 
        y_cols = len(data_set_y[0])
        cols += y_cols
    else: 
        y_cols = 1 
        cols += 1
    
    if type(figsize) == None:
        fig, ax = plt.subplots(rows, cols)
    else: 
        fig, ax = plt.subplots(rows, cols, figsize=figsize)
    fig.suptitle(title)
    
    for row in range(rows): 
        if x_cols > 1: 
            for x_col in range(x_cols): 
                if rows > 1:
                    pos = ax[row, x_col].imshow(data_set_x[row][x_col])
                    fig.colorbar(pos, ax=ax[row, x_col])
                else: 
                    pos = ax[x_col].imshow(data_set_x[row][x_col])
                    fig.colorbar(pos, ax=ax[x_col])
        else:
            pos = ax[row, 0].imshow(data_set_x[row])
            fig.colorbar(pos, ax=ax[row, 0])
        if y_cols > 1: 
            for y_col in range(y_cols):
                if rows > 1: 
                    pos = ax[row, x_cols + y_col].imshow(data_set_y[row][y_col])
                    fig.colorbar(pos, ax=ax[row, x_cols + y_col])
                else: 
                    pos = ax[x_cols + y_col].imshow(data_set_y[row][y_col])
                    fig.colorbar(pos, ax=ax[x_cols + y_col])
        else:
            if rows > 1: 
                pos = ax[row, x_cols].imshow(data_set_y[row])
                fig.colorbar(pos, ax=ax[row, x_cols])
            else: 
                pos = ax[x_cols].imshow(data_set_y[row])
                fig.colorbar(pos, ax=ax[x_cols])
            
        if print_mat:
            # printing the matrices
            print("Row: ", row)
            print(data_set_x[row])
            print(data_set_y[row])
            print("\n\n\n")
            
    plt.show()
    return 


def display_images_predicted(data_set_x, data_set_y, predicted_data, num_pairs, title, figsize = None, print_mat=False): 
    """
    Display num_pairs pairs of images 
    args: 
        - data_set_x: list of x data  
        - data_set_y: list of y data
        - num_pairs: number of pairs to display
        - title
        - print_mat: True if the matrices should be printed
    returns: 
        - None
    """
    rows = num_pairs
    cols = 0
    if type(data_set_x[0]) == tuple: 
        x_cols = len(data_set_x[0])
        cols = x_cols
    else: 
        x_cols = 1
        cols += 1
        
    if type(data_set_y[0]) == tuple: 
        y_cols = len(data_set_y[0])
        cols += y_cols
    else: 
        y_cols = 1 
        cols += 1
        
    if type(predicted_data[0]) == tuple:
        pred_cols = len(predicted_data[0])
        cols += pred_cols
    else:
        pred_cols = 1
        cols += pred_cols 
        
    # add columns for difference images
    cols += pred_cols
        
    if type(figsize) == None:
        fig, ax = plt.subplots(rows, cols)
    else: 
        fig, ax = plt.subplots(rows, cols, figsize=figsize)
    
    title_col_headers = "columns: "
    for _ in range(x_cols): 
        title_col_headers += "x, "
    for _ in range(y_cols): 
        title_col_headers += "y, "
    for _ in range(pred_cols): 
        title_col_headers += "pred, "
    for _ in range(pred_cols): 
        title_col_headers += "diff, "
    fig.suptitle(title + ": " + str(title_col_headers))
    
    for row in range(rows): 
        if x_cols > 1: 
            for x_col in range(x_cols): 
                ax[row, x_col].imshow(data_set_x[row][x_col])
        else:
            ax[row, 0].imshow(data_set_x[row])
        if y_cols > 1: 
            for y_col in range(y_cols):
                ax[row, x_cols + y_col].imshow(data_set_y[row][y_col])
        else: 
            ax[row, x_cols].imshow(data_set_y[row])
            
        if pred_cols > 1: 
            for pred_col in range(pred_cols):
                ax[row, pred_col + x_cols + ycols].imshow(predicted_data[row][pred_col])
        else: 
            ax[row, x_cols + y_cols].imshow(predicted_data[row])
        
        # difference images
        if pred_cols > 1: 
            for diff_col in range(pred_cols):
                difference_img = np.abs(predicted_data[row][diff_col] - data_set_y[row][diff_col]) 
                ax[row, x_cols + ycols + pred_cols + diff_col].imshow(difference_img)
        else: 
            difference_img = np.abs(predicted_data[row] - data_set_y[row])
            ax[row, x_cols + y_cols + pred_cols].imshow(difference_img)
            
        if print_mat:
            # printing the matrices
            print("Row: ", row)
            print(data_set_x[row])
            print(data_set_y[row])
            print("\n\n\n")
            
    plt.show()
    return 

def display_images_predicted_variablerange(data_set_x, data_set_y, predicted_data, num_pairs, title, error_type='rel', figsize = None, 
    show_colorbar=True, print_mat=False): 
    """
    Display num_pairs pairs of images - does not constrain the image output to be between 0 and 1
    args: 
        - data_set_x: list of x data  
        - data_set_y: list of y data
        - num_pairs: number of pairs to display
        - title
        - error_type: type: string - the type of error to display over the difference image 
            - 'rel': relative error - The default in FNO 
            - 'MSE': mean squared error
        - figsize
        - show_colorbar: boolean: if true shows the colorbar for every subplot
        - print_mat: True if the matrices should be printed
    returns: 
        - None
    """
    rows = num_pairs
    cols = 0
    if type(data_set_x[0]) == tuple: 
        x_cols = len(data_set_x[0])
        cols = x_cols
    else: 
        x_cols = 1
        cols += 1
        
    if type(data_set_y[0]) == tuple: 
        y_cols = len(data_set_y[0])
        cols += y_cols
    else: 
        y_cols = 1 
        cols += 1
        
    if type(predicted_data[0]) == tuple:
        pred_cols = len(predicted_data[0])
        cols += pred_cols
    else:
        pred_cols = 1
        cols += pred_cols 
        
    # add columns for difference images
    cols += pred_cols
        
    if type(figsize) == None:
        fig, ax = plt.subplots(rows, cols)
    else: 
        fig, ax = plt.subplots(rows, cols, figsize=figsize)
    
    title_col_headers = "columns: "
    for _ in range(x_cols): 
        title_col_headers += "x, "
    for _ in range(y_cols): 
        title_col_headers += "y, "
    for _ in range(pred_cols): 
        title_col_headers += "pred, "
    for _ in range(pred_cols): 
        title_col_headers += "diff, "
    fig.suptitle(title + ": " + str(title_col_headers))

    for row in range(rows): 
        if x_cols > 1: 
            for x_col in range(x_cols):
                ax_curr = ax[row, x_col]
                pos = ax_curr.imshow(data_set_x[row][x_col])
                if show_colorbar:
                    fig.colorbar(pos, ax=ax_curr)
                ax_curr.set_title('x')
        else:
            ax_curr = ax[row, 0]
            pos = ax_curr.imshow(data_set_x[row])
            if show_colorbar:
                fig.colorbar(pos, ax=ax_curr)
            ax_curr.set_title('x')
        if y_cols > 1: 
            for y_col in range(y_cols):
                ax_curr = ax[row, x_cols + y_col]
                pos = ax_curr.imshow(data_set_y[row][y_col])
                if show_colorbar:
                    fig.colorbar(pos, ax=ax_curr)
                ax_curr.set_title('y')
        else: 
            ax_curr = ax[row, x_cols]
            pos = ax_curr.imshow(data_set_y[row])
            if show_colorbar:
                fig.colorbar(pos, ax=ax_curr)
            ax_curr.set_title('y')
        if pred_cols > 1: 
            for pred_col in range(pred_cols):
                ax_curr = ax[row, pred_col + x_cols + y_cols]
                pos = ax_curr.imshow(predicted_data[row][pred_col])
                if show_colorbar:
                    fig.colorbar(pos, ax=ax_curr)
                ax_curr.set_title('pred')
        else: 
            ax_curr = ax[row, x_cols + y_cols]
            pos = ax_curr.imshow(predicted_data[row])
            if show_colorbar:
                fig.colorbar(pos, ax=ax_curr)
            ax_curr.set_title('pred')
        # difference images
        if pred_cols > 1: 
            for diff_col in range(pred_cols):
                difference_img = np.abs(predicted_data[row][diff_col] - data_set_y[row][diff_col]) 
                if error_type == 'rel':
                    # compute relative error
                    error = relative_error(x=predicted_data[row][diff_col], y=data_set_y[row][diff_col]) 
                elif error_type == 'MSE':
                    # compute the MSE 
                    error = MSE(x=predicted_data[row][diff_col], y=data_set_y[row][diff_col]) 
                else: 
                    print("ERROR: incorrect error_type inputted in function")
                    assert(False)
                ax_curr = ax[row, x_cols + y_cols + pred_cols + diff_col]
                pos = ax_curr.imshow(difference_img)
                if show_colorbar:
                    fig.colorbar(pos, ax=ax_curr)
                ax_curr.set_title('diff: ' + str(error_type) + ":" + str(np.round(error, 5)))
        else: 
            difference_img = np.abs(predicted_data[row] - data_set_y[row])
            if error_type == 'rel':
                # compute relative error
                error = relative_error(x=predicted_data[row], y=data_set_y[row]) 
            elif error_type == 'MSE':
                # compute the MSE 
                error = MSE(x=predicted_data[row], y=data_set_y[row]) 
            else: 
                print("ERROR: incorrect error_type inputted in function")
                assert(False)
            ax_curr = ax[row, x_cols + y_cols + pred_cols]
            pos = ax_curr.imshow(difference_img)
            if show_colorbar:
                fig.colorbar(pos, ax=ax_curr)
            ax_curr.set_title('diff: ' + str(error_type) + ":" + str(np.round(error, 5)))
            
        if print_mat:
            # printing the matrices
            print("Row: ", row)
            print(data_set_x[row])
            print(data_set_y[row])
            print("\n\n\n")
            
    plt.show()
    return 



############################################ Save Image functions #############################################
def save_imageseq_gif(gif_savepath, max_steps, xdataset=None, ydataset=None, predicted=None, predicted_var=None, frame_nums = None, 
    error_type='rel', show_difference=True, show_colorbar=True, show_error_graph=True, show_whole_error_graph=False, show_stds_off=False, figsize=None, duration=3):
    """
    Save the sequence of images at gif_savepath
    args: 
        - gif_savepath: the path to save the file at 
        - max_steps: the steps of the sequence to show 
        - xdataset: list of images/image tuples
        - ydataset: ist of images/image tuples 
        - predicted: list of predicted images/image tuples
        - frame_nums: list of ints: the ints refer to the frame numbers of the predicted frames and the true frames in the original image seq
                    - NOTE: must be same length as the other lists 
        - error_type: type: string - the type of error to display over the difference image 
            - 'rel': relative error - The default in FNO 
            - 'MSE': mean squared error
        - show_difference: show the difference between the y dataset and the predicted dataset  
        - show_colorbar: boolean value: when True the colorbar is displayed 
        - show_error_graph: boolean value: when True displays a graph of the of all the error values and highlights the error value of the current frame
        - show_whole_error_graph: boolean value:
            - True: when True always shows the whole error graph - rollout number vs error and highlights the current spot
            - False: when False then only shows the graph up to the current frame with the current frame error highlighted
        - show_stds_off: type: bool
            - True: show the number of standard deviations off - need to specify the variance images for this
            - False: don't show this 
        - figsize: (width, height) tuple 
    
    It will automaticlally adjust based on which datasets you provide, the columns will be labelled
    and image sequences from the datasets will be shown side by side. 
    """
    assert(not (type(xdataset) == type(None) and type(ydataset) == type(None)  and type(predicted) == type(None)))

    column_titles = []
    xcols = 0
    ycols = 0
    pred_cols = 0 
    pred_var_cols = 0
    stds_off_cols = 0
    if (type(xdataset) != type(None)):
        if type(xdataset[0]) == tuple: 
            xcols = len(xdataset[0])
        else: 
            xcols = 1

    if (type(ydataset) != type(None)): 
        if type(ydataset[0]) == tuple: 
            ycols = len(ydataset[0])
        else: 
            ycols = 1

    if (type(predicted) != type(None)):
        if type(predicted[0]) == tuple: 
            pred_cols = len(predicted[0])
        else: 
            pred_cols = 1

    if (type(predicted_var) != type(None)):
        if type(predicted[0]) == tuple: 
            pred_var_cols = len(predicted_var[0])
        else: 
            pred_var_cols = 1

    if show_difference and (pred_cols > 0 and ycols > 0): 
        diff_cols = pred_cols
    else: 
        diff_cols = 0 

    if show_error_graph and (pred_cols > 0 and ycols > 0):
        graph_cols = diff_cols
    else:
        graph_cols = 0

    if show_stds_off and (pred_var_cols > 0 and ycols > 0):
        stds_off_cols = pred_var_cols
    else: 
        pred_var_cols = 0

    total_cols = xcols + ycols + pred_cols + pred_var_cols + stds_off_cols + diff_cols + graph_cols

    column_titles = []
    column_titles.extend(['x'] * xcols)
    column_titles.extend(['y'] * ycols)
    column_titles.extend(['predicted'] * pred_cols)
    column_titles.extend(['pred variance'] * pred_var_cols)
    column_titles.extend(['diff'] * diff_cols)
    column_titles.extend(['error graph'] * graph_cols)

    error_name = ''
    if error_type == 'rel':
        error_name = "Relative"
    elif error_type == 'MSE':
        error_name = "Mean Squared"


    # extract all the error values
    all_errors = [] 
    all_stds_off = []
    for step in range(max_steps):

        if diff_cols > 1: 
            current_error_list = []
            current_stds_off_list = []
            for ycol in range(ycols):
                if error_type == 'rel':
                    # compute relative error
                    error = relative_error(x=predicted[step][diff_col], y=ydataset[step][diff_col]) 
                elif error_type == 'MSE':
                    # compute the MSE 
                    error = MSE(x=predicted[step][diff_col], y=ydataset[step][diff_col]) 
                else: 
                    print("ERROR: incorrect error_type inputted in function")
                    assert(False)
                current_error_list.append(error)
                if show_stds_off: 
                    stds_off = compute_stds_off(mean_image=predicted[step][diff_col], var_image=predicted_var[step][diff_col], 
                        ground_truth_image=ydataset[step][diff_col])
                    current_stds_off_list.append(stds_off)
            all_errors.append(current_error_list)
            all_stds_off.append(current_stds_off_list)

        elif diff_cols > 0:
            if error_type == 'rel':
                # compute relative error
                error = relative_error(x=predicted[step], y=ydataset[step]) 
            elif error_type == 'MSE':
                # compute the MSE 
                error = MSE(x=predicted[step], y=ydataset[step]) 
            else: 
                print("ERROR: incorrect error_type inputted in function")
                assert(False)
            if show_stds_off: 
                stds_off = compute_stds_off(mean_image=predicted[step], var_image=predicted_var[step], 
                        ground_truth_image=ydataset[step])
            all_errors.append(error)
            all_stds_off.append(stds_off)



    if show_error_graph:
        # create a list where every sublist step long that contains all the errors, length of the outer list is graph_cols
        if graph_cols > 1: 
            error_lists_tograph = []
            for graph_col in graph_cols: 
                current_error_graph = []
                for step in range(steps):
                    current_error_graph.append(all_errors[step][graph_col])
                error_lists_tograph.append(current_error_graph)

    # create figures for the gif 
    all_figures = []
    rows = 1
    for step in range(max_steps):
        if (step % 10 == 0 or step == max_steps - 1): 
            print("Step: ", step)
        
        if type(figsize) == None:
            fig, axs = plt.subplots(rows, total_cols)
        else: 
            fig, axs = plt.subplots(rows, total_cols, figsize=figsize)

        curr_colnum = 0 
        # x dataset 
        if xcols > 1: 
            for xcol in range(xcols): 
                ax = axs[curr_colnum]
                pos = ax.imshow(xdataset[step][xcol])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                if type(frame_nums) != type(None):
                    ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
                else:
                    ax.set_title(column_titles[curr_colnum])
                curr_colnum += 1
        elif xcols > 0: 
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            pos = ax.imshow(xdataset[step])
            if show_colorbar: 
                fig.colorbar(pos, ax=ax)
            if type(frame_nums) != type(None):
                ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
            else:
                ax.set_title(column_titles[curr_colnum])
            curr_colnum += 1

        # y dataset
        if ycols > 1: 
            for ycol in range(ycols): 
                ax = axs[curr_colnum]
                pos = ax.imshow(ydataset[step][ycol])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                if type(frame_nums) != type(None):
                    ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
                else:
                    ax.set_title(column_titles[curr_colnum])
                curr_colnum += 1
        elif ycols > 0:
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            pos = ax.imshow(ydataset[step])
            if show_colorbar: 
                fig.colorbar(pos, ax=ax)
            if type(frame_nums) != type(None):
                ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
            else:
                ax.set_title(column_titles[curr_colnum])
            curr_colnum += 1

        # predicted dataset
        if pred_cols > 1: 
            for pred_col in range(pred_cols): 
                ax = axs[curr_colnum]
                pos = ax.imshow(predicted[step][pred_col])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                if type(frame_nums) != type(None):
                    ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
                else:
                    ax.set_title(column_titles[curr_colnum])
                curr_colnum += 1
        elif pred_cols > 0: 
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            pos = ax.imshow(predicted[step])
            if show_colorbar: 
                fig.colorbar(pos, ax=ax)
            if type(frame_nums) != type(None):
                ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
            else:
                ax.set_title(column_titles[curr_colnum])
            curr_colnum += 1

        # predicted variance dataset
        if pred_var_cols > 1: 
            for pred_var_col in range(pred_var_cols):
                ax = axs[curr_colnum]
                pos = ax.imshow(predicted_var[step][pred_var_col])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                if type(frame_nums) != type(None):
                    ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
                else:
                    ax.set_title(column_titles[curr_colnum])
                curr_colnum += 1
        elif pred_var_cols > 0: 
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            pos = ax.imshow(predicted_var[step])
            if show_colorbar: 
                fig.colorbar(pos, ax=ax)
            if type(frame_nums) != type(None):
                ax.set_title(column_titles[curr_colnum] + ": " + str(frame_nums[step]))
            else:
                ax.set_title(column_titles[curr_colnum])
            curr_colnum += 1

        if stds_off_cols > 1: 
            for stds_col in range(stds_off_cols): 
                ax = axs[curr_colnum]
                stds_off = all_stds_off[step][diff_col]
                pos = ax.imshow(stds_off)
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                ax.set_title("stds: mean: " + str(np.round(np.mean(stds_off), 3)) + "\nstds: max: " + str(np.round(np.max(stds_off), 3)))
                curr_colnum += 1
        elif stds_off_cols > 0: 
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            stds_off = all_stds_off[step]
            pos = ax.imshow(stds_off)
            if show_colorbar: 
                fig.colorbar(pos, ax=ax)
            ax.set_title("stds: mean: " + str(np.round(np.mean(stds_off), 3)) + "\nstds: max: " + str(np.round(np.max(stds_off), 3)))
            curr_colnum += 1

        if diff_cols > 1: 
            for diff_col in range(diff_cols): 
                ax = axs[curr_colnum]
                error = all_errors[step][diff_col]
                pos = ax.imshow(ydataset[step][diff_col] - predicted[step][diff_col])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                ax.set_title(column_titles[curr_colnum] + ": " + str(error_type) + ": " + str(np.round(error, 5)))
                curr_colnum += 1
        elif diff_cols > 0: 
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            error = all_errors[step]
            pos = ax.imshow(ydataset[step] - predicted[step])
            if show_colorbar: 
                fig.colorbar(pos, ax=ax)
            ax.set_title(column_titles[curr_colnum] + ": " + str(error_type) + ": " + str(np.round(error, 5)))
            curr_colnum += 1

        if graph_cols > 1: 
            for graph_col in range(graph_cols):
                ax = axs[curr_colnum]
                error = all_errors[step][graph_col]
                # plot
                if show_whole_error_graph:
                    ax.plot(error_lists_tograph[graph_col], color='b')
                else: 
                    ax.plot(error_lists_tograph[graph_col][:min(max(int(np.ceil(step/10)), 1)*10 + 1, max_steps)], color='b')
                # highlight current frame
                ax.scatter(step, error, s=100, color='r')
                ax.set_title("frame, error: (" + str(step) + ", " + str(np.round(error, 5)) + ")")
                #ax.set_xlabel("Rollout number")
                #ax.set_ylabel("Error")
                curr_colnum += 1
        elif graph_cols > 0: 
            if total_cols > 1: 
                ax = axs[curr_colnum]
            else: 
                ax = axs
            error = all_errors[step]
            # plot
            if show_whole_error_graph:
                ax.plot(all_errors, color='b')
            else: 
                ax.plot(all_errors[:min(max(int(np.ceil(step/10)), 1)*10 + 1, max_steps)], color='b')
            # highlight current frame
            ax.scatter(step, error, s=100, color='r')
            #ax.set_xlabel("Rollout number")
            #ax.set_ylabel("Error")
            ax.set_title("frame, error: (" + str(step) + ", " + str(np.round(error, 5)) + ")")

        # draw the image and store the numpy array in the list
        #plt.tight_layout()
        plt.subplots_adjust( 
                    wspace=0.5, 
                    hspace=0.0)
        canvas = FigureCanvas(fig)
        canvas.draw()
        s, (width, height) = canvas.print_to_buffer()
        width = int(width)
        height = int(height)
        image = np.array(np.fromstring(canvas.tostring_rgb(), dtype='uint8')).reshape(height, width, 3)

        all_figures.append(image)
        plt.close()

    # create the gif and write it 
    kargs = { 'duration': duration }
    with imageio.get_writer(gif_savepath, mode='I', **kargs) as writer:
        for image in all_figures:
            writer.append_data(image)

    # save the error graphs
    if show_error_graph: 
        if graph_cols > 1: 
            for graph_col in graph_cols: 
                plt.figure(figsize=(8,8))
                plt.title("Rollout frame number vs. " + str(error_type) + " error")
                plt.xlabel("Rollout frame number")
                plt.ylabel(str(error_name) + " error")
                plt.xticks(list(np.arange(0, max_steps, int(max_steps//10))) + [max_steps])

                plt.plot(np.arange(max_steps), error_lists_tograph[graph_col], color='b')
                plt.scatter(np.arange(max_steps), error_lists_tograph[graph_col], s=20, color='b')
                plt.savefig(gif_savepath[:-4] +"_" + str(graph_col) + str(".png"))
        elif graph_cols > 0: 
            plt.figure(figsize=(10,8))
            plt.title("Rollout frame number vs. " + str(error_type) + " error")
            plt.xlabel("Rollout frame number")
            plt.ylabel(str(error_name) + " error")
            plt.xticks(list(np.arange(0, max_steps, max(1, int(max_steps//10)) )) + [max_steps])

            plt.plot(np.arange(max_steps), all_errors, color='b')
            plt.scatter(np.arange(max_steps), all_errors, s=20, color='b')
            plt.savefig(gif_savepath[:-4] + str(".png"))

        # save the error values in an .npy file
        error_savepath = gif_savepath[:-4] + ".npy"
        np.save(error_savepath, np.array(all_errors))

    print("Saved: ", gif_savepath)
    return 


def save_multiprediction_gif(gif_savepath, max_steps, xdatasets=None, ydatasets=None, predicted_datasets=None, predicted_frame_nums=None,
    train_rollout_info=None, error_type='rel', show_difference=True, show_colorbar=True, show_error_graph=True, show_whole_error_graph=False, 
    figsize=None, duration=3):
    """
    Save the multiple sequences of images at gif_savepath. To be used to visualize the output of pred_sequential_seqlearn
    If the max steps exceeds the rollout size then you just append images of all zeros. 
    args: 
        - gif_savepath: the path to save the file at 
        - max_steps: the steps of the sequence to show 
        - xdatasets: list of lists of images/image tuples
        - ydatasets: list of lists of images/image tuples 
        - predicted_datasets: list of lists of predicted images/image tuples
        - prediced_frame_nums: list of lists of integers corresponding to frame number of the frames in predicted_datasets
        - comparisons_per_row: type: int: the number of (x,y,pred,diff) comparisons that you want to have per row
        - train_rollout_info: type: list: to be used as strings in titles of the predictions to give more information about the predictions
                            - information about the last training image the model had before doing this rollout prediction 
        - error_type: type: string - the type of error to display over the difference image 
            - 'rel': relative error - The default in FNO 
            - 'MSE': mean squared error
        - show_difference: show the difference between the y dataset and the predicted dataset for each dataset
        - show_colorbar: boolean value: when True the colorbar is displayed 
        - show_error_graph: boolean value: when True displays a graph of the of all the error values and highlights the error value of the current frame
        - show_whole_error_graph: boolean value:
            - True: when True always shows the whole error graph - rollout number vs error and highlights the current spot
            - False: when False then only shows the graph up to the current frame with the current frame error highlighted
        - figsize: (width, height) tuple 
    
    It will automaticlally adjust based on which datasets you provide, the columns will be labelled
    and image sequences from the datasets will be shown side by side. 

    Each row of the gif will contain one rollout (x,y,pred,difference) (omitting the ones that are not wanted)
    NOTE: the tuple sizes etc. of all the datasets in a list of lists must be the same
    """

    assert(not (type(xdatasets) == type(None) and type(ydatasets) == type(None)  and type(predicted_datasets) == type(None)))

    # set up the predicted_frame_nums to use for the titles: if not specified then just have np.arange(0, len(rollout)) for each rollout
    if type(predicted_frame_nums) == type(None):

        if type(xdatasets) != type(None):
            set_to_use = xdatasets
        elif type(ydatasets) != type(None):
            set_to_use = ydatasets
        elif type(predicted_datasets) != type(None):
            set_to_use = predicted_datasets

        # set up the predicted_frame_nums to use for the titles. 
        predicted_frame_nums = []
        for rollout in set_to_use:
            curr_rollout_frame_nums = []
            for i in range(len(rollout)):
                curr_rollout_frame_nums.append(i)
            predicted_frame_nums.append(curr_rollout_frame_nums)


    # sample x,y, predicted datasets to figure out the column structure
    if type(xdatasets) != type(None):
        xdataset = xdatasets[0]
    else:
        xdataset = None
    if type(ydatasets) != type(None):
        ydataset = ydatasets[0]
    else: 
        ydataset = None
    if type(predicted_datasets) != type(None):
        predicted = predicted_datasets[0]
    else: 
        predicted = None
    
    column_titles = []
    xcols = 0
    ycols = 0
    pred_cols = 0 
    if (type(xdataset) != type(None)):
        if type(xdataset[0]) == tuple: 
            xcols = len(xdataset[0])
        else: 
            xcols = 1

    if (type(ydataset) != type(None)): 
        if type(ydataset[0]) == tuple: 
            ycols = len(ydataset[0])
        else: 
            ycols = 1

    if (type(predicted) != type(None)):
        if type(predicted[0]) == tuple: 
            pred_cols = len(predicted[0])
        else: 
            pred_cols = 1

    if show_difference and (pred_cols > 0 and ycols > 0): 
        diff_cols = pred_cols
    else: 
        diff_cols = 0 

    if show_error_graph and (pred_cols > 0 and ycols > 0):
        graph_cols = diff_cols
    else:
        graph_cols = 0


    total_cols = xcols + ycols + pred_cols + diff_cols + graph_cols

    column_titles = []
    column_titles.extend(['x'] * xcols)
    column_titles.extend(['y'] * ycols)
    column_titles.extend(['predicted'] * pred_cols)
    column_titles.extend(['diff'] * pred_cols)
    #column_titles.extend(['error graph'] * graph_cols)

    # set the number of rows 
    if type(xdatasets) != type(None):
        rows = len(xdatasets)
    elif type(ydatasets) != type(None):
        rows = len(ydatasets)
    elif type(predicted_datasets) != type(None):
        rows = len(predicted_datasets)

    error_name = ''
    if error_type == 'rel':
        error_name = "Relative"
    elif error_type == 'MSE':
        error_name = "Mean Squared"

    # extract all the error values
    if diff_cols > 0 or graph_cols > 0: 
        # only get all the errors if there are errors to be computed
        all_errors = [] 
        for dnum in range(rows):
            predicted = predicted_datasets[dnum]
            ydataset = ydatasets[dnum]
            current_dataset_error_list = []
            
            for step in range(max_steps):
                if diff_cols > 1 or graph_cols > 1: 
                    current_error_list = []
                    for ycol in range(ycols):
                        if error_type == 'rel':
                            # compute relative error
                            error = relative_error(x=predicted[step][diff_col], y=ydataset[step][diff_col]) 
                        elif error_type == 'MSE':
                            # compute the MSE 
                            error = MSE(x=predicted[step][diff_col], y=ydataset[step][diff_col]) 
                        else: 
                            print("ERROR: incorrect error_type inputted in function")
                            assert(False)
                        current_error_list.append(error)     
                    current_dataset_error_list.append(current_error_list)

                elif diff_cols > 0 or graph_cols > 0:
                    if error_type == 'rel':
                        # compute relative error
                        error = relative_error(x=predicted[step], y=ydataset[step]) 
                    elif error_type == 'MSE':
                        # compute the MSE 
                        error = MSE(x=predicted[step], y=ydataset[step]) 
                    else: 
                        print("ERROR: incorrect error_type inputted in function")
                        assert(False)
                    current_dataset_error_list.append(error)
            all_errors.append(current_dataset_error_list)
    print("All errors shape: ", np.array(all_errors).shape)
    if show_error_graph:
        # create a list where every sublist step long that contains all the errors, length of the outer list is graph_cols
        if graph_cols > 1: 
            error_lists_tograph = []
            for dnum in range(rows):
                current_dataset_error_graphs = []
                for graph_col in graph_cols: 
                    current_error_graph = []
                    for step in range(steps):
                        current_error_graph.append(all_errors[step][graph_col])
                    current_dataset_error_graphs.append(current_error_graph)
                error_lists_tograph.append(current_dataset_error_graphs)


    # create figures for the gif 
    all_figures = []
    for step in range(max_steps):

        # create the figure for the current time step 
        if (step % 5 == 0 or step == max_steps - 1): 
            print("Step: ", step)
        if type(figsize) == None:
            fig, axs = plt.subplots(rows, total_cols) 
        else: 
            fig, axs = plt.subplots(rows, total_cols, figsize=figsize)
        # if there is only one row then surround it with another list/array to enable the same indexing 
        if rows == 1: 
            axs = np.array([axs])

        # populate the image 
        for row_num in np.arange(rows):
            # get the datasets to use for the current row num 
            if type(xdatasets) != type(None):
                xdataset = xdatasets[row_num]
            if type(ydatasets) != type(None):
                ydataset = ydatasets[row_num]
            if type(predicted_datasets) != type(None):
                predicted = predicted_datasets[row_num]

            curr_colnum = 0 
            # x dataset 
            if xcols > 1: 
                for xcol in range(xcols): 
                    ax = axs[row_num][curr_colnum]
                    pos = ax.imshow(xdataset[step][xcol])
                    if show_colorbar: 
                        fig.colorbar(pos, ax=ax)
                    ax.set_title(str(column_titles[curr_colnum]) + ": " + str(predicted_frame_nums[row_num][step]))
                    curr_colnum += 1
            elif xcols > 0: 
                if total_cols > 1: 
                    ax = axs[row_num][curr_colnum]
                else: 
                    ax = axs[row_num]
                pos = ax.imshow(xdataset[step])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                ax.set_title(str(column_titles[curr_colnum]) + ": " + str(predicted_frame_nums[row_num][step]))
                curr_colnum += 1

            # y dataset
            if ycols > 1: 
                for ycol in range(ycols): 
                    ax = axs[row_num][curr_colnum]
                    pos = ax.imshow(ydataset[step][ycol])
                    if show_colorbar: 
                        fig.colorbar(pos, ax=ax)
                    ax.set_title(str(column_titles[curr_colnum]) + ": " + str(predicted_frame_nums[row_num][step]))
                    curr_colnum += 1
            elif ycols > 0:
                if total_cols > 1: 
                    ax = axs[row_num][curr_colnum]
                else: 
                    ax = axs[row_num]
                pos = ax.imshow(ydataset[step])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                ax.set_title(str(column_titles[curr_colnum]) + ": " + str(predicted_frame_nums[row_num][step]))
                curr_colnum += 1

            # predicted dataset
            if pred_cols > 1: 
                for pred_col in range(pred_cols): 
                    ax = axs[row_num][curr_colnum]
                    pos = ax.imshow(predicted[step][pred_col])
                    if show_colorbar: 
                        fig.colorbar(pos, ax=ax)
                    if type(train_rollout_info) != type(None):
                        pred_title = str(column_titles[curr_colnum]) + ": trn: " + str(train_rollout_info[row_num])
                    else: 
                        pred_title = str(column_titles[curr_colnum])
                    ax.set_title(str(pred_title) + ": " + str(predicted_frame_nums[row_num][step]))
                    curr_colnum += 1
            elif pred_cols > 0: 
                if total_cols > 1: 
                    ax = axs[row_num][curr_colnum]
                else: 
                    ax = axs[row_num]
                pos = ax.imshow(predicted[step])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)

                if type(train_rollout_info) != type(None):
                    pred_title = str(column_titles[curr_colnum]) + ": trn: " + str(train_rollout_info[row_num])
                else: 
                    pred_title = str(column_titles[curr_colnum])
                ax.set_title(str(pred_title) + ": " + str(predicted_frame_nums[row_num][step]))
                curr_colnum += 1

            if diff_cols > 1: 
                for diff_col in range(diff_cols): 
                    ax = axs[row_num][curr_colnum]
                    pos = ax.imshow(ydataset[step][diff_col] - predicted[step][diff_col])
                    """
                    if error_type == 'rel':
                        # compute relative error
                        error = relative_error(x=predicted[step][diff_col], y=ydataset[step][diff_col]) 
                    elif error_type == 'MSE':
                        # compute the MSE 
                        error = MSE(x=predicted[step][diff_col], y=ydataset[step][diff_col]) 
                    else: 
                        print("ERROR: incorrect error_type inputted in function")
                        assert(False)
                    """
                    error = all_errors[row_num][step][diff_col]
                    if show_colorbar: 
                        fig.colorbar(pos, ax=ax)
                    ax.set_title(str(column_titles[curr_colnum]) + ": " + str(error_type) + ": " + str(np.round(error, 5)))
                    curr_colnum += 1
            elif diff_cols > 0: 
                if total_cols > 1: 
                    ax = axs[row_num][curr_colnum]
                else: 
                    ax = axs[row_num]
                """
                if error_type == 'rel':
                    # compute relative error
                    error = relative_error(x=predicted[step], y=ydataset[step]) 
                elif error_type == 'MSE':
                    # compute the MSE 
                    error = MSE(x=predicted[step], y=ydataset[step]) 
                else: 
                    print("ERROR: incorrect error_type inputted in function")
                    assert(False)
                """
                error = all_errors[row_num][step]
                pos = ax.imshow(ydataset[step] - predicted[step])
                if show_colorbar: 
                    fig.colorbar(pos, ax=ax)
                ax.set_title(str(column_titles[curr_colnum]) + ": " + str(error_type) + ": " + str(np.round(error, 5)))
                curr_colnum += 1

            if graph_cols > 1: 
                for graph_col in range(graph_cols): 
                    ax = axs[row_num][curr_colnum]
                    
                    error = all_errors[row_num][step][diff_col]
                    # plot
                    if show_whole_error_graph:
                        ax.plot(error_lists_tograph[row_num][graph_col], color='b')
                    else: 
                        ax.plot(error_lists_tograph[row_num][graph_col][:min(max(int(np.ceil(step/10)), 1)*10 + 1, max_steps)], color='b')
                    # highlight current frame
                    ax.scatter(step, error, s=100, color='r')
                    ax.set_title("frame, error: (" + str(step) + ", " + str(np.round(error, 5)) + ")")
                    #ax.set_xlabel("Rollout number")
                    #ax.set_ylabel("Error")
                    if show_colorbar: 
                        fig.colorbar(pos, ax=ax)
                    curr_colnum += 1
            elif graph_cols > 0: 
                if total_cols > 1: 
                    ax = axs[row_num][curr_colnum]
                else: 
                    ax = axs[row_num]

                error = all_errors[row_num][step]
                # plot
                if show_whole_error_graph:
                    ax.plot(all_errors[row_num], color='b')
                else: 
                    ax.plot(all_errors[row_num][:min(max(int(np.ceil(step/10)), 1)*10 + 1, max_steps)], color='b')
                # highlight current frame
                ax.scatter(step, error, s=100, color='r')
                #ax.set_xlabel("Rollout number")
                #ax.set_ylabel("Error")
                ax.set_title("frame, error: (" + str(step) + ", " + str(np.round(error, 5)) + ")")
                curr_colnum += 1

        # draw the image and store the numpy array in the list
        #plt.tight_layout()
        plt.subplots_adjust( 
                    wspace=0.5, 
                    hspace=0.35)
        canvas = FigureCanvas(fig)
        canvas.draw()
        s, (width, height) = canvas.print_to_buffer()
        width = int(width)
        height = int(height)
        image = np.array(np.fromstring(canvas.tostring_rgb(), dtype='uint8')).reshape(height, width, 3)

        all_figures.append(image)
        plt.close()

    # create the gif and write it 
    kargs = { 'duration': duration }
    with imageio.get_writer(gif_savepath, mode='I', **kargs) as writer:
        for image in all_figures:
            writer.append_data(image)

    # save the error values in an .npy file
    error_savepath = gif_savepath[:-4] + ".npy"
    np.save(error_savepath, np.array(all_errors))

    print("Saved: ", gif_savepath)
    return 
