"""
Overall file: 

1. Load the data 
2. Pre-process the data
    2.1: Convert to image tuples
    2.2: Add different spaces if needed: Image differences, Optical Flow, Fourier Transformations
    2.3: Patchify datasets
        2.3.1: Patchification types: 
            2.3.1.1: 
            2.3.1.2: 
    2.4: Vectorize patches 
3. Train GP model: 
    3.1: Choose model type: GP Regression, Sparse GP Regression, GP Regression with white noise
    3.2: Create kernel: 
        3.2.1: Set lengthscale
        3.2.2: Set noise variance
    3.3: Optimize models - if desired
4. Predict with GP Model:  
    4.1: Predict on test set: 
        4.1.1: Predict all at once 
        4.1.2: Might need to predict one at a time: to prevent memory issues
    4.2: Sequential prediction: 
        4.2.1: Start with one starting point and roll out 
    4.3: Sequential prediction while seeing new data: 
        4.3.1: Do sequential rollouts, update the model and then continue to do sequential rollouts 
5. Unpatchify predictions - might integrate into prediction step if doing things like rollouts. 
6. Postprocessing: if needed - if predicting fourier transforms, or if predicting image differences. 
"""
import numpy as np 
import cv2 
import GPy 
import torch 
import scipy
import time

import matplotlib.pyplot as plt 
import os 
import sys 
import pdb

current_dir = os.getcwd()
parent_dir = os.path.join(current_dir, "..")
sys.path.append(parent_dir) # append parent directory to the path 

# import locally written files 
from patchify_unpatchify import patcher
from predict import predict 
from processing import processing

def gaussian_kernel_2d(size=(5,5), sigma=10): 
    x = np.linspace(-(size[0]//2), (size[0]//2), size[0])
    gaussian_1d_x = (1/(sigma*np.sqrt(2)))* np.exp(-0.5 * ((x/sigma)**2))
    
    y = np.linspace(-(size[1]//2), (size[1]//2), size[1])
    gaussian_1d_y = (1/(sigma*np.sqrt(2)))* np.exp(-0.5 * ((y/sigma)**2))
    
    gaussian_2d = np.outer(gaussian_1d_x, gaussian_1d_y)
    return gaussian_2d

def create_kernel_isolated(base_kernel, kernel_groups, kernel_lengthscales, kernel_variances, patch_dim, all_xtypes):
        """
        Creates a kernel function based on the specified parameters
        args: 
            - base_kernel: This is the base kernel function to use: eg: GPy.kern.RBF
            - kernel_groups: list of lists
                - each sublist contains the xtype keys that should be 'joined' together under the same kernel 
            - kernel_lengthscales: type: list of numbers
                - the lengthscale corresponds to the lengthscale of the corresponding group in kernel_groups
                - same length as kernel_groups
            - kernel_variances: type: list of numbers 
                - the lengthscale corresponds to the variance of the corresponding group in kernel_groups
                - same length as kernel groups
			- patch_dim: type: tuple: the shape of the patches 
			- all_xtypes: type: list of strings: each string refers to an xtype
        returns: 
            - kernel to be used. 
        """
        assert(len(kernel_groups) == len(kernel_lengthscales))
        assert(len(kernel_groups) == len(kernel_variances))

        kernel_activedims = []
        single_patch_len = np.product(patch_dim)

        for group in kernel_groups:
            current_activedims = []
            for xtype in group:     
                if xtype not in all_xtypes: 
                    print("ERROR: incorrect group for forming the kernel")
                    assert(False)
                xtype_index = all_xtypes.index(xtype)
                current_activedims.extend(list(np.arange(single_patch_len*xtype_index, single_patch_len*(xtype_index + 1)))) # add one to account for 0 index

            kernel_activedims.append(current_activedims)

        # print("length kernel active dims: ", len(kernel_activedims))
        # print(kernel_activedims)
        # create the kernels and then multiply them together
        sub_kernels = []        
        for i, curr_kernel_activedims in enumerate(kernel_activedims): 
            current_kernel = base_kernel(input_dim=len(curr_kernel_activedims), active_dims=curr_kernel_activedims, 
                lengthscale=kernel_lengthscales[i], variance=kernel_variances[i])
            sub_kernels.append(current_kernel)

        # create the final kernel
        kernel = None
        for sub_kernel in sub_kernels: 
            if type(kernel) == type(None):
                kernel = sub_kernel
            else: 
                kernel = kernel * sub_kernel
        return kernel

class nonparam_predictor: 
    """
    This file 
    """
    def __init__(self, patch_parameters, processor_parameters, max_all_together):
        """
        Initializes the whole predictor using the following parameters: 
        args: 
            - patch_parameters: dictionary of the arguments for the patcher class 
                - 'img_dim': tuple (rows, cols): dimension of the image
                - 'patch_dim': tuple (rows, cols): dimension of the patch
                - 'patch_border': tuple (rows, cols): number of rows, cols that are used as the border of the patch 
                - 'img_padlen': tuple (rows, cols): number of rows, cols used to pad the image
                - 'img_padtype': string: to indicate the method used for padding the image. For viable options look at self.pad_image 
                - 'wrap_x': whether you want the patches to wrap in the x or column direction
                - 'wrap_y': whether you want the patches to wrap in the y or row direction
                - 'stride': tuple (row stride, col stride): the number of pixels to skip by when creating image patches
                - 'patch_weight': the patch weight to use when recombining  
            - processor_parameters: dictionary of th earguments for the processing class
                - 'xtypes': list of strings - order indicates order of types of x components in dataset
                    - 'img-2': 2 images before current img
                    - 'img-1': 1 image before current img
                    - 'img0': current img - when predicting the next
                    - 'vel_diff': image difference velocity ('img0' - 'img-1')
                    - 'accel_diff': image difference acceleration (('img0' - 'img-1') - ('img-1'  'img-2'))
                    - 'of_x': optical flow x 
                    - 'of_y': optical flow y
                - 'ytype': a string that indicates the type of the y components in dataset (also what is to be predicted)
                    - 'img': the next image 
                    - 'diff': difference between the last image and the next
                - 'farneback_flow_params': parameters for optical flow - can be/default to None if not using any optical flow
            - max_all_together: parameter for the predict class - when not processing all test vectors together how many to process at once
        """ 

        self.patch_parameters = patch_parameters
        self.processor_parameters = processor_parameters
        self.max_all_together = max_all_together

        self.patch_obj = patcher(img_dim=self.patch_parameters['img_dim'], 
                                 patch_dim=self.patch_parameters['patch_dim'], 
                                 patch_border=self.patch_parameters['patch_border'], 
                                 img_padlen=self.patch_parameters['img_padlen'], 
                                 img_padtype=self.patch_parameters['img_padtype'], 
                                 wrap_x=self.patch_parameters['wrap_x'], 
                                 wrap_y=self.patch_parameters['wrap_y'], 
                                 stride=self.patch_parameters['stride'], 
                                 patch_weight=self.patch_parameters['patch_weight'])

        self.data_processor = processing(xtypes=self.processor_parameters['xtypes'], 
                                         ytype=self.processor_parameters['ytype'], 
                                         farneback_flow_params=self.processor_parameters['farneback_flow_params'])
        self.predictor = predict(max_all_together=max_all_together)
        
        self.GP_model = None
        return 

    def create_kernel(self, base_kernel, kernel_groups, kernel_lengthscales, kernel_variances):
        """
        Creates a kernel function based on the specified parameters
        args: 
            - Implicit parameters: self.xtypes - already in the object
            - base_kernel: This is the base kernel function to use: eg: GPy.kern.RBF
            - kernel_groups: list of lists
                - each sublist contains the xtype keys that should be 'joined' together under the same kernel 
            - kernel_lengthscales: type: list of numbers
                - the lengthscale corresponds to the lengthscale of the corresponding group in kernel_groups
                - same length as kernel_groups
            - kernel_variances: type: list of numbers 
                - the lengthscale corresponds to the variance of the corresponding group in kernel_groups
                - same length as kernel groups
        returns: 
            - kernel to be used. 
        """
        assert(len(kernel_groups) == len(kernel_lengthscales))
        assert(len(kernel_groups) == len(kernel_variances))

        kernel_activedims = []
        single_patch_len = np.product(self.patch_obj.patch_dim)

        for group in kernel_groups:
            current_activedims = []
            for xtype in group:     
                if xtype not in self.data_processor.xtypes: 
                    print("ERROR: incorrect group for forming the kernel")
                    assert(False)
                xtype_index = self.data_processor.xtypes.index(xtype)
                current_activedims.extend(list(np.arange(single_patch_len*xtype_index, single_patch_len*(xtype_index + 1)))) # add one to account for 0 index

            kernel_activedims.append(current_activedims)

        # print("length kernel active dims: ", len(kernel_activedims))
        # print(kernel_activedims)
        # create the kernels and then multiply them together
        sub_kernels = []        
        for i, curr_kernel_activedims in enumerate(kernel_activedims): 
            current_kernel = base_kernel(input_dim=len(curr_kernel_activedims), active_dims=curr_kernel_activedims, 
                lengthscale=kernel_lengthscales[i], variance=kernel_variances[i])
            sub_kernels.append(current_kernel)

        # create the final kernel
        kernel = None
        for sub_kernel in sub_kernels: 
            if type(kernel) == type(None):
                kernel = sub_kernel
            else: 
                kernel = kernel * sub_kernel
        return kernel

    def train_model(self, kernel, datapoints, optimize=False, noise_var=0.01, use_sparse_GP=False, num_inducing_points=100, max_opt_iters=None):
        """
        Trains the GP model. To do this it first converts the train_image_seq into the desired training data. 
        and then trains the GP model using the created dataset. 

        args: 
            - kernel: initialized GPy kernel
            - datapoints: dictionary containing - only has one of the below
                Either: 
                    - 'image_seq': sequence of images to be converted into a patchified, vectorized dataset
                Or: 
                    - 'x_dataset': pre-processed dataset of images
                    - 'y_dataset': pre-processed dataset of images corresponding to the given 'x_dataset'
            - optimize: if true will optimize the GP model kernel and other parameters
            - noise_var: the noise variance to use for the GP Regression model
            - use_sparse_GP: type: bool
                - True: use the sparse GP regression model 
                - False: use normal GP regression model 
            - num_inducing_points: type: int: the number of inducing points to use in the sparse GP model. not used if use_sparse_GP is False. 
            - max_opt_iters: type: int: the maximum optimization iterations to run - default is None - dont use it
        returns: None
            - the trained model is stored in the attribute self.GP_model

        # TODO: Maybe later add support for other types of GP models
        """

        """
        if 'image_seq' in datapoints.keys(): 
            # create image dataset
            train_x, train_y = self.data_processor.create_xy(datapoints['image_seq'])
        elif 'x_dataset' in datapoints.keys() and 'y_dataset' in datapoints.keys():
            train_x = datapoints['x_dataset']
            train_y = datapoints['y_dataset']

        # patch and vectorize dataset
        train_x_patch = self.patch_obj.patchify_dataset(dataset=train_x, dataset_type='x')
        train_y_patch = self.patch_obj.patchify_dataset(dataset=train_y, dataset_type='y')
        train_x_patch_vecs = self.data_processor.convert_imgdataset_to_vecdataset(dataset=train_x_patch)
        train_y_patch_vecs = self.data_processor.convert_imgdataset_to_vecdataset(dataset=train_y_patch)

        num_train_vecs = len(train_x_patch_vecs)

        # train using the vectorized patches
        self.kernel = kernel
        self.GP_model = GPy.models.GPRegression(X=np.array(train_x_patch_vecs).reshape((num_train_vecs, -1)), 
                                                Y=np.array(train_y_patch_vecs).reshape((num_train_vecs, -1)), 
                                                kernel=self.kernel, 
                                                noise_var=noise_var)

        if optimize:
            self.GP_model.optimize(messages=True)

        return self.GP_model
        """
        self.GP_model = self.predictor.train_model(kernel=kernel, datapoints=datapoints, 
                                                   data_processor=self.data_processor, 
                                                   patch_obj=self.patch_obj, 
                                                   optimize=optimize, 
                                                   noise_var=noise_var, 
                                                   use_sparse_GP=use_sparse_GP, 
                                                   num_inducing_points=num_inducing_points,
                                                   max_opt_iters=max_opt_iters)
        self.kernel = self.GP_model.kern
        return self.GP_model

    def add_datapoints(self, datapoints):
        """
        Add datapoints to the trained GP model. 
        args: 
            - datapoints: dictionary containing - only has one of the below
                Either: 
                    - 'image_seq': sequence of images to be converted into a patchified, vectorized dataset
                Or: 
                    - 'x_dataset': pre-processed dataset of images
                    - 'y_dataset': pre-processed dataset of images corresponding to the given 'x_dataset'
        returns: 
            - None: but update the GP_model
        """
        """
        if 'image_seq' in datapoints.keys(): 
            x_dataset, y_dataset = self.data_processor.create_xy(image_seq=datapoints['image_seq'])
        elif 'x_dataset' in datapoints.keys() and 'y_dataset' in datapoints.keys():
            x_dataset = datapoints['x_dataset']
            y_dataset = datapoints['y_dataset']
        else:
            print("ERROR: Proper data not provided")
            assert(False)

        # patchify the data and turn it into vecs
        x_dataset = self.patch_obj.patchify_dataset(dataset=x_dataset, dataset_type='x')
        y_dataset = self.patch_obj.patchify_dataset(dataset=y_dataset, dataset_type='y')
        x_dataset = self.data_processor.convert_imgdataset_to_vecdataset(dataset=x_dataset)
        y_dataset = self.data_processor.convert_imgdataset_to_vecdataset(dataset=y_dataset)

        new_X = np.vstack((self.GP_model.X, np.array(x_dataset).reshape((len(x_dataset), -1))))
        new_Y = np.vstack((self.GP_model.Y, np.array(y_dataset).reshape((len(y_dataset), -1))))
        # update the model 
        self.GP_model.set_XY(X=new_X, Y=new_Y)
        return 
        """
        self.predictor.add_datapoints_to_GP(GP_model=self.GP_model, 
                                            datapoints=datapoints, 
                                            optimize=optimize, 
                                            data_processor=self.data_processor, 
                                            patch_obj=self.patch_obj)
        return 

    def calculate_iK_beta(self, svd_inverse = False, sigma_threshold = None, plot_for_thresholding = False):
        """
        Calculate the inverse of the training kernel gram matrix and the beta vector used in the PILCO paper. 
        args: 
            - implicit argument gp_model: GPy gp model
            - svd_inverse: type: boolean: 
                - True: use the svd inverse with the specified sigma cutoff
                - False: use the matrices returned by GPy directly 
            - sigma_threshold: 
                - None: take the sigma values up to the "large" dropoff
                    - We define a large dropoff as if the change in values between the next two sigma values is less than 1 order of magnitude compared to the change in the last 2 values
                - float: the maximum sigma value you want to keep before inverting
            - plot_for_thresholding: type: boolean:
                - False: just use the specified thresholds 
                - True: plot the sigma values and go into a set trace to allow the sigma threshold to be set manually by the user. 

            NOTE: order of precedence: svd_inverse,  sigma_threshold, then plot_for_thresholding
        returns: 
            - inv_K: inverse of the kernel gram matrix
            - beta: the beta vector used in the mean and covariance calculation
        """

        if not svd_inverse: 
            # Use GPy matrices
            train_data = self.GP_model._predictive_variable

            K_train = self.GP_model.kern.K(train_data, train_data)
            K_train_noise = K_train + (np.eye(K_train.shape[0])*float(self.GP_model.Gaussian_noise.variance))
            K_train_noise = K_train_noise + (np.eye(K_train_noise.shape[0])*1e-8) # additional noise for invertibility

            # # compute solutions using cholesky decomposition for stability: 1:16 in https://www.youtube.com/watch?v=4vGiHC35j9s&t=316s
            # lower_triang = True
            # L = scipy.linalg.cholesky(a=K_train_noise, lower=lower_triang)
            # beta = scipy.linalg.cho_solve((L, lower_triang), b=self.GP_model.Y, overwrite_b=False) # also works - but same result already computed in GPy model
            # K_train_noise_inv = self.stable_cho_inverse(matrix=L, cholesky=True)

            K_train_noise_inv = self.GP_model.posterior.woodbury_inv
            beta = self.GP_model.posterior.woodbury_vector

            return K_train_noise, K_train_noise_inv, beta
        
        else: 
            # Stable inverse calculation using the SVD
            from GPy.util.linalg import jitchol, dtrtri, dpotri, dpotrs
            K = self.GP_model.kern.K(self.GP_model._predictive_variable, self.GP_model._predictive_variable)
            noise_variance = self.GP_model.Gaussian_noise.variance
            jitter = 1e-8
            K_noise = K + ((noise_variance + jitter) * np.eye(K.shape[0]) ) 
            umat, svec, vhmat = np.linalg.svd(K_noise)

            if type(sigma_threshold) == type(None):
                keep_threshold = None
                # Find the dropoff point
                for i in range(1, len(svec) - 1):
                    last_difference = np.abs(svec[i-1] - svec[i])
                    next_difference = np.abs(svec[i] - svec[i+1])

                    if (last_difference/next_difference) > 10: 
                        keep_threshold = svec[i]
                        break 
                if type(keep_threshold) == type(None):
                    keep_threshold = svec[-1] - np.finfo(float).eps
            else: 
                keep_threshold = float(sigma_threshold)


            if plot_for_thresholding: 
                # plot and let the user manually input the threshold
                fig = plt.figure()
                plt.title("Sigma values K train noise")
                plt.plot(svec.flatten())
                plt.show()

                print("Sigma values are stored in variable svec")
                print("Default kee_threshold: ", keep_threshold)
                print("Set variable: \'keep_threshold\' as the maximum sigma value you want to keep")
                pdb.set_trace()

            
            # truncate the sigma vector 
            len_svec = len(svec)

            print("Keep threshold: ", keep_threshold)
            svec_truncated_inv = np.zeros(len_svec)
            svec_truncated_inv[np.abs(svec) > keep_threshold] = 1/np.array(svec[np.abs(svec) > keep_threshold]) 

            svec_truncated_inv_mat = np.diag(svec_truncated_inv)
            iK = vhmat.T@svec_truncated_inv_mat@umat.T
            beta = iK@self.GP_model.Y

            return K_noise, iK, beta

    def setup_mean_var_propogation_parameters(self, patch_obj, svd_inverse, sigma_threshold, plot_for_thresholding):
        """
        Setup the self. parameters for the mean variance propogation functions. 
        args: 
            - patch_obj: the patch object that is used for the test rollout
            - svd_inverse: argument to calculate_ik_beta
            - sigma_threshold: argument to calculate_ik_beta
            - plot_for_thresholding: argument to calculate_ik_beta
        """
        # START SETUP 
        # Separate out the kernels - NOTE: these are all the same for each image space for now - structured to make it easy to change later
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        pixels_per_inpatch = np.product(patch_obj.patch_dim)
        img_space_lengthscale = float(self.GP_model.kern.lengthscale)
        img_space_variance    = float(self.GP_model.kern.variance)**(1/3)
        
        self.lengthscale_matrix_dim0 = np.eye(pixels_per_inpatch) * img_space_lengthscale**2
        self.lengthscale_matrix_dim1 = np.eye(pixels_per_inpatch) * img_space_lengthscale**2
        self.lengthscale_matrix_dim2 = np.eye(pixels_per_inpatch) * img_space_lengthscale**2
        self.lengthscale_matrix_dim01 = scipy.linalg.block_diag(self.lengthscale_matrix_dim0, self.lengthscale_matrix_dim1)
        self.lengthscale_matrix_dim012 = scipy.linalg.block_diag(self.lengthscale_matrix_dim01, self.lengthscale_matrix_dim2)

        self.lengthscale_matrix_dim0_inv = np.linalg.inv(self.lengthscale_matrix_dim0)
        self.lengthscale_matrix_dim1_inv = np.linalg.inv(self.lengthscale_matrix_dim1)
        self.lengthscale_matrix_dim2_inv = np.linalg.inv(self.lengthscale_matrix_dim2)
        self.lengthscale_matrix_dim01_inv = scipy.linalg.block_diag(self.lengthscale_matrix_dim0_inv, self.lengthscale_matrix_dim1_inv)
        self.lengthscale_matrix_dim012_inv = scipy.linalg.block_diag(self.lengthscale_matrix_dim01_inv, self.lengthscale_matrix_dim2_inv)

        # load inverse matrices into gpu
        self.lengthscale_matrix_dim0_inv_tr = torch.cuda.DoubleTensor(self.lengthscale_matrix_dim0_inv)
        self.lengthscale_matrix_dim01_inv_tr = torch.cuda.DoubleTensor(self.lengthscale_matrix_dim01_inv)
        self.lengthscale_matrix_dim012_inv_tr = torch.cuda.DoubleTensor(self.lengthscale_matrix_dim012_inv)

        self.variance_dim0 = img_space_variance
        self.variance_dim1 = img_space_variance
        self.variance_dim2 = img_space_variance
        self.full_variance = float(self.GP_model.kern.variance)


        self.kernel_dim0 = GPy.kern.RBF(input_dim=pixels_per_inpatch, lengthscale=img_space_lengthscale, variance=self.variance_dim0) # kernel for the 0th image patch in the input vec
        self.kernel_dim1 = GPy.kern.RBF(input_dim=pixels_per_inpatch, lengthscale=img_space_lengthscale, variance=self.variance_dim1) # kernel for the 1st image patch in the input vec
        self.kernel_dim2 = GPy.kern.RBF(input_dim=pixels_per_inpatch, lengthscale=img_space_lengthscale, variance=self.variance_dim2) # kernel for the 2nd image patch in the input vec

        # Process Data: 
        # Seperate datasets by image patch dimensions
        self.dim0_start_idx = 0 
        self.dim0_end_idx = self.dim0_start_idx + pixels_per_inpatch
        self.dim1_start_idx = self.dim0_end_idx
        self.dim1_end_idx = self.dim1_start_idx + pixels_per_inpatch
        self.dim2_start_idx = self.dim1_end_idx
        self.dim2_end_idx = self.dim2_start_idx + pixels_per_inpatch

        # Training data: seperate the image patch dimensions
        self.full_train_xvecs = np.array(self.GP_model._predictive_variable)
        self.num_training_points = self.full_train_xvecs.shape[0]
        self.dim0_train_xvecs = self.full_train_xvecs[:, self.dim0_start_idx:self.dim0_end_idx]
        self.dim1_train_xvecs = self.full_train_xvecs[:, self.dim1_start_idx:self.dim1_end_idx]
        self.dim2_train_xvecs = self.full_train_xvecs[:, self.dim2_start_idx:self.dim2_end_idx]
        self.dim01_train_xvecs = self.full_train_xvecs[:, self.dim0_start_idx:self.dim1_end_idx]
        # END SETUP

        ###########################################################################
        # Set up: pre-process variables needed for all predictions
        ###########################################################################
        print("Creating the Ba vector")
        start_time = time.time()

        self.K_train_noise, self.K_train_noise_inv, self.Ba_vector = self.calculate_iK_beta(svd_inverse = svd_inverse, sigma_threshold = sigma_threshold, plot_for_thresholding = plot_for_thresholding)
        
        elapsed_time = time.time() - start_time
        print("Finished creating the Ba vector: ", elapsed_time)

        return 

    


################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
######################################################### OLD LEGACY FUNCTIONS - DON'T WORK  ###################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################

    def full_predict_dataset(self, test_set, test_set_type, all_together=True, use_variance_weighting=True, test_patch_obj=None): 
        """
        Use the internal GP model to predict on the specified test dataset using the predict class
        args: 
            - test_set: type: list: This can be either just a sequence of images or a sequence of image tuples - the x of a dataset
            - test_set_type: type: string: the type of test_set inputted 
                - 'image_seq': This indicates the test set is a sequence of images that must be pre-processed into the appropriate x dataset 
                - 'xdataset': Pre-processed x dataset that can be patchified, vectorized and put into the model 
            - all_together: If the model should predict the dataset all together or parts at a time (memory, big matrix)
            - use_variance_weighting: whether you want to use intelligent weighting when unpatchifying images
            - test_patch_obj: type: patcher: Separate patcher to use for prediction, if None use self.patch_obj
        """
        assert(type(self.GP_model) != type(None)) # ensure GP model has been trained

        if test_set_type == 'image_seq': 
            test_x_images = self.data_processor.create_x(image_seq=test_set)

        elif test_set_type == 'xdataset':
            test_x_images = test_set

        curr_patch_obj = self.patch_obj
        if type(test_patch_obj) == patcher: 
            curr_patch_obj = test_patch_obj
            # asserts to ensure data compatibility between the two patch objects 
            assert(test_patch_obj.img_dim == self.patch_obj.img_dim)
            assert(test_patch_obj.get_ypatch_dim() == self.patch_obj.get_ypatch_dim())
        
        return self.predictor.pred_test_set(test_x=test_x_images, 
                                            data_processor=self.data_processor, 
                                            GP_model=self.GP_model, 
                                            patch_obj=curr_patch_obj, 
                                            all_together=all_together, 
                                            use_variance_weighting=use_variance_weighting)


    def full_pred_sequential(self,  starting_x_images, steps, all_together=True, use_variance_weighting=True, test_patch_obj=None): 
        """
        Starting from the image datapoint formed from starting_x_images rollout predictions for 'steps' steps
        args: 
            - starting_x_images: see predict.pred_sequential
            - steps: number of steps the predict/rollout
            - all_together: If the model should predict the dataset all together or parts at a time (memory, big matrix)
            - use_variance_weighting: whether you want to use intelligent weighting when unpatchifying images
            - test_patch_obj: type: patcher: Separate patcher to use for prediction, if None use self.patch_obj
        returns: 
            - see predict.pred_sequential
        """

        assert(type(self.GP_model) != type(None))

        curr_patch_obj = self.patch_obj
        if type(test_patch_obj) == patcher: 
            curr_patch_obj = test_patch_obj
            # asserts to ensure data compatibility between the two patch objects
            assert(test_patch_obj.img_dim == self.patch_obj.img_dim)
            assert(test_patch_obj.get_ypatch_dim() == self.patch_obj.get_ypatch_dim())

        return self.predictor.pred_sequential(starting_x_images=starting_x_images, 
                                            data_processor=self.data_processor, 
                                            GP_model=self.GP_model, 
                                            patch_obj=curr_patch_obj, 
                                            steps=steps, 
                                            all_together=all_together, 
                                            use_variance_weighting=use_variance_weighting)


    def full_pred_sequential_seqlearn(self, image_seq, learning_indices, rollout_indices, max_rollout_steps, test_patch_obj=None, optimize=False, 
        all_together=True, use_variance_weighting=True, noise_var=0.01): 
        """
        Predicts several rollout sequences. After learning_interval number of images it updates the training data used by the 
        GP model. After this training it does the next rollout. It continues this till it gets to the end of the sequence. 

        Methodology: Use self.pred_sequential for each rollout, add data to the GP model every learning_interval, continue rollouts, repeat. 

        args: 
            - image_seq: sequence of images to predict - used to generate the training data and also used for testing against the rollouts 
            - learning_indices: type: list of ints: list of integers corresponding to the index of the image in image_seq 
                                - these indices indcate the index of the image when the model will be updated, these model updates
                                mean that the model will have data on all the images in the image sequence up to (but not inclusive of) the specified index
            - rollout_indices: type: list of ints: index of image in the image seq that will be the first expected image of the rollout
                                    - the starting x images to begin the rollout will include image_seq[index-3:index]
            - max_rollout_steps: type: int: maximum number of steps to rollout - max of this and the steps remaining in the image_seq
            - optimize: type: boolean:
                - True: Optimize every time you add data to the GP model
                - False: do not optimize and use the same kernel parameters. 
            - patch_obj: object of type patcher to help patchify and unpatchify the images 
            - all_together: boolean indicator
                - True: predict the entire dataset in one go 
                - False: predict the dataset one image at a time - NOTE: TODO: might need to change this to x patches at a time. 
            - use_variance_weighting: 
                - True: use variance for the mask weighting
                - False: do not use variance - just do averaging
            - noise_var: the noise variance to use when training or creating the GP model 
        
        NOTE: the rollout indices must start with a number greater than 3 and the learning_indices must start with a number greater than 4, both must
        be in ascending order. 

        returns: 
            - pred_rollouts: list of lists showing the predict rollouts
            - true_rollouts: list of lists showing the true image sequences corresponding to the rollouts. 
            - all_rollout_indices: list of lists where each entry is an integer corresponding to the index of the image in true_rollouts in the image_seq
            - last_train_indices: list of length pred_rollouts where each entry is the last image that the prediction model used to train. 
        """
        if type(self.GP_model) != type(None):
            GP_model = self.GP_model
            kernel = None
        elif type(self.kernel) != type(None):
            GP_model = None
            kernel = self.kernel
        else: 
            # both the GP model and kernel are none means there is an error 
            print("ERROR: You need to have at least the self.kernel initialized!")
            assert(False)


        if type(test_patch_obj) == type(None):
            test_patch_obj = self.patch_obj

        pred_rollout, true_rollout, all_rollout_indices, last_train_indices, returned_GP_model = \
        							 self.predictor.pred_sequential_seqlearn(image_seq=image_seq, 
                                                                             learning_indices=learning_indices, 
                                                                             rollout_indices=rollout_indices, 
                                                                             max_rollout_steps=max_rollout_steps, 
                                                                             GP_model=GP_model, 
                                                                             kernel=kernel, 
                                                                             data_processor=self.data_processor, 
                                                                             patch_obj=self.patch_obj,
                                                                             test_patch_obj=test_patch_obj,
                                                                             optimize=optimize, 
                                                                             all_together=all_together, 
                                                                             use_variance_weighting=use_variance_weighting, 
                                                                             noise_var=noise_var)

        self.GP_model = returned_GP_model
        return pred_rollout, true_rollout, all_rollout_indices, last_train_indices




    ############################################## Useless functions ############################################
    def full_pred_sequential(self,  starting_x_images, steps, all_together=True, use_variance_weighting=True, test_patch_obj=None): 
        """
        Starting from the image datapoint formed from starting_x_images rollout predictions for 'steps' steps
        1. get the prediction from self.predict.GP_model 
        2. Convert the prediction to a proper multivariate gaussian 
        3. sample the prediction 
        4. Use the prediction for reconstruction 
        5. Repeat. 

        args: 
            - starting_x_images: see predict.pred_sequential
            - steps: number of steps the predict/rollout
            - all_together: If the model should predict the dataset all together or parts at a time (memory, big matrix)
            - use_variance_weighting: whether you want to use intelligent weighting when unpatchifying images
            - test_patch_obj: type: patcher: Separate patcher to use for prediction, if None use self.patch_obj
        returns: 
            - see predict.pred_sequential
        """

        assert(type(self.GP_model) != type(None))
        
        if type(test_patch_obj) == type(None):
            patch_obj = self.patch_obj
        else: 
            patch_obj = test_patch_obj

        return self.sub_pred_sequential_sampling(starting_x_images=starting_x_images, 
                                                steps=steps, 
                                                patch_obj=patch_obj, 
                                                all_together=all_together, 
                                                use_variance_weighting=use_variance_weighting)

    def stable_cho_inverse(self, matrix, cholesky=False):
        """
        More stable inverse using the cholesky decomposition and least squares type solutions
        NOTE: the matrix for inversion must be Hermetian, Positive Definite. 

        args: 
            - matrix: the matrix you want to invert
                - can also enter the cholesky decomposition (lower triangular part) directly
            - cholesky: boolean: 
                - True: the matrix entered is the lower triangular cholesky decomposition
                - False: the matrix entered is the full matrix - DEFAULT option
        returns: 
            - matrix_inv: the inverse of the matrix
        """

        if not cholesky: 
            L = scipy.linalg.cholesky(a=matrix, lower=True)
        else:
            L = matrix

        matrix_inv = scipy.linalg.cho_solve((L, True), b=np.eye(L.shape[0]), overwrite_b=True)
        return matrix_inv

    def unscented_transform(self, input_mean, input_cov, dim1_input_val=None, dim2_input_val=None):
        """
        Propogates forward the probability distribution using the unscented transform. 
        args: 
            - input_mean: the mean along the input dimensions that are random variables
            - input_cov: the corresponding covariance matrix to the input mean
            - dim1_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first mean propogation
            - dim2_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first and second mean propogation
        returns: 
            - current_prediction_mean 
            - current_prediction_var
        """ 
        d = np.product(input_mean.shape)
        alpha = 0.75
        k = 0

        input_mean = input_mean.reshape((1, -1))
        input_cov_chol = scipy.linalg.cholesky(a=np.array(input_cov), lower=True, overwrite_a=True) # the lower triangular cholesky decomposition

        v0 = 1 - (d/((alpha**2)*(d + k)))
        sigma0 = input_mean
        v_weights = [v0]
        sigma_points = [sigma0]

        for i in range(d): 
            # append 2 sigma points per i and their corresponding weights
            vi = (1 - v0)/(2*d)

            sigma_mean_adder = alpha*np.sqrt(d/(1 - v0))*input_cov_chol[:, i].reshape(input_mean.shape)
            sigma_i1 = input_mean + sigma_mean_adder
            sigma_i2 = input_mean - sigma_mean_adder

            v_weights.extend([vi, vi])
            sigma_points.append(sigma_i1)
            sigma_points.append(sigma_i2)

        # pass the points through the nonlinear functions
        all_sigma_points_array = np.array(sigma_points).reshape((2*d + 1, -1))
        if type(dim1_input_val) != type(None) and type(dim2_input_val) != type(None):
            # 2nd prediction  
            dim12_input_val = np.hstack((dim1_input_val.reshape((1, -1)), dim2_input_val.reshape((1, -1))))
            dim12_stack = np.repeat(dim12_input_val, repeats=2*d + 1, axis=0)
            all_sigma_points_array = np.hstack((all_sigma_points_array, dim12_stack))
        elif type(dim2_input_val) != type(None): 
            # 3rd prediction
            dim2_stack = np.repeat(dim2_input_val.reshape((1, -1)), repeats=2*d + 1, axis=0)
            all_sigma_points_array = np.hstack((all_sigma_points_array, dim2_stack))

        all_output_means, all_output_vars = self.GP_model.predict(Xnew=all_sigma_points_array)

        # compute the output mean, variance and covariance
        output_mean = np.average(all_output_means, weights=np.array(v_weights))
        demeaned_outputs = all_output_means - output_mean 
        output_var = np.average(np.diagonal(np.dot(demeaned_outputs, demeaned_outputs.T)), weights=np.array(v_weights))

        # NOTE: the unscented transform passes a probability distribution through a nonlinear function
        # Here our nonlinear function is the function for the mean - but this doesn't quite deal with 
        # passing the probability distribution through a nonlinear function that outputs both a mean and a variance ? 
        # should average the outputted all_output_vars as well?

        return output_mean, output_var 