"""
Same method as code except the sequential rollout will be replaced using a sampling based method: 
For each prediction sample the output pixels rather than just the means
"""

import numpy as np 
import cv2 
import GPy 

import matplotlib.pyplot as plt 
import os 
import sys 
import time 
import multiprocessing as mp
import scipy
from scipy.linalg import lapack, blas


from tqdm import tqdm 

current_dir = os.getcwd()
parent_dir = os.path.join(current_dir, "..")
sys.path.append(parent_dir) # append parent directory to the path 

# import locally written files 
from patchify_unpatchify import patcher
from predict import predict 
from processing import processing

from nonparam_wrapper import nonparam_predictor
import torch

import pdb

class nonparam_predictor_sampling(nonparam_predictor):
    # Same class as nonparam_predictor - except for the rollouts this class samples each rollout rather than using the mean 
    # only change is to the function: 'full_pred_sequential'
        


    


#########################################################################################################################################################
#########################################################################################################################################################
#########################################################################################################################################################
############################################################# OLD LEGACY CODE ###########################################################################
#########################################################################################################################################################
#########################################################################################################################################################
#########################################################################################################################################################


    def sub_pred_sequential_sampling(self, starting_x_images, steps, patch_obj, all_together, use_variance_weighting): 
        """
        Predicts a rollout sequence from the starting_x_images for steps number of steps. 
        args: 
            - starting_x_images: list of starting images: depending on the dataprocessor up to 3 images are needed
                - this is used to start the prediction process
            - data_processor: object of class processing used to pre and post process the images to the requisite data types for the GP
            - GP_model: GP model object of type GPy's GP model 
            - patch_obj: object of type patcher to help patchify and unpatchify the images
            - steps: the number of steps to rollout
            - all_together: boolean indicator
                - True: predict the entire dataset in one go 
                - False: predict the dataset one image at a time - NOTE: TODO: might need to change this to x patches at a time. 
            - use_variance_weighting: 
                - True: use variance for the mask weighting
                - False: do not use variance - just do averaging
        returns: 
            - pred_images: list of unpatchified images that are predicted from the test_x set
            - pred_var_images: 
                - returns this if  use_variance_weighting is set to True
            - xdataset: the xdataset that was used to generate each of the predicted images during the rollout

        """
        # use_variance_weighting input only left in for compatibility - do not use with this approach 
        assert(use_variance_weighting == False)

        predicted_seq_imgs = []
        predicted_seq_vars = []
        xdataset = []

        curr_test_x = list(starting_x_images)

        for step in range(steps):
            # process the current image seq
            # NOTE: add a padder image as create_x always leaves the last image for the y dataset creator
            processed_curr_test_x = self.data_processor.create_x(image_seq=list(curr_test_x) + [np.zeros(curr_test_x[0].shape)])
            xdataset.extend(processed_curr_test_x)
            # patch the processed dataset
            patched_processed_curr_test_x = patch_obj.patchify_dataset(dataset=processed_curr_test_x,
                dataset_type='x')
            # convert patches to flattened vectors
            curr_test_x_vecs = self.data_processor.convert_imgdataset_to_vecdataset(dataset=patched_processed_curr_test_x)
            # perform the prediction
            if all_together:
                predicted_vecs, predicted_vars = GP_model.predict(Xnew=np.array(curr_test_x_vecs).reshape((len(curr_test_x_vecs), -1)))
            else: 
                predicted_vecs = []
                predicted_vars = []
                start_range = 0
                while start_range < len(curr_test_x_vecs):
                    end_range = min(start_range + self.predictor.max_all_together, len(curr_test_x_vecs))
                    predict_subset_x = curr_test_x_vecs[start_range:end_range]
                    predict_subset_ymean, predict_subset_yvar = self.GP_model.predict(Xnew=np.array(predict_subset_x).reshape((len(predict_subset_x), -1)))


                    predict_subset_samples = self.get_all_samples_gpoutput(predict_subset_ymean=predict_subset_ymean, predict_subset_yvar=predict_subset_yvar)
                    
                    # append
                    #predicted_vecs.extend(predict_subset_ymean)
                    predicted_vecs.extend(predict_subset_samples)
                    predicted_vars.extend(predict_subset_yvar)
                    #pdb.set_trace()

                    start_range += self.predictor.max_all_together
                    
            # convert the predicted vecs to patches
            predicted_patches = [vec.reshape(patch_obj.get_ypatch_dim()) for vec in predicted_vecs]

            #pdb.set_trace()
            # convert the predicted patches to one predicted image
            if use_variance_weighting:
                predicted_image, predicted_padded_image, predicted_var_image, _, _ = patch_obj.unpatchify_image(patch_list=predicted_patches, patch_variance_list=predicted_vars, 
                                                                                                          img_type='y')
            else:
                #print("NOT USING VARIANCE WEIGHTING!")
                predicted_image, predicted_padded_image, predicted_var_image, _, _ = patch_obj.unpatchify_image(patch_list=predicted_patches, patch_variance_list=None,
                                                                                                          img_type='y')
            if self.data_processor.ytype == 'diff':
                # properly calculate the predicted image
                last_image = curr_test_x[-1]
                predicted_image = predicted_image + last_image

            # append to stored lists
            predicted_seq_imgs.append(predicted_image)
            predicted_seq_vars.append(predicted_var_image)

            # update the curr test x
            curr_test_x.append(predicted_image)
            curr_test_x = curr_test_x[1:]

        return predicted_seq_imgs, predicted_seq_vars, xdataset


    def get_all_samples_gpoutput(self, predict_subset_ymean, predict_subset_yvar):
        """ 
        Given the mean and variance output of the gp model for several patches compute the corresponding samples
        args: 
            - predict_subset_ymean: the mean values outputted for a subset of the GP model predictions
            - predict_subset_yvar: the variance values with one variance value per outputted patch for the GP model predictions

        returns:  
            - predicted_samples: numpy array that corresponds to a list of samples from the distributions described from the 
            inputted parameters. 

        """ 
        # sample each ymean
        predicted_samples = []
        for sample_num in range(predict_subset_ymean.shape[0]):
            current_cov = np.diag(list(predict_subset_yvar[sample_num])*predict_subset_ymean.shape[1])
            current_mean = predict_subset_ymean[sample_num]
            current_sample = np.random.multivariate_normal(mean=current_mean, 
                                                           cov=current_cov)
            predicted_samples.append(current_sample)
        predicted_samples = np.array(predicted_samples)

        return predicted_samples


    def analytical_mean_var_propogation_Faster_OLDLEGACYCODE(self, input_means, input_covs, dim1_input_vals=None, dim2_input_vals=None):
        """ 
        Propogates forward the mean and variance of the probability distribution. 
        Attempt to do this faster by parallelizing multiple patches at once using tensor operations. 

        args: 
            - input_mean: the mean along the input dimensions that are random variables
            - input_cov: the corresponding covariance matrix to the input mean 
            - dim1_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first mean propogation
            - dim2_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first and second mean propogation
        returns: 
            - current_prediction_mean: the propogated mean 
            - current_prediction_var : the propogated variance
        """ 
        EIN_OPTIMIZER = 'optimal'
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        if type(dim1_input_vals) != type(None) and type(dim2_input_vals) != type(None):

            start_time = time.time()
            dim0_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim0_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
            dim1_input_vals = dim1_input_vals
            dim2_input_vals = dim2_input_vals

            qa_coef = self.variance_dim0/np.sqrt(np.linalg.det(np.matmul(dim0_input_cov_patches, self.lengthscale_matrix_dim0_inv) + np.eye(dim0_input_cov_patches.shape[-1]))).reshape((-1, 1))
            k1_coef = self.kernel_dim1.K(dim1_input_vals, self.dim1_train_xvecs)
            k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
            k_coef  = k1_coef * k2_coef #np.array(self.kernel_dim1.K(dim1_input_vals, self.dim1_train_xvecs) * self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs))
            # exponential part
            dim0_input_mean_patches_seperated = dim0_input_mean_patches.reshape((dim0_input_mean_patches.shape[0], 1, dim0_input_mean_patches.shape[1]))
            v = dim0_input_mean_patches_seperated - self.dim0_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim0_input_cov_patches + self.lengthscale_matrix_dim0) 
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))
            
            qa = qa_coef * k_coef * np.exp(-0.5 * qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)
            #print("Mean predicted: ", time.time() - start_time)

            #clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim0_input_mean_patches_seperated)
            del(qa_coef)

            """
            # TRY WITH PYTORCH - START
            # Variance Propagation
            # var = E_x[var(f|x)] + E_x[E_f[f|x] E_f[f|x]] - mean^2
            # 1. Compute Q matrix
            normal_start_time = time.time()
            start_time = time.time()
            R = np.dot(dim0_input_cov_patches, self.lengthscale_matrix_dim0_inv + self.lengthscale_matrix_dim0_inv) + np.eye(dim0_input_cov_patches.shape[-1])
            R_sqrt_det = np.sqrt(np.linalg.det(R)).reshape((-1,1,1))
            R_inv = np.linalg.inv(R)
            #print("R computations: ", time.time() - start_time) 

            # Z computation
            start_time = time.time()
            lengthscale_inv_v = np.dot(v, self.lengthscale_matrix_dim0_inv)
            all_zij = np.repeat(lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], lengthscale_inv_v.shape[1], 1, lengthscale_inv_v.shape[2])), repeats=lengthscale_inv_v.shape[1], axis=2) + \
                      np.repeat(lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], 1, lengthscale_inv_v.shape[1], lengthscale_inv_v.shape[2])), repeats=lengthscale_inv_v.shape[1], axis=1)
            #print("Z computation: ", time.time() - start_time)

            # Q matrix coefficient
            start_time = time.time()
            k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
            total_k_coef = k0_coef * k1_coef * k2_coef 
            total_k_coef = total_k_coef.reshape((total_k_coef.shape[0], 1, total_k_coef.shape[1]))
            Qmat_coefs = np.matmul(np.transpose(total_k_coef, axes=(0,2,1)), total_k_coef)/R_sqrt_det
            #print("Q matrix coefficient computation: ", time.time() - start_time)

            # Q matrix computation
            start_time = time.time()
            Q_mat_exp_covmat = np.matmul(R_inv, dim0_input_cov_patches)
            Qmat = Qmat_coefs * np.exp(0.5 * np.matmul(np.expand_dims(all_zij, axis=-2), np.matmul(np.expand_dims(Q_mat_exp_covmat, axis=[1,2]), np.expand_dims(all_zij, axis=-1))).reshape(all_zij.shape[0:3]))
            #print("Q computation: ", time.time() - start_time)

            # clear var variables
            # del(all_zij)
            # del(Q_mat_exp_covmat)
            # del(Qmat_coefs)
            # del(total_k_coef)
            # del(lengthscale_inv_v)

            start_time = time.time()
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms = self.full_variance - np.trace(np.dot(Qmat, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms = np.dot(np.transpose(np.dot(Qmat, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms = np.square(current_prediction_mean_patches).flatten()

            current_prediction_var_patches = var_first_terms + var_second_terms - var_third_terms
            #print("Final variance computation: ", time.time() - start_time)
            #print("Total time for normal prediction: ", time.time() - normal_start_time)
            """
            
            ################################################################################################
            torch_start_time = time.time()
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                start_time = time.time()
                #dim0_input_mean_patches_tr = torch.cuda.DoubleTensor(dim0_input_mean_patches)
                dim0_input_cov_patches_tr = torch.cuda.DoubleTensor(dim0_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)
                #print("Time to load tensors: ", time.time() - start_time)

                start_time = time.time()
                R_tr = torch.matmul(dim0_input_cov_patches_tr, self.lengthscale_matrix_dim0_inv_tr + self.lengthscale_matrix_dim0_inv_tr) + torch.eye(n=dim0_input_cov_patches.shape[-1]).to(device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim0_input_cov_patches_tr)
                del(dim0_input_cov_patches_tr)
                del(R_tr)
                #print("R torch times: ", time.time() - start_time)

                # Z torch computation
                start_time = time.time()
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim0_inv_tr)
                all_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr, dim=2).expand(-1, -1, lengthscale_inv_v_tr.shape[1], -1) + \
                              torch.unsqueeze(lengthscale_inv_v_tr, dim=1).expand(-1, lengthscale_inv_v_tr.shape[1], -1, -1)
                del(lengthscale_inv_v_tr)
                del(v_tr)
                #print("Z torch times: ", time.time() - start_time)

                # Q matrix coefficient
                start_time = time.time()
                k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
                total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
                total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)
                #print("Q matrix coefficient computation: ", time.time() - start_time)

                # Q matrix torch computation 
                start_time = time.time()
                Qmat_tr = Qmat_coefs_tr * torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(all_zij_tr, dim=-2), \
                            torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(all_zij_tr, dim=-1))))
                #print("Q torch computation: ", time.time() - start_time)
                
                # Final variance torch computation 
                Qmat_tr = Qmat_tr.detach().cpu().numpy()
        
            start_time = time.time()
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(np.dot(Qmat_tr, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr = np.dot(np.transpose(np.dot(Qmat_tr, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()

            current_prediction_var_patches_tr = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr

            #print("Final variance computation: ", time.time() - start_time)
            #print("Total time for torch prediction: ", time.time() - torch_start_time)
            # TRY WITH PYTORCH - END
            current_prediction_var_patches = current_prediction_var_patches_tr

            """
            # OLD METHOD FOR TESTING
            all_current_prediction_means = []
            all_current_prediction_vars = []

            for patch_num in range(input_means.shape[0]): 
                input_mean = input_means[patch_num] 
                input_cov = input_covs[patch_num]
                dim1_input_val = dim1_input_vals[patch_num]
                dim2_input_val = dim2_input_vals[patch_num]

                # second prediction 
                dim0_input_mean = input_mean.reshape((1, -1))
                dim0_input_cov = input_cov #TODO: REMOVE AND CHANGE THIS BACK
                dim1_input_val = dim1_input_val.reshape((1, -1))
                dim2_input_val = dim2_input_val.reshape((1, -1))

                # mean propogation
                qa_coef_mat1 = np.einsum('ij,jk',dim0_input_cov, self.lengthscale_matrix_dim0_inv, optimize=EIN_OPTIMIZER) + np.eye(dim0_input_cov.shape[0])
                qa_coef1 = self.variance_dim0/np.sqrt(np.linalg.det(qa_coef_mat1))
                # kernel coefficients due to non random variable inputs
                k1_coef1 = self.kernel_dim1.K(self.dim1_train_xvecs, dim1_input_val)
                k2_coef1 = self.kernel_dim2.K(self.dim2_train_xvecs, dim2_input_val)
                k_coef1 = np.array(k1_coef1 * k2_coef1)
                # exponential part
                qa_exp_covmat1 = self.stable_cho_inverse(matrix=dim0_input_cov + self.lengthscale_matrix_dim0, cholesky=False)
                v1 = dim0_input_mean - self.dim0_train_xvecs
                #qa_exp = np.diagonal(np.dot(v, np.dot(qa_exp_covmat, v.T))).reshape((-1, 1))
                qa_exp1 = np.einsum('ij,ji->i', v1, np.einsum('ij,jk',qa_exp_covmat1, v1.T, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)

                qa1 = qa_coef1 * k_coef1 * np.exp(-0.5 * qa_exp1).reshape(k_coef1.shape)
                current_prediction_mean1 = np.einsum('ij,jk', self.Ba_vector.T, qa1, optimize=EIN_OPTIMIZER)

                # variance propogation
                # var = E_x[var(f|x)] + E_x[E_f[f|x] E_f[f|x]] - mean^2
                # 1. Compute Q matrix
                R1 = np.einsum('ij,jk',dim0_input_cov, self.lengthscale_matrix_dim0_inv + self.lengthscale_matrix_dim0_inv, optimize=EIN_OPTIMIZER) + np.eye(dim0_input_cov.shape[0])
                R_sqrt_det1 = np.sqrt(np.linalg.det(R1))
                R_inv1 = np.linalg.inv(R1)
                
                Q_mat_exp_covmat1 = np.einsum('ij,jk',R_inv1, dim0_input_cov, optimize=EIN_OPTIMIZER)
                
                
                #start_time = time.time()
                # try computing Q matrix efficiently 
                k0_coef1 = self.kernel_dim0.K(self.dim0_train_xvecs, dim0_input_mean)
                total_k_coef = k0_coef1 * k1_coef1 * k2_coef1
                Q_coef1 = np.outer(total_k_coef, total_k_coef)/R_sqrt_det1#np.einsum('ij,kl->ik', k0_coef1, k0_coef1, optimize=EIN_OPTIMIZER)/R_sqrt_det
                lengthscale_inv_v1 = np.einsum('ij,kj->kj', self.lengthscale_matrix_dim0_inv, v1, optimize=EIN_OPTIMIZER)
                
                lengthscale_inv_v_seperated1 = lengthscale_inv_v1.reshape((lengthscale_inv_v1.shape[0], 1, lengthscale_inv_v1.shape[1]))
                all_zij1 = lengthscale_inv_v_seperated1 + lengthscale_inv_v1 # can index with all_zij[i, j] 
                
                Qexp_dot_prod_later1 = np.einsum('kk,ijk->ijk', Q_mat_exp_covmat1, all_zij1, optimize=EIN_OPTIMIZER)
                Qexp_dot_prod1 = np.einsum('ijk,ijk->ij', all_zij1, Qexp_dot_prod_later1, optimize=EIN_OPTIMIZER)
                Qexp_dot_prod_mat1 = 0.5 * Qexp_dot_prod1

                #Q_k_coef1 = np.einsum('ij,kl->ik', k_coef1, k_coef1, optimize=EIN_OPTIMIZER)
                Q_mat1 = Q_coef1 * np.exp(Qexp_dot_prod_mat1)

                #print("Finished computing faster Q compute: ", time.time() - start_time)
                # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
                var_first_term21 = self.full_variance - np.trace(np.einsum('ij,jk', self.K_train_noise_inv, Q_mat1, optimize=EIN_OPTIMIZER))
                # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
                var_second_term21 = np.einsum('ij,jk', self.Ba_vector.T, np.einsum('ij,jk', Q_mat1, self.Ba_vector, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)
                # 4. Compute 4th term mean^2
                var_third_term21 = current_prediction_mean1**2

                var_third_term_ein = np.einsum('ij,jk', self.Ba_vector.T, np.einsum('ij,jk', qa1, np.einsum('ij,jk', qa1.T, self.Ba_vector)))

                current_prediction_var21 = var_first_term21 + var_second_term21 - var_third_term21

                current_prediction_var = current_prediction_var21

                all_current_prediction_means.append(current_prediction_mean1)
                all_current_prediction_vars.append(current_prediction_var)

            print("Debugging information")
            print("Batch computation mean : ", float(np.array(current_prediction_mean_patches)))
            print("Single computation mean: ", float(np.array(all_current_prediction_means)))
            print("\n")
            print("Batch computation var: ", float(np.array(current_prediction_var_patches)))
            print("Single computation var: ", float(np.array(all_current_prediction_vars)))
            print("Batch torch computation var: ", float(np.array(current_prediction_var_patches_tr)))
            print("\n")
            # # pdb.set_trace()
            # return current_prediction_mean_patches, current_prediction_var_patches#np.array(all_current_prediction_means), np.array(all_current_prediction_vars)
            print("Finished prediction")
            pdb.set_trace()
            """
        elif type(dim2_input_vals) != type(None):
            # Mean Propagation - NEW
            dim01_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim01_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
            dim0_input_mean_patches = dim01_input_mean_patches[:, self.dim0_start_idx:self.dim0_end_idx] 
            dim1_input_mean_patches = dim01_input_mean_patches[:, self.dim1_start_idx:self.dim1_end_idx]
            dim2_input_vals = dim2_input_vals

            qa_coef = (self.variance_dim0*self.variance_dim1)/np.sqrt(np.linalg.det(np.matmul(dim01_input_cov_patches, self.lengthscale_matrix_dim01_inv) + np.eye(dim01_input_cov_patches.shape[-1]))).reshape((-1, 1))
            k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
            
            # exponential part
            dim01_input_mean_patches_seperated = dim01_input_mean_patches.reshape((dim01_input_mean_patches.shape[0], 1, dim01_input_mean_patches.shape[1]))
            v = dim01_input_mean_patches_seperated - self.dim01_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim01_input_cov_patches + self.lengthscale_matrix_dim01)
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))

            qa = qa_coef * k2_coef * np.exp(-0.5*qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim01_input_mean_patches_seperated)
            del(qa_coef)

            """
            start_time = time.time()
            # Var Propogation 
            # var = E_x[var(f|x)] + E_x[E_f[f|x] E_f[f|x]] - mean^2
            # 1. Compute Q matrix
            R = np.dot(dim01_input_cov_patches, self.lengthscale_matrix_dim01_inv + self.lengthscale_matrix_dim01_inv) + np.eye(dim01_input_cov_patches.shape[-1])
            R_sqrt_det = np.sqrt(np.linalg.det(R)).reshape((-1,1,1))
            R_inv = np.linalg.inv(R)
            print("R computation: ", time.time() - start_time)

            start_time = time.time()
            # Z computation
            lengthscale_inv_v = np.dot(v, self.lengthscale_matrix_dim01_inv)
            all_zij = np.repeat(lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], lengthscale_inv_v.shape[1], 1, lengthscale_inv_v.shape[2])), repeats=lengthscale_inv_v.shape[1], axis=2) + \
                      np.repeat(lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], 1, lengthscale_inv_v.shape[1], lengthscale_inv_v.shape[2])), repeats=lengthscale_inv_v.shape[1], axis=1) 
            print("zij creation: ", time.time() - start_time)

            start_time = time.time()
            # Q matrix coefficient
            k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
            k1_coef = self.kernel_dim1.K(dim1_input_mean_patches, self.dim1_train_xvecs)
            total_k_coef = k0_coef * k1_coef * k2_coef 
            total_k_coef = total_k_coef.reshape((total_k_coef.shape[0], 1, total_k_coef.shape[1]))
            Qmat_coefs = np.matmul(np.transpose(total_k_coef, axes=(0,2,1)), total_k_coef)/R_sqrt_det
            print("Q matrix coefficient computation: ", time.time() - start_time)

            start_time = time.time()
            # Q matrix computation
            Q_mat_exp_covmat = np.matmul(R_inv, dim01_input_cov_patches)
            Qmat = Qmat_coefs * np.exp(0.5 * np.matmul(np.expand_dims(all_zij, axis=-2), \
                                            np.matmul(np.expand_dims(Q_mat_exp_covmat, axis=[1,2]), np.expand_dims(all_zij, axis=-1))\
                                            ).reshape(all_zij.shape[0:3]))
            print("Actual Q matrix computation: ", time.time() - start_time)

            # clear var variables
            del(all_zij)
            del(Q_mat_exp_covmat)
            del(Qmat_coefs)
            del(total_k_coef)
            del(lengthscale_inv_v)

            start_time = time.time()
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms = self.full_variance - np.trace(np.dot(Qmat, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms = np.dot(np.transpose(np.dot(Qmat, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms = np.square(current_prediction_mean_patches).flatten()

            current_prediction_var_patches = var_first_terms + var_second_terms - var_third_terms
            print("Final variance computation: ", time.time() - start_time)
            #pdb.set_trace()
            """
            
            ################################################################################################
            torch_start_time = time.time()
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                start_time = time.time()
                #dim01_input_mean_patches_tr = torch.cuda.DoubleTensor(dim01_input_mean_patches)
                dim01_input_cov_patches_tr = torch.cuda.DoubleTensor(dim01_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)
                #print("Time to load tensors: ", time.time() - start_time)

                start_time = time.time()
                R_tr = torch.matmul(dim01_input_cov_patches_tr, self.lengthscale_matrix_dim01_inv_tr + self.lengthscale_matrix_dim01_inv_tr) + torch.eye(n=dim01_input_cov_patches.shape[-1]).to(device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim01_input_cov_patches_tr)
                del(R_tr)
                del(dim01_input_cov_patches_tr)
                #print("R torch times: ", time.time() - start_time)

                # Z torch computation
                start_time = time.time()
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim01_inv_tr)
                all_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr, dim=2).expand(-1, -1, lengthscale_inv_v_tr.shape[1], -1) + \
                              torch.unsqueeze(lengthscale_inv_v_tr, dim=1).expand(-1, lengthscale_inv_v_tr.shape[1], -1, -1)
                del(v_tr)
                del(lengthscale_inv_v_tr)
                #print("Z torch times: ", time.time() - start_time)

                # Q matrix coefficient
                start_time = time.time()
                k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
                k1_coef = self.kernel_dim1.K(dim1_input_mean_patches, self.dim1_train_xvecs)
                total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
                total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)
                #print("Q matrix coefficient computation: ", time.time() - start_time)

                # Q matrix torch computation 
                start_time = time.time()
                Qmat_tr = Qmat_coefs_tr * torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(all_zij_tr, dim=-2), \
                            torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(all_zij_tr, dim=-1))))
                #print("Q torch computation: ", time.time() - start_time)
                
                # Final variance torch computation 
                Qmat_tr = Qmat_tr.detach().cpu().numpy()

            start_time = time.time()
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(np.dot(Qmat_tr, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr = np.dot(np.transpose(np.dot(Qmat_tr, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()

            current_prediction_var_patches_tr = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr
            #print("Final variance computation: ", time.time() - start_time)
            #print("Total time for torch prediction: ", time.time() - torch_start_time)
            #pdb.set_trace()

            current_prediction_var_patches = current_prediction_var_patches_tr
            
        else:
            """ 
            dim012_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim012_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)

            qa_coef_mat = np.einsum('ijk,jk->ijk', dim012_input_cov_patches, self.lengthscale_matrix_dim012_inv) + np.eye(dim012_input_cov_patches.shape[-1])
            qa_coef = (self.full_variance)/np.sqrt(np.linalg.det(qa_coef_mat)).reshape((-1, 1))
            # exponential part 
            qa_exp_covmat = np.linalg.inv(dim012_input_cov_patches + self.lengthscale_matrix_dim012)
            dim012_input_mean_patches_seperated = dim012_input_mean_patches.reshape((dim012_input_mean_patches.shape[0], 1, dim012_input_mean_patches.shape[1]))
            v = dim012_input_mean_patches_seperated - self.full_train_xvecs
            qa_exp_laterhalf = np.einsum('ijk,ilk->ilk', qa_exp_covmat, v)
            qa_exp = np.einsum('ijk,ijk->ij',v,qa_exp_laterhalf)
            qa = qa_coef * np.exp(-0.5*qa_exp)

            current_prediction_mean = np.dot(self.Ba_vector.T, qa.T)
            current_prediction_var  = np.zeros(current_prediction_mean.shape)
            """
            # Mean Propagation - NEW
            start_time = time.time()
            dim012_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim012_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)

            qa_coef = (self.full_variance)/np.sqrt(np.linalg.det(np.matmul(dim012_input_cov_patches, self.lengthscale_matrix_dim012_inv) + np.eye(dim012_input_cov_patches.shape[-1]))).reshape((-1, 1))
            # exponential part
            dim012_input_mean_patches_seperated = dim012_input_mean_patches.reshape((dim012_input_mean_patches.shape[0], 1, dim012_input_mean_patches.shape[1]))
            v = dim012_input_mean_patches_seperated - self.full_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim012_input_cov_patches + self.lengthscale_matrix_dim012)
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))

            qa = qa_coef * np.exp(-0.5*qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim012_input_mean_patches_seperated)
            del(qa_coef)

            """
            # Var Propogation 
            # var = E_x[var(f|x)] + E_x[E_f[f|x] E_f[f|x]] - mean^2
            # 1. Compute Q matrix
            start_time = time.time()

            R = np.dot(dim012_input_cov_patches, self.lengthscale_matrix_dim012_inv + self.lengthscale_matrix_dim012_inv) + np.eye(dim012_input_cov_patches.shape[-1])
            R_sqrt_det = np.sqrt(np.linalg.det(R)).reshape((-1,1,1))
            R_inv = np.linalg.inv(R)

            # Z computation
            lengthscale_inv_v = np.dot(v, self.lengthscale_matrix_dim012_inv)
            all_zij = np.repeat(lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], lengthscale_inv_v.shape[1], 1, lengthscale_inv_v.shape[2])), repeats=lengthscale_inv_v.shape[1], axis=2) + \
                      np.repeat(lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], 1, lengthscale_inv_v.shape[1], lengthscale_inv_v.shape[2])), repeats=lengthscale_inv_v.shape[1], axis=1) 

            # Q matrix coefficient
            total_k_coef = self.GP_model.kern.K(dim012_input_mean_patches, self.full_train_xvecs)
            total_k_coef = total_k_coef.reshape((total_k_coef.shape[0], 1, total_k_coef.shape[1]))
            Qmat_coefs = np.matmul(np.transpose(total_k_coef, axes=(0,2,1)), total_k_coef)/R_sqrt_det

            # Q matrix computation
            Q_mat_exp_covmat = np.matmul(R_inv, dim012_input_cov_patches)
            Qmat = Qmat_coefs * np.exp(0.5 * np.matmul(np.expand_dims(all_zij, axis=-2), \
                                            np.matmul(np.expand_dims(Q_mat_exp_covmat, axis=[1,2]), np.expand_dims(all_zij, axis=-1))\
                                            ).reshape(all_zij.shape[0:3]))

            # clear var variables
            del(all_zij)
            del(Q_mat_exp_covmat)
            del(Qmat_coefs)
            del(total_k_coef)
            del(lengthscale_inv_v)

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms = self.full_variance - np.trace(np.dot(Qmat, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms = np.dot(np.transpose(np.dot(Qmat, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms = np.square(current_prediction_mean_patches).flatten()

            current_prediction_var_patches = var_first_terms + var_second_terms - var_third_terms
            """

            ################################################################################################
            torch_start_time = time.time()
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                start_time = time.time()
                #dim012_input_mean_patches_tr = torch.cuda.DoubleTensor(dim012_input_mean_patches)
                dim012_input_cov_patches_tr = torch.cuda.DoubleTensor(dim012_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)
                #print("Time to load tensors: ", time.time() - start_time)

                start_time = time.time()
                R_tr = torch.matmul(dim012_input_cov_patches_tr, self.lengthscale_matrix_dim012_inv_tr + self.lengthscale_matrix_dim012_inv_tr) + torch.eye(n=dim012_input_cov_patches.shape[-1]).to(device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim012_input_cov_patches_tr)
                del(R_tr)
                del(dim012_input_cov_patches_tr)
                #print("R torch times: ", time.time() - start_time)

                # Z torch computation
                start_time = time.time()
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim012_inv_tr)
                all_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr, dim=2).expand(-1, -1, lengthscale_inv_v_tr.shape[1], -1) + \
                              torch.unsqueeze(lengthscale_inv_v_tr, dim=1).expand(-1, lengthscale_inv_v_tr.shape[1], -1, -1)
                del(v_tr)
                del(lengthscale_inv_v_tr)
                #print("Z torch times: ", time.time() - start_time)

                # Q matrix coefficient
                start_time = time.time()
                total_k_coef = self.GP_model.kern.K(dim012_input_mean_patches, self.full_train_xvecs)
                total_k_coef_tr = torch.unsqueeze(torch.cuda.DoubleTensor(total_k_coef), dim=1)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)
                #print("Q matrix coefficient computation: ", time.time() - start_time)

                # Q matrix torch computation 
                start_time = time.time()
                Qmat_tr = Qmat_coefs_tr * torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(all_zij_tr, dim=-2), \
                            torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(all_zij_tr, dim=-1))))
                #print("Q torch computation: ", time.time() - start_time)
                
                # Final variance torch computation 
                Qmat_tr = Qmat_tr.detach().cpu().numpy()

            start_time = time.time()
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(np.dot(Qmat_tr, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr = np.dot(np.transpose(np.dot(Qmat_tr, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()

            current_prediction_var_patches_tr = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr
            #print("Final variance computation: ", time.time() - start_time)
            #print("Total time for torch prediction: ", time.time() - torch_start_time)

            current_prediction_var_patches = current_prediction_var_patches_tr

            

        # force negative jitter to zero
        current_prediction_var_patches[current_prediction_var_patches < 0] = 0

        return current_prediction_mean_patches, current_prediction_var_patches

#########################################################################################################################################################
#########################################################################################################################################################
#########################################################################################################################################################
#########################################################################################################################################################
######################################################### Mean and Variance Propagation #################################################################
#########################################################################################################################################################
#########################################################################################################################################################
#########################################################################################################################################################
#########################################################################################################################################################

    def analytical_mean_var_propogation(self, input_mean, input_cov, dim1_input_val=None, dim2_input_val=None, set_debug=False):
        """ 
        Propogates forward the mean and variance of the probability distribution. 
        args: 
            - input_mean: the mean along the input dimensions that are random variables
            - input_cov: the corresponding covariance matrix to the input mean 
            - dim1_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first mean propogation
            - dim2_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first and second mean propogation
        returns: 
            - current_prediction_mean: the propogated mean 
            - current_prediction_var : the propogated variance
        """ 

        if type(dim1_input_val) != type(None) and type(dim2_input_val) != type(None):

            """
            # second prediction 
            dim0_input_mean = input_mean.reshape((1, -1))
            dim0_input_cov = input_cov

            # mean propogation
            qa_coef_mat = np.dot(dim0_input_cov, self.lengthscale_matrix_dim0_inv) + np.eye(dim0_input_cov.shape[0])
            qa_coef = self.variance_dim0/np.sqrt(np.linalg.det(qa_coef_mat))
            # kernel coefficients due to non random variable inputs
            k1_coef = self.kernel_dim1.K(self.dim1_train_xvecs, dim1_input_val.reshape((1, -1)))
            k2_coef = self.kernel_dim2.K(self.dim2_train_xvecs, dim2_input_val.reshape((1, -1)))
            k_coef = np.array(k1_coef * k2_coef)
            # exponential part
            qa_exp_covmat = self.stable_cho_inverse(matrix=dim0_input_cov + self.lengthscale_matrix_dim0, cholesky=False)
            v = self.dim0_train_xvecs - dim0_input_mean
            #qa_exp = np.diagonal(np.dot(v, np.dot(qa_exp_covmat, v.T))).reshape((-1, 1))
            qa_exp = np.einsum('ij,ji->i', v, np.dot(qa_exp_covmat, v.T)).reshape((-1, 1))

            qa = qa_coef * k_coef * np.exp(-0.5 * qa_exp)
            current_prediction_mean = np.dot(self.Ba_vector.T, qa)

            # variance propogation
            # var = E_x[var(f|x)] + E_x[E_f[f|x] E_f[f|x]] - mean^2
            # 1. Compute Q matrix
            R = np.dot(dim0_input_cov, self.lengthscale_matrix_dim0_inv + self.lengthscale_matrix_dim0_inv) + np.eye(dim0_input_cov.shape[0])
            R_sqrt_det = np.sqrt(np.linalg.det(R))
            R_inv = np.linalg.inv(R)
            
            Q_mat_exp_covmat = np.dot(R_inv, dim0_input_cov)
            
            Q = np.zeros((self.num_training_points, self.num_training_points))
            Q_exponents = np.zeros((self.num_training_points, self.num_training_points))
            all_zij_old = np.zeros((self.num_training_points, self.num_training_points, v.shape[-1]))
            
            start_time = time.time()
            for i in range(Q.shape[0]):
                for j in range(Q.shape[1]): 
                    train_dim0_i = self.dim0_train_xvecs[i].reshape((1, -1))
                    train_dim0_j = self.dim0_train_xvecs[j].reshape((1, -1))

                    zij = np.dot(self.lengthscale_matrix_dim0_inv, v[i]) + np.dot(self.lengthscale_matrix_dim0_inv, v[j])
                    Qij_coef = (self.kernel_dim0.K(train_dim0_i, dim0_input_mean)*self.kernel_dim0.K(train_dim0_j, dim0_input_mean))/R_sqrt_det
                    # additional coefficients from non-random inputs
                    Qij_k_coef = k1_coef[i]*k2_coef[i]*k1_coef[j]*k2_coef[j]

                    Qij_exponent = 0.5 * np.dot(zij, np.dot(Q_mat_exp_covmat, zij))
                    Qij = Qij_k_coef * Qij_coef * np.exp(Qij_exponent)
                    Q[i, j] = float(Qij)
                    Q_exponents[i, j] = Qij_exponent
                    all_zij_old[i, j] = zij

            print("Q matrix computed in time: ", time.time() - start_time)
            
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_term = self.full_variance - np.trace(np.dot(self.K_train_noise_inv, Q))
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_term = np.dot(self.Ba_vector.T, np.dot(Q, self.Ba_vector))
            # 4. Compute 4th term mean^2
            var_third_term = current_prediction_mean**2
            print("Finished calculating the variance propogation")
            current_prediction_var = var_first_term + var_second_term - var_third_term
            
            #print("Starting faster compute")
            #start_time = time.time()
            # try computing Q matrix efficiently 
            Q_coef = np.dot(self.kernel_dim0.K(self.dim0_train_xvecs, dim0_input_mean), self.kernel_dim0.K(self.dim0_train_xvecs, dim0_input_mean).T)/R_sqrt_det
            lengthscale_inv_v = np.dot(self.lengthscale_matrix_dim0_inv, v.T).T
            lengthscale_inv_v_seperated = lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], 1, lengthscale_inv_v.shape[1]))
            all_zij = lengthscale_inv_v_seperated + lengthscale_inv_v # can index with all_zij[i, j] 
            # all_zij_2d = all_zij.reshape((all_zij.shape[0]*all_zij.shape[1], all_zij.shape[2]))
            # Qexpcovmat_allzij = np.dot(Q_mat_exp_covmat, all_zij_2d.T)
            # Qexp_dot_prod = np.einsum('ij,ij->i', all_zij_2d, Qexpcovmat_allzij.T)
            
            Qexp_dot_prod_later = np.einsum('kk,ijk->ijk', Q_mat_exp_covmat, all_zij)
            Qexp_dot_prod = np.einsum('ijk,ijk->ij', all_zij, Qexp_dot_prod_later)
            Qexp_dot_prod_mat = 0.5 * Qexp_dot_prod

            Q_k_coef = np.dot(k_coef, k_coef.T)
            Q_mat = Q_coef * Q_k_coef * np.exp(Qexp_dot_prod_mat)

            #print("Finished computing faster Q compute: ", time.time() - start_time)
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_term2 = self.full_variance - np.trace(np.dot(self.K_train_noise_inv, Q_mat))
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_term2 = np.dot(self.Ba_vector.T, np.dot(Q_mat, self.Ba_vector))
            # 4. Compute 4th term mean^2
            var_third_term2 = current_prediction_mean**2
            current_prediction_var2 = var_first_term2 + var_second_term2 - var_third_term2
            """

            EIN_OPTIMIZER = 'optimal'#'optimal'

            # ATTEMPT WITH  EVERYTHING AS EINSUM 

            # second prediction 
            dim0_input_mean = input_mean.reshape((1, -1))
            dim0_input_cov = input_cov #TODO: REMOVE AND CHANGE THIS BACK
            dim1_input_val = dim1_input_val.reshape((1, -1))
            dim2_input_val = dim2_input_val.reshape((1, -1))

            # mean propogation
            qa_coef_mat = np.einsum('ij,jk',dim0_input_cov, self.lengthscale_matrix_dim0_inv, optimize=EIN_OPTIMIZER) + np.eye(dim0_input_cov.shape[0])
            qa_coef = self.variance_dim0/np.sqrt(np.linalg.det(qa_coef_mat))
            # kernel coefficients due to non random variable inputs
            k1_coef = self.kernel_dim1.K(self.dim1_train_xvecs, dim1_input_val)
            k2_coef = self.kernel_dim2.K(self.dim2_train_xvecs, dim2_input_val)
            k_coef = np.array(k1_coef * k2_coef)
            # exponential part
            qa_exp_covmat = self.stable_cho_inverse(matrix=dim0_input_cov + self.lengthscale_matrix_dim0, cholesky=False)
            v = self.dim0_train_xvecs - dim0_input_mean
            #qa_exp = np.diagonal(np.dot(v, np.dot(qa_exp_covmat, v.T))).reshape((-1, 1))
            qa_exp = np.einsum('ij,ji->i', v, np.einsum('ij,jk',qa_exp_covmat, v.T, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)

            qa = qa_coef * k_coef * np.exp(-0.5 * qa_exp).reshape(k_coef.shape)
            current_prediction_mean = np.einsum('ij,jk', self.Ba_vector.T, qa, optimize=EIN_OPTIMIZER)

            # variance propogation
            # var = E_x[var(f|x)] + E_x[E_f[f|x] E_f[f|x]] - mean^2
            # 1. Compute Q matrix
            R = np.einsum('ij,jk',dim0_input_cov, self.lengthscale_matrix_dim0_inv + self.lengthscale_matrix_dim0_inv, optimize=EIN_OPTIMIZER) + np.eye(dim0_input_cov.shape[0])
            R_sqrt_det = np.sqrt(np.linalg.det(R))
            R_inv = np.linalg.inv(R)
            
            Q_mat_exp_covmat = np.einsum('ij,jk',R_inv, dim0_input_cov, optimize=EIN_OPTIMIZER)
            
            # Q = np.zeros((self.num_training_points, self.num_training_points))
            # Q_exponents_later = np.zeros((self.num_training_points, self.num_training_points, v.shape[-1]))
            # Q_exponents = np.zeros((self.num_training_points, self.num_training_points))
            # all_zij_old = np.zeros((self.num_training_points, self.num_training_points, v.shape[-1]))
            # Q_coefs_old = np.zeros((self.num_training_points, self.num_training_points))
            # Q_k_coefs_old = np.zeros((self.num_training_points, self.num_training_points))
            
            # start_time = time.time()
            # for i in range(Q.shape[0]):
            #     for j in range(Q.shape[1]): 
            #         train_dim0_i = self.dim0_train_xvecs[i].reshape((1, -1))
            #         train_dim0_j = self.dim0_train_xvecs[j].reshape((1, -1))
            #         vi = v[i].reshape((1, -1))
            #         vj = v[j].reshape((1, -1))

            #         zij = np.einsum('ij,jk', self.lengthscale_matrix_dim0_inv, vi.T, optimize=EIN_OPTIMIZER) + np.einsum('ij,jk', self.lengthscale_matrix_dim0_inv, vj.T, optimize=EIN_OPTIMIZER)
            #         Qij_coef = (self.kernel_dim0.K(train_dim0_i, dim0_input_mean)*self.kernel_dim0.K(train_dim0_j, dim0_input_mean))/R_sqrt_det
            #         # additional coefficients from non-random inputs
            #         Qij_k_coef = k1_coef[i]*k2_coef[i]*k1_coef[j]*k2_coef[j]

            #         Qij_exponent_later = np.einsum('ij,jk', Q_mat_exp_covmat, zij, optimize=EIN_OPTIMIZER)
            #         Qij_exponent = 0.5 * np.einsum('ij,jk', zij.reshape((1, -1)), Qij_exponent_later.reshape((-1, 1)), optimize=EIN_OPTIMIZER)
                    
            #         Qij = Qij_k_coef * Qij_coef * np.exp(Qij_exponent)
            #         Q[i, j] = float(Qij)
                    
            #         Q_exponents_later[i,j] = Qij_exponent_later.flatten()
            #         Q_exponents[i, j] = Qij_exponent
            #         all_zij_old[i, j] = zij.flatten()
            #         Q_coefs_old[i,j] = Qij_coef
            #         Q_k_coefs_old[i,j] = Qij_k_coef

            # print("Q matrix computed in time: ", time.time() - start_time)
            
            # # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            # var_first_term = self.full_variance - np.trace(np.einsum('ij,jk', self.K_train_noise_inv, Q, optimize=EIN_OPTIMIZER))
            # # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            # var_second_term = np.einsum('ij,jk', self.Ba_vector.T, np.einsum('ij,jk', Q, self.Ba_vector, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)
            # # 4. Compute 4th term mean^2
            # var_third_term = current_prediction_mean**2
            # current_prediction_var = var_first_term + var_second_term - var_third_term

            #start_time = time.time()
            # try computing Q matrix efficiently 
            k0_coef = self.kernel_dim0.K(self.dim0_train_xvecs, dim0_input_mean)
            Q_coef = np.einsum('ij,kl->ik', k0_coef, k0_coef, optimize=EIN_OPTIMIZER)/R_sqrt_det
            lengthscale_inv_v = np.einsum('ij,kj->kj', self.lengthscale_matrix_dim0_inv, v, optimize=EIN_OPTIMIZER)
            
            lengthscale_inv_v_seperated = lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], 1, lengthscale_inv_v.shape[1]))
            all_zij = lengthscale_inv_v_seperated + lengthscale_inv_v # can index with all_zij[i, j] 
            
            Qexp_dot_prod_later = np.einsum('kk,ijk->ijk', Q_mat_exp_covmat, all_zij, optimize=EIN_OPTIMIZER)
            Qexp_dot_prod = np.einsum('ijk,ijk->ij', all_zij, Qexp_dot_prod_later, optimize=EIN_OPTIMIZER)
            Qexp_dot_prod_mat = 0.5 * Qexp_dot_prod

            Q_k_coef = np.einsum('ij,kl->ik', k_coef, k_coef, optimize=EIN_OPTIMIZER)
            Q_mat = Q_coef * Q_k_coef * np.exp(Qexp_dot_prod_mat)

            #print("Finished computing faster Q compute: ", time.time() - start_time)
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_term2 = self.full_variance - np.trace(np.einsum('ij,jk', self.K_train_noise_inv, Q_mat, optimize=EIN_OPTIMIZER))
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_term2 = np.einsum('ij,jk', self.Ba_vector.T, np.einsum('ij,jk', Q_mat, self.Ba_vector, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)
            # 4. Compute 4th term mean^2
            var_third_term2 = current_prediction_mean**2

            var_third_term_ein = np.einsum('ij,jk', self.Ba_vector.T, np.einsum('ij,jk', qa, np.einsum('ij,jk', qa.T, self.Ba_vector)))

            current_prediction_var2 = var_first_term2 + var_second_term2 - var_third_term2

            current_prediction_var = current_prediction_var2


            # if current_prediction_var2 < 0: #or current_prediction_var < 0:
            #     print("There is an error: Variance should always be positive 1")
            #     pdb.set_trace()

            if False:#current_prediction_var < 0:

                # NOTE: All below results assume input covariance is zero
                dim0_input_mean = dim0_input_mean.reshape((1, -1))
                dim1_input_val = dim1_input_val.reshape((1, -1))
                dim2_input_val = dim2_input_val.reshape((1, -1))
                totalinput = np.hstack((dim0_input_mean, np.hstack((dim1_input_val, dim2_input_val))))

                kern = self.GP_model.kern
                pred_var = self.GP_model._predictive_variable
                Xnew = totalinput

                Kx = kern.K(pred_var, Xnew)
                Kxx = kern.Kdiag(Xnew)

                # GPy output 
                gpy_mean, gpy_var = self.GP_model.predict(Xnew)

                # GPy posterior raw output full_cov = True
                mu_fullcov, var_fullcov = self.GP_model.posterior._raw_predict(kern=kern, Xnew=Xnew, pred_var=pred_var, full_cov=True)

                # GPy posterior raw output full_cov = True 
                mu_nofullcov, var_nofullcov = self.GP_model.posterior._raw_predict(kern=kern, Xnew=Xnew, pred_var=pred_var, full_cov=False)

                # Completely manual computation with GPy posterior matrices
                var_manual_gpymats = Kxx - np.einsum('ij,jk', Kx.T, np.einsum('ij,jk', self.GP_model.posterior.woodbury_inv, Kx))

                # Completely manual computation using dtrtrs instead of inverse for matrix multiplication
                from GPy.util.linalg import dtrtrs
                tmp_dtrtrs = dtrtrs(self.GP_model.posterior.woodbury_chol, Kx)[0]
                var_dtrtrs = Kxx - np.sum(np.square(tmp_dtrtrs))

                tmp_dtrtrs_scipy = scipy.linalg.lapack.dtrtrs(self.GP_model.posterior.woodbury_chol, Kx)[0]
                var_dtrtrs_scipy = Kxx - np.sum(np.square(tmp_dtrtrs_scipy))

                # Completely manual computation - from kernels
                K_train_noise_inv = np.linalg.inv(kern.K(pred_var, pred_var) + float(self.GP_model.Gaussian_noise.variance)*np.eye(pred_var.shape[0]))
                var_manual = Kxx - np.einsum('ij,jk', Kx.T, np.einsum('ij,jk', K_train_noise_inv, Kx, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)

                K_train_noise_inv_stable = self.stable_cho_inverse(kern.K(pred_var, pred_var) + float(self.GP_model.Gaussian_noise.variance)*np.eye(pred_var.shape[0]))
                var_manual_stable = Kxx - np.einsum('ij,jk', Kx.T, np.einsum('ij,jk', K_train_noise_inv_stable, Kx, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)

                # test start 
                k0mat = np.outer(k0_coef, k0_coef)
                k1mat = np.outer(k1_coef, k1_coef)
                k2mat = np.outer(k2_coef, k2_coef)

                Qmat_total1 = k0mat * k1mat * k2mat
                secondterm_Qmattotal1 = np.einsum('ij,jk', self.Ba_vector.T, np.einsum('ij,jk', Qmat_total1, self.Ba_vector, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)

                total_k = k0_coef * k1_coef * k2_coef
                Qmat_total2 = np.outer(total_k, total_k)
                secondterm_Qmattotal2 = np.einsum('ij,jk', self.Ba_vector.T, np.einsum('ij,jk', Qmat_total2, self.Ba_vector, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)

                # COMPARISON TO DOING ALL THE STUFF TOGETHER
                k0_coef_manual = self.variance_dim0 * np.exp(-0.5 * (np.einsum('ij,ji->i', v, np.einsum('ij,jk', self.lengthscale_matrix_dim0_inv, v.T, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)))
                k0_coef_manual = k0_coef_manual.reshape(k_coef.shape)
                total_k = k0_coef_manual * k1_coef * k2_coef 
                Q_total_coef = np.einsum('ij,kl->ik', total_k, total_k, optimize=EIN_OPTIMIZER)/R_sqrt_det

                Q_mat3 = Q_total_coef * np.exp(Qexp_dot_prod_mat)

                #print("Finished computing faster Q compute: ", time.time() - start_time)
                # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
                var_first_term3 = self.full_variance - np.trace(np.einsum('ij,jk', self.K_train_noise_inv, Q_mat3, optimize=EIN_OPTIMIZER))
                # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
                var_second_term3 = np.einsum('ij,jk', self.Ba_vector.T, np.einsum('ij,jk', Q_mat3, self.Ba_vector, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)
                # 4. Compute 4th term mean^2
                var_third_term3 = current_prediction_mean**2
                
                current_prediction_var3 = var_first_term3 + var_second_term3 - var_third_term3

                # COMPARISON TO DOING ALL THE OUTER PRODUCTS APART!
                #k0_outer = np.einsum('ij,kl->ik', k0_coef, k0_coef, optimize=EIN_OPTIMIZER)
                k0_outer = np.einsum('ij,kl->ik', k0_coef_manual, k0_coef_manual, optimize=EIN_OPTIMIZER)
                k1_outer = np.einsum('ij,kl->ik', k1_coef, k1_coef, optimize=EIN_OPTIMIZER)
                k2_outer = np.einsum('ij,kl->ik', k2_coef, k2_coef, optimize=EIN_OPTIMIZER)
                Q_seperate_coef = (k0_outer *  k1_outer * k2_outer)/R_sqrt_det

                Q_mat4 = Q_seperate_coef * np.exp(Qexp_dot_prod_mat)

                #print("Finished computing faster Q compute: ", time.time() - start_time)
                # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
                var_first_term4 = self.full_variance - np.trace(np.einsum('ij,jk', self.K_train_noise_inv, Q_mat4, optimize=EIN_OPTIMIZER))
                # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
                var_second_term4 = np.einsum('ij,jk', self.Ba_vector.T, np.einsum('ij,jk', Q_mat4, self.Ba_vector, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)
                # 4. Compute 4th term mean^2
                var_third_term4 = current_prediction_mean**2
                
                current_prediction_var4 = var_first_term4 + var_second_term4 - var_third_term4


                print("current var: ", current_prediction_var)
                print("GPy output var: ", gpy_var)
                print("posterior raw predict fullcov, nofullcov: " + str(var_fullcov) + str(var_nofullcov))
                print("manual with gpy mats: ", var_manual_gpymats)
                print("manual np.linalg.inv: ", var_manual)
                print("manual stable inverse: ", var_manual_stable)
                print("current prediction var 2: ", current_prediction_var2)
                print("current prediction var 3: ", current_prediction_var3)
                print("current prediction var 4: ", current_prediction_var4)

                print("End first prediction here")
                pdb.set_trace()
            

        elif type(dim2_input_val) != type(None):
            # third prediction
            dim01_input_mean = input_mean.reshape((1, -1))
            dim01_input_cov = input_cov
            dim0_input_mean = dim01_input_mean[: ,self.dim0_start_idx:self.dim0_end_idx] 
            dim1_input_mean = dim01_input_mean[:, self.dim1_start_idx:self.dim1_end_idx]

            qa_coef_mat = np.dot(dim01_input_cov, self.lengthscale_matrix_dim01_inv) + np.eye(dim01_input_cov.shape[0])
            qa_coef = (self.variance_dim0 * self.variance_dim1)/np.sqrt(np.linalg.det(qa_coef_mat))
            # kernel coefficients due to non random variable inputs
            k2_coef = self.kernel_dim2.K(self.dim2_train_xvecs, dim2_input_val.reshape((1, -1)))
            k_coef = k2_coef
            # exponential part 
            qa_exp_covmat = self.stable_cho_inverse(matrix=dim01_input_cov + self.lengthscale_matrix_dim01, cholesky=False)
            v = self.dim01_train_xvecs - dim01_input_mean
            #qa_exp = np.diagonal(np.dot(v, np.dot(qa_exp_covmat, v.T))).reshape((-1, 1))
            qa_exp = np.einsum('ij,ji->i', v, np.dot(qa_exp_covmat, v.T)).reshape((-1, 1))

            qa = qa_coef * k_coef * np.exp(-0.5 * qa_exp)
            current_prediction_mean = np.dot(self.Ba_vector.T, qa)
            
            # variance propogation
            # var = E_x[var(f|x)] + E_x[E_f[f|x] E_f[f|x]] - mean^2
            # 1. Compute Q matrix
            R = np.dot(dim01_input_cov, self.lengthscale_matrix_dim01_inv + self.lengthscale_matrix_dim01_inv) + np.eye(dim01_input_cov.shape[0])
            R_sqrt_det = np.sqrt(np.linalg.det(R))
            R_inv = np.linalg.inv(R)
            Q_mat_exp_covmat = np.dot(R_inv, dim01_input_cov)
            
            """
            Q = np.zeros((self.num_training_points, self.num_training_points))
            for i in range(Q.shape[0]):
                for j in range(Q.shape[1]): 
                    zij = np.dot(self.lengthscale_matrix_dim01_inv, v[i]) + np.dot(self.lengthscale_matrix_dim01_inv, v[j])
                    Qij_coef = (self.kernel_dim0.K(self.dim0_train_xvecs[i], dim0_input_mean) * self.kernel_dim1.K(self.dim1_train_xvecs[i], dim1_input_mean) *
                                self.kernel_dim0.K(self.dim0_train_xvecs[j], dim0_input_mean) * self.kernel_dim1.K(self.dim1_train_xvecs[j], dim1_input_mean))/R_sqrt_det
                    # additional coefficients from non-random inputs
                    Qij_k_coef = k2_coef[i]*k2_coef[j]

                    Qij_exponent = 0.5 * np.dot(zij, np.dot(Q_mat_exp_covmat, zij))
                    Qij = Qij_k_coef * Qij_coef * np.exp(Qij_exponent)
                    Q[i, j] = Qij

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_term = self.full_variance - np.linalg.trace(np.dot(K_train_noise_inv, Q))
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_term = np.dot(Ba_vector.T, np.dot(Q, Ba_vector))
            # 4. Compute 4th term mean^2
            var_third_term = current_prediction_mean**2
            current_prediction_var = var_first_term + var_second_term - var_third_term
            """

            # try computing Q matrix efficiently 
            k_dim0_mean_coef = self.kernel_dim0.K(self.dim0_train_xvecs, dim0_input_mean)
            k_dim1_mean_coef = self.kernel_dim1.K(self.dim1_train_xvecs, dim1_input_mean)
            k_dim01_mean_coef = k_dim0_mean_coef * k_dim1_mean_coef

            Q_coef = np.dot(k_dim01_mean_coef, k_dim01_mean_coef.T)/R_sqrt_det
            lengthscale_inv_v = np.dot(self.lengthscale_matrix_dim01_inv, v.T).T
            lengthscale_inv_v_seperated = lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], 1, lengthscale_inv_v.shape[1]))
            all_zij = lengthscale_inv_v_seperated + lengthscale_inv_v # can index with all_zij[i, j] 

            Qexp_dot_prod_later = np.einsum('kk,ijk->ijk', Q_mat_exp_covmat, all_zij)
            Qexp_dot_prod = np.einsum('ijk,ijk->ij', all_zij, Qexp_dot_prod_later)
            Qexp_dot_prod_mat = 0.5 * Qexp_dot_prod

            Q_k_coef = np.dot(k_coef, k_coef.T)
            Q_mat = Q_coef * Q_k_coef * np.exp(Qexp_dot_prod_mat)

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_term2 = self.full_variance - np.trace(np.dot(self.K_train_noise_inv, Q_mat))
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_term2 = np.dot(self.Ba_vector.T, np.dot(Q_mat, self.Ba_vector))
            # 4. Compute 4th term mean^2
            var_third_term2 = current_prediction_mean**2
            current_prediction_var = var_first_term2 + var_second_term2 - var_third_term2
            # if current_prediction_var < 0:
            #     print("There is an error: Variance should always be positive 2")
            #     pdb.set_trace()

        else: 
            # # all other predictions
            # dim012_input_mean = input_mean.reshape((1, -1))
            # dim012_input_cov = input_cov

            # qa_coef_mat = np.dot(dim012_input_cov, self.lengthscale_matrix_dim012_inv)  + np.eye(dim012_input_cov.shape[0])
            # qa_coef = self.full_variance/np.sqrt(np.linalg.det(qa_coef_mat))
            # # exponential part
            # qa_exp_covmat = self.stable_cho_inverse(matrix=dim012_input_cov + self.lengthscale_matrix_dim012, cholesky=False)
            # v = self.full_train_xvecs - dim012_input_mean
            # #qa_exp = np.diagonal(np.dot(v, np.dot(qa_exp_covmat, v.T))).reshape((-1, 1))
            # qa_exp = np.einsum('ij,ji->i', v, np.dot(qa_exp_covmat, v.T)).reshape((-1, 1))

            # qa = qa_coef * np.exp(-0.5 * qa_exp)
            # current_prediction_mean = np.dot(self.Ba_vector.T, qa)
            
            # # variance propogation
            # # var = E_x[var(f|x)] + E_x[E_f[f|x] E_f[f|x]] - mean^2
            # # 1. Compute Q matrix
            # R = np.dot(dim012_input_cov, self.lengthscale_matrix_dim012_inv + self.lengthscale_matrix_dim012_inv) + np.eye(dim012_input_cov.shape[0])
            # R_sqrt_det = np.sqrt(np.linalg.det(R))
            # R_inv = np.linalg.inv(R)
            # Q_mat_exp_covmat = np.dot(R_inv, dim012_input_cov)

            # """
            # Q = np.zeros((self.num_training_points, self.num_training_points))
            # for i in range(Q.shape[0]):
            #     for j in range(Q.shape[1]): 
            #         zij = np.dot(self.lengthscale_matrix_dim012_inv, v[i]) + np.dot(self.lengthscale_matrix_dim012_inv, v[j])
            #         Qij_coef = ((self.GP_model.kern.K(self.full_train_xvecs[i], dim012_input_mean) * 
            #                     self.GP_model.kern.K(self.full_train_xvecs[j], dim012_input_mean))/R_sqrt_det)

            #         Qij_exponent = 0.5 * np.dot(zij, np.dot(Q_mat_exp_covmat, zij))
            #         Qij = Qij_coef * np.exp(Qij_exponent)
            #         Q[i, j] = Qij

            # # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            # var_first_term = self.full_variance - np.linalg.trace(np.dot(K_train_noise_inv, Q))
            # # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            # var_second_term = np.dot(Ba_vector.T, np.dot(Q, Ba_vector))
            # # 4. Compute 4th term mean^2
            # var_third_term = current_prediction_mean**2

            # current_prediction_var = var_first_term + var_second_term - var_third_term
            # """

            # # try computing Q matrix efficiently 
            # k_dim012_mean_coef = self.GP_model.kern.K(self.full_train_xvecs, dim012_input_mean)

            # Q_coef = np.dot(k_dim012_mean_coef, k_dim012_mean_coef.T)/R_sqrt_det
            # lengthscale_inv_v = np.dot(self.lengthscale_matrix_dim012_inv, v.T).T
            # lengthscale_inv_v_seperated = lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], 1, lengthscale_inv_v.shape[1]))
            # all_zij = lengthscale_inv_v_seperated + lengthscale_inv_v # can index with all_zij[i, j] 

            # Qexp_dot_prod_later = np.einsum('kk,ijk->ijk', Q_mat_exp_covmat, all_zij)
            # Qexp_dot_prod = np.einsum('ijk,ijk->ij', all_zij, Qexp_dot_prod_later)
            # Qexp_dot_prod_mat = 0.5 * Qexp_dot_prod

            # Q_mat = Q_coef * np.exp(Qexp_dot_prod_mat)

            # # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            # var_first_term2 = self.full_variance - np.trace(np.dot(self.K_train_noise_inv, Q_mat))
            # # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            # var_second_term2 = np.dot(self.Ba_vector.T, np.dot(Q_mat, self.Ba_vector))
            # # 4. Compute 4th term mean^2
            # var_third_term2 = current_prediction_mean**2
            # current_prediction_var = var_first_term2 + var_second_term2 - var_third_term2
            # if current_prediction_var < 0:
            #     print("There is an error: Variance should always be positive 3")
            #     pdb.set_trace()

            # CONVERT EVERYTHING TO USE EINSUM for consistency 
            EIN_OPTIMIZER = 'optimal'

            # all other predictions
            dim012_input_mean = input_mean.reshape((1, -1))
            dim012_input_cov = input_cov

            qa_coef_mat = np.einsum('ij,jk',dim012_input_cov, self.lengthscale_matrix_dim012_inv, optimize=EIN_OPTIMIZER) + np.eye(dim012_input_cov.shape[0])

            qa_coef = self.full_variance/np.sqrt(np.linalg.det(qa_coef_mat))
            # exponential part
            qa_exp_covmat = self.stable_cho_inverse(matrix=dim012_input_cov + self.lengthscale_matrix_dim012, cholesky=False)
            v = self.full_train_xvecs - dim012_input_mean
            #qa_exp = np.diagonal(np.dot(v, np.dot(qa_exp_covmat, v.T))).reshape((-1, 1))
            qa_exp = np.einsum('ij,ji->i', v, np.einsum('ij,jk', qa_exp_covmat, v.T, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER).reshape((-1, 1))

            qa = qa_coef * np.exp(-0.5 * qa_exp)
            current_prediction_mean = np.einsum('ij,jk',self.Ba_vector.T, qa, optimize=EIN_OPTIMIZER)
            
            # variance propogation
            # var = E_x[var(f|x)] + E_x[E_f[f|x] E_f[f|x]] - mean^2
            # 1. Compute Q matrix
            R = np.einsum('ij,jk', dim012_input_cov, self.lengthscale_matrix_dim012_inv + self.lengthscale_matrix_dim012_inv, optimize=EIN_OPTIMIZER)+ np.eye(dim012_input_cov.shape[0]) 
            R_sqrt_det = np.sqrt(np.linalg.det(R))
            R_inv = np.linalg.inv(R)
            Q_mat_exp_covmat = np.einsum('ij,jk', R_inv, dim012_input_cov, optimize=EIN_OPTIMIZER)

            """
            Q = np.zeros((self.num_training_points, self.num_training_points))
            for i in range(Q.shape[0]):
                for j in range(Q.shape[1]): 
                    zij = np.dot(self.lengthscale_matrix_dim012_inv, v[i]) + np.dot(self.lengthscale_matrix_dim012_inv, v[j])
                    Qij_coef = ((self.GP_model.kern.K(self.full_train_xvecs[i], dim012_input_mean) * 
                                self.GP_model.kern.K(self.full_train_xvecs[j], dim012_input_mean))/R_sqrt_det)

                    Qij_exponent = 0.5 * np.dot(zij, np.dot(Q_mat_exp_covmat, zij))
                    Qij = Qij_coef * np.exp(Qij_exponent)
                    Q[i, j] = Qij

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_term = self.full_variance - np.linalg.trace(np.dot(K_train_noise_inv, Q))
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_term = np.dot(Ba_vector.T, np.dot(Q, Ba_vector))
            # 4. Compute 4th term mean^2
            var_third_term = current_prediction_mean**2

            current_prediction_var = var_first_term + var_second_term - var_third_term
            """

            # try computing Q matrix efficiently 
            k_dim012_mean_coef = self.GP_model.kern.K(self.full_train_xvecs, dim012_input_mean)

            Q_coef = np.einsum('ij,jk', k_dim012_mean_coef, k_dim012_mean_coef.T, optimize=EIN_OPTIMIZER)/R_sqrt_det
            lengthscale_inv_v = np.einsum('ij,jk', self.lengthscale_matrix_dim012_inv, v.T, optimize=EIN_OPTIMIZER).T
            lengthscale_inv_v_seperated = lengthscale_inv_v.reshape((lengthscale_inv_v.shape[0], 1, lengthscale_inv_v.shape[1]))
            all_zij = lengthscale_inv_v_seperated + lengthscale_inv_v # can index with all_zij[i, j] 

            Qexp_dot_prod_later = np.einsum('kk,ijk->ijk', Q_mat_exp_covmat, all_zij, optimize=EIN_OPTIMIZER)
            Qexp_dot_prod = np.einsum('ijk,ijk->ij', all_zij, Qexp_dot_prod_later, optimize=EIN_OPTIMIZER)
            Qexp_dot_prod_mat = 0.5 * Qexp_dot_prod

            Q_mat = Q_coef * np.exp(Qexp_dot_prod_mat)

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_term2 = self.full_variance - np.trace(np.einsum('ij,jk', self.K_train_noise_inv, Q_mat, optimize=EIN_OPTIMIZER))
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_term2 = np.einsum('ij,jk', self.Ba_vector.T, np.einsum('ij,jk', Q_mat, self.Ba_vector, optimize=EIN_OPTIMIZER), optimize=EIN_OPTIMIZER)
            # 4. Compute 4th term mean^2
            var_third_term2 = current_prediction_mean**2
            current_prediction_var = var_first_term2 + var_second_term2 - var_third_term2

            # if current_prediction_var < 0:
            #     print("There is an error: Variance should always be positive 3")
            #     pdb.set_trace()

        if set_debug: 
            print("Debug is true: analytical mean var propagation")
            pdb.set_trace()

        # force set var to 0 if negative 
        if float(current_prediction_var) < 0: 
            current_prediction_var = 0

        return current_prediction_mean, current_prediction_var


    def mean_var_propogation(self, starting_x_images, steps, all_together, use_variance_weighting, test_patch_obj=None, single_kernel=False, svd_inverse = False, sigma_threshold = None, plot_for_thresholding = False):
        """
        Predicts a rollout sequence from the starting_x_images for steps number of steps. 
        Propogate the random variable forward at each timestep.  

        args:  
            - starting_x_images: list of starting images: depending on the dataprocessor up to 3 images are needed
                - this is used to start the prediction process
            - steps: the number of steps to rollout
            - all_together: boolean indicator
                - True: predict the entire dataset in one go 
                - False: predict the dataset one image at a time - NOTE: TODO: might need to change this to x patches at a time. 
            - use_variance_weighting: 
                - True: use variance for the mask weighting
                - False: do not use variance - just do averaging
            - single_kernel: boolean indicator: 
                - True: just one lengthscale for all dimensions
                - False: one lengthscale per different image space. 
            - test_patch_obj: type: patcher: Separate patcher to use for prediction, if None use self.patch_obj
            - svd_inverse: type: boolean: 
                - True: use the svd inverse with the specified sigma cutoff
                - False: use the matrices returned by GPy directly 
            - sigma_threshold: 
                - None: take the sigma values up to the "large" dropoff
                    - We define a large dropoff as if the change in values between the next two sigma values is less than 1 order of magnitude compared to the change in the last 2 values
                - float: the maximum sigma value you want to keep before inverting
            - plot_for_thresholding: type: boolean:
                - False: just use the specified thresholds 
                - True: plot the sigma values and go into a set trace to allow the sigma threshold to be set manually by the user. 
        returns: 
            - pred_seq_means: list of unpatchified images that are predicted from the test_x set - these are the means
            - pred_var_images: 
                - returns this if  use_variance_weighting is set to True
            - xdataset: the xdataset that was used to generate each of the predicted images during the rollout
        """
        # default toy variance value for testing
        print("This has been updated: 0")
        TOY_DEFAULT_VAR = 0.01

        # use_variance_weighting input only left in for compatibility - do not use with this approach 
        if type(test_patch_obj) == type(None):
            patch_obj = self.patch_obj
        else: 
            patch_obj = test_patch_obj

        assert(use_variance_weighting == False)
        assert(patch_obj.stride == (1,1))
        assert(patch_obj.get_ypatch_dim() == (1,1)) # This method is really only for when you are estimating the center pixel of the patch 

        # currently only coding for this input space and output space
        assert(self.data_processor.xtypes == ['img0', 'img-1', 'img-2'])
        assert(self.data_processor.ytype == 'img')

        predicted_seq_means = []
        predicted_seq_vars = [] # even though not being estimated set up here to establish framework 

        predicted_input_mean_patches = [] # list of sublists at most 3 long - each sublist is a list of patches - 0th list corresponds to patches of the last image and so on
        predicted_input_var_patches = [] # same as list of mean patches above but with variance 


        # Separate out the kernels - NOTE: these are all the same for each image space for now - structured to make it easy to change later
        pixels_per_inpatch = np.product(patch_obj.patch_dim)
        img_space_lengthscale = float(self.GP_model.kern.lengthscale)
        img_space_variance    = float(self.GP_model.kern.variance)**(1/3)
        
        self.lengthscale_matrix_dim0 = np.eye(pixels_per_inpatch) * img_space_lengthscale**2
        self.lengthscale_matrix_dim1 = np.eye(pixels_per_inpatch) * img_space_lengthscale**2
        self.lengthscale_matrix_dim2 = np.eye(pixels_per_inpatch) * img_space_lengthscale**2
        self.lengthscale_matrix_dim01 = scipy.linalg.block_diag(self.lengthscale_matrix_dim0, self.lengthscale_matrix_dim1)
        self.lengthscale_matrix_dim012 = scipy.linalg.block_diag(self.lengthscale_matrix_dim01, self.lengthscale_matrix_dim2)

        self.lengthscale_matrix_dim0_inv = np.linalg.inv(self.lengthscale_matrix_dim0)
        self.lengthscale_matrix_dim1_inv = np.linalg.inv(self.lengthscale_matrix_dim1)
        self.lengthscale_matrix_dim2_inv = np.linalg.inv(self.lengthscale_matrix_dim2)
        self.lengthscale_matrix_dim01_inv = scipy.linalg.block_diag(self.lengthscale_matrix_dim0_inv, self.lengthscale_matrix_dim1_inv)
        self.lengthscale_matrix_dim012_inv = scipy.linalg.block_diag(self.lengthscale_matrix_dim01_inv, self.lengthscale_matrix_dim2_inv)

        self.variance_dim0 = img_space_variance
        self.variance_dim1 = img_space_variance
        self.variance_dim2 = img_space_variance
        self.full_variance = float(self.GP_model.kern.variance)


        self.kernel_dim0 = GPy.kern.RBF(input_dim=pixels_per_inpatch, lengthscale=img_space_lengthscale, variance=self.variance_dim0) # kernel for the 0th image patch in the input vec
        self.kernel_dim1 = GPy.kern.RBF(input_dim=pixels_per_inpatch, lengthscale=img_space_lengthscale, variance=self.variance_dim1) # kernel for the 1st image patch in the input vec
        self.kernel_dim2 = GPy.kern.RBF(input_dim=pixels_per_inpatch, lengthscale=img_space_lengthscale, variance=self.variance_dim2) # kernel for the 2nd image patch in the input vec

        # Process Data: 
        # Seperate datasets by image patch dimensions
        dim0_start_idx = 0 
        dim0_end_idx = dim0_start_idx + pixels_per_inpatch
        dim1_start_idx = dim0_end_idx
        dim1_end_idx = dim1_start_idx + pixels_per_inpatch
        dim2_start_idx = dim1_end_idx
        dim2_end_idx = dim2_start_idx + pixels_per_inpatch

        self.dim0_start_idx = dim0_start_idx
        self.dim1_start_idx = dim1_start_idx
        self.dim2_start_idx = dim2_start_idx
        self.dim0_end_idx = dim0_end_idx
        self.dim1_end_idx = dim1_end_idx
        self.dim2_end_idx = dim2_end_idx

        # Training data: seperate the image patch dimensions
        self.full_train_xvecs = np.array(self.GP_model._predictive_variable)
        self.num_training_points = self.full_train_xvecs.shape[0]
        self.dim0_train_xvecs = self.full_train_xvecs[:, dim0_start_idx:dim0_end_idx]
        self.dim1_train_xvecs = self.full_train_xvecs[:, dim1_start_idx:dim1_end_idx]
        self.dim2_train_xvecs = self.full_train_xvecs[:, dim2_start_idx:dim2_end_idx]
        self.dim01_train_xvecs = self.full_train_xvecs[:, dim0_start_idx:dim1_end_idx]

        # Test data: starting x images
        # NOTE: add a padder image as create_x always leaves the last image for the y dataset creator
        curr_test_x = list(starting_x_images)
        processed_curr_test_x = self.data_processor.create_x(image_seq=list(curr_test_x) + [np.zeros(curr_test_x[0].shape)])
        patched_processed_curr_test_x = patch_obj.patchify_dataset(dataset=processed_curr_test_x,
            dataset_type='x')
        curr_test_x_vecs = np.array(self.data_processor.convert_imgdataset_to_vecdataset(dataset=patched_processed_curr_test_x))
        dim0_test_start_xvecs = curr_test_x_vecs[:, dim0_start_idx:dim0_end_idx]
        dim1_test_start_xvecs = curr_test_x_vecs[:, dim1_start_idx:dim1_end_idx]
        dim2_test_start_xvecs = curr_test_x_vecs[:, dim2_start_idx:dim2_end_idx]
        num_patches_per_image = curr_test_x_vecs.shape[0]

        ###########################################################################
        # Set up: pre-process variables needed for all predictions
        ###########################################################################
        print("Creating the Ba vector")
        start_time = time.time()

        self.K_train_noise, self.K_train_noise_inv, self.Ba_vector = self.calculate_iK_beta(svd_inverse = svd_inverse, sigma_threshold = sigma_threshold, plot_for_thresholding = plot_for_thresholding)
        
        elapsed_time = time.time() - start_time
        print("Finished creating the Ba vector: ", elapsed_time)
        
        ###########################################################################
        # 1st Prediction: GP output
        ###########################################################################
        # perform the prediction
        if all_together:
            first_prediction_means, first_prediction_vars = GP_model.predict(Xnew=np.array(curr_test_x_vecs).reshape((len(curr_test_x_vecs), -1)))
        else: 
            first_prediction_means = []
            first_prediction_vars = []

            start_range = 0
            print("About to start while loop for First mean prediction")
            while start_range < len(curr_test_x_vecs):
                end_range = min(start_range + self.predictor.max_all_together, len(curr_test_x_vecs))
                predict_subset_x = curr_test_x_vecs[start_range:end_range]
                predict_subset_ymean, predict_subset_yvar = self.GP_model.predict(Xnew=np.array(predict_subset_x).reshape((len(predict_subset_x), -1)))
                # append
                first_prediction_means.extend(predict_subset_ymean)
                first_prediction_vars.extend(predict_subset_yvar)

                start_range += self.predictor.max_all_together

        # convert the predicted vecs to patches
        first_prediction_mean_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in first_prediction_means]
        first_prediction_var_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in first_prediction_vars]

        # convert the predicted patches to one predicted image
        if not use_variance_weighting:
            first_predicted_mean_image, _, _, _, _ = patch_obj.unpatchify_image(patch_list=first_prediction_mean_patches, patch_variance_list=None, 
                                                                          img_type = 'y')
            first_predicted_var_image, _, _, _, _ = patch_obj.unpatchify_image(patch_list=first_prediction_var_patches, patch_variance_list=None, 
                                                                         img_type='y')
        else: 
            print("ERROR: incorrectly tried to use variance weighting here")
            assert(False)

        # append to stored lists
        predicted_seq_means.append(first_predicted_mean_image)
        predicted_seq_vars.append(first_predicted_var_image) #* 1e4) # TODO: REMOVE INCREASED VARIANCE

        #np.save("predicted_seq_means.npy", np.array(predicted_seq_means))
        #np.save("predicted_seq_vars.npy", np.array(predicted_seq_vars))

        # Convert last predicted output to next input
        first_predicted_input_mean_patches = patch_obj.patchify_image(img=predicted_seq_means[-1], img_type='x')
        first_predicted_input_var_patches = patch_obj.patchify_image(img=predicted_seq_vars[-1], img_type='x')
        # vectorize
        first_predicted_input_means = np.array([patch_mean.flatten() for patch_mean in first_predicted_input_mean_patches]) 
        first_predicted_input_vars = np.array([patch_var.flatten() for patch_var in first_predicted_input_var_patches])

        predicted_input_mean_patches.append(first_predicted_input_means)
        predicted_input_var_patches.append(first_predicted_input_vars)

        # show the first image
        fig = plt.figure()
        pos = plt.imshow(predicted_seq_means[-1])
        fig.colorbar(pos)
        plt.title("First Prediction")
        plt.show(block=False)

        # show the variance image
        fig = plt.figure()
        pos = plt.imshow(predicted_seq_vars[-1])
        fig.colorbar(pos)
        plt.title("First Variance Image")
        plt.show(block=False)

        ###########################################################################
        # 2nd and Following Predictions: PILCO
        ###########################################################################
        second_prediction_means = []
        second_prediction_means_normalgp = []
        second_prediction_vars = []
        print("About to start for loop for 2nd prediction")
        stopper = True

        for step_num in range(1, steps):
            start_time = time.time()

            current_prediction_means = []
            current_prediction_vars = []

            if step_num == 1: 
                """
                # Multi processing attempt
                print("Start multiprocessing")
                start_time = time.time()
                all_dim0_input_covs = [np.diag(var_patch.flatten()) for var_patch in predicted_input_var_patches[-1]]
                all_zipped_inputs = list(zip(predicted_input_mean_patches[-1], all_dim0_input_covs, dim1_test_start_xvecs, dim2_test_start_xvecs))

                num_processes = 5
                patches_per_division = (num_patches_per_image)//num_processes
                patch_process_indices = np.arange(0, num_patches_per_image, patches_per_division)
                if patch_process_indices[-1] != num_patches_per_image: 
                    patch_process_indices = list(patch_process_indices) + [num_patches_per_image]
                divided_inputs = [all_zipped_inputs[patch_process_indices[i:i+1]] for i in range(len(patch_process_indices) - 1)]

                pool = mp.Pool(num_processes)
                all_results = pool.starmap(self.analytical_mean_var_propogation_MP, [input_list for input_list in divided_inputs])

                time_taken = time.time() - start_time
                print("Multi processing time taken: ", time_taken)
                pdb.set_trace()
                """

                
                for patch_num in tqdm(range(len(predicted_input_mean_patches[-1]))):

                    dim0_input_mean = predicted_input_mean_patches[-1][patch_num]
                    dim0_input_var = predicted_input_var_patches[-1][patch_num] 
                    if patch_num == 0: 
                        print(dim0_input_var)
                    dim0_input_cov = np.diag(dim0_input_var.flatten())

                    dim1_input_val = dim0_test_start_xvecs[patch_num]
                    dim2_input_val = dim1_test_start_xvecs[patch_num]

                    current_prediction_mean, current_prediction_var = \
                                                self.analytical_mean_var_propogation(input_mean=dim0_input_mean, 
                                                                                     input_cov=dim0_input_cov,
                                                                                     dim1_input_val=dim1_input_val, 
                                                                                     dim2_input_val=dim2_input_val)
                    current_prediction_means.append(current_prediction_mean)
                    current_prediction_vars.append(current_prediction_var)
                
                """
                current_prediction_means2 = []
                current_prediction_vars2 = []
                for patch_num in tqdm(range(len(predicted_input_mean_patches[-1]))):

                    dim0_input_mean = predicted_input_mean_patches[-1][patch_num]
                    dim0_input_var = np.ones(dim0_input_mean.shape) * 1e-16#predicted_input_var_patches[-1][patch_num]
                    dim0_input_cov = np.diag(dim0_input_var.flatten())

                    dim1_input_val = dim1_test_start_xvecs[patch_num]
                    dim2_input_val = dim2_test_start_xvecs[patch_num]

                    input_mean = np.hstack((dim0_input_mean, np.hstack((dim1_input_val, dim2_input_val))))
                    input_var = np.hstack((dim0_input_var, np.hstack((np.ones(dim0_input_var.shape) * 1e-16, np.ones(dim0_input_var.shape) * 1e-16))))
                    input_cov = np.diag(input_var.flatten())

                    current_prediction_mean, current_prediction_var = \
                                                self.analytical_mean_var_propogation(input_mean=input_mean, 
                                                                                     input_cov=input_cov,
                                                                                     dim1_input_val=None, 
                                                                                     dim2_input_val=None)
                    current_prediction_means2.append(current_prediction_mean)
                    current_prediction_vars2.append(current_prediction_var)

                pdb.set_trace()
                """

            elif step_num == 2: 
                for patch_num in tqdm(range(len(predicted_input_mean_patches[-1]))):

                    dim0_input_mean = predicted_input_mean_patches[-1][patch_num].reshape((1, -1))
                    dim1_input_mean = predicted_input_mean_patches[-2][patch_num].reshape((1, -1))
                    dim01_input_mean = np.hstack((dim0_input_mean, dim1_input_mean))

                    dim0_input_var = predicted_input_var_patches[-1][patch_num].reshape((1, -1))
                    dim1_input_var = predicted_input_var_patches[-2][patch_num].reshape((1, -1))
                    dim01_input_var = np.hstack((dim0_input_var, dim1_input_var))
                    dim01_input_cov = np.diag(dim01_input_var.flatten())

                    dim2_input_val = dim0_test_start_xvecs[patch_num]

                    current_prediction_mean, current_prediction_var = \
                                                self.analytical_mean_var_propogation(input_mean=dim01_input_mean, 
                                                                                     input_cov=dim01_input_cov,
                                                                                     dim1_input_val=None, 
                                                                                     dim2_input_val=dim2_input_val)
                    current_prediction_means.append(current_prediction_mean)
                    current_prediction_vars.append(current_prediction_var)
            else: 
                set_debug = False
                for patch_num in tqdm(range(len(predicted_input_mean_patches[-1]))):

                    dim012_input_mean = np.hstack((predicted_input_mean_patches[-1][patch_num], 
                                                    np.hstack((predicted_input_mean_patches[-2][patch_num], 
                                                        predicted_input_mean_patches[-3][patch_num]))))
                    dim012_input_var = np.hstack((predicted_input_var_patches[-1][patch_num], 
                                                    np.hstack((predicted_input_var_patches[-2][patch_num], 
                                                        predicted_input_var_patches[-3][patch_num]))))
                    dim012_input_cov = np.diag(dim012_input_var.flatten())

                    current_prediction_mean, current_prediction_var = \
                                                self.analytical_mean_var_propogation(input_mean=dim012_input_mean, 
                                                                                     input_cov=dim012_input_cov,
                                                                                     dim1_input_val=None, 
                                                                                     dim2_input_val=None, 
                                                                                     set_debug=set_debug)
                    current_prediction_means.append(current_prediction_mean)
                    current_prediction_vars.append(current_prediction_var)

                    if set_debug: 
                        print("Debug set to true: in mean var propagation")
                        pdb.set_trace()

            # process the predicted means and variances to form images for saving and the next round of predictions
            current_prediction_mean_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in current_prediction_means]
            current_prediction_var_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in current_prediction_vars]

            # unpatchify 
            if not use_variance_weighting: 
                current_predicted_mean_image, _, _, _, _ = patch_obj.unpatchify_image(patch_list=current_prediction_mean_patches, 
                                                                                      patch_variance_list=None, 
                                                                                      img_type='y')
                current_predicted_var_image, _, _, _, _  = patch_obj.unpatchify_image(patch_list=current_prediction_var_patches, 
                                                                                      patch_variance_list=None, 
                                                                                      img_type='y') 
            else: 
                print("ERROR: incorrectly tried to use variance weighting here")
                assert(False)
            print("Finished Predicting: " + str(step_num + 1) + " took: " + str(time.time() - start_time))
            # append 
            predicted_seq_means.append(current_predicted_mean_image)
            predicted_seq_vars.append(current_predicted_var_image)


            # process input patches for the next round
            current_predicted_input_mean_patches = patch_obj.patchify_image(img=predicted_seq_means[-1], img_type='x')
            current_predicted_input_var_patches  = patch_obj.patchify_image(img=predicted_seq_vars[-1], img_type='x')  
            current_predicted_input_means = np.array([patch_mean.flatten() for patch_mean in current_predicted_input_mean_patches])
            current_predicted_input_vars  = np.array([patch_var.flatten() for patch_var in current_predicted_input_var_patches])

            predicted_input_mean_patches.append(current_predicted_input_means)
            predicted_input_var_patches.append(current_predicted_input_vars)

            if len(predicted_input_mean_patches) > 3:
                predicted_input_mean_patches = predicted_input_mean_patches[1:]
                predicted_input_var_patches  = predicted_input_var_patches[1:]

            # display the result
            fig = plt.figure()
            pos = plt.imshow(current_predicted_mean_image)
            plt.title("Prediction: " + str(step_num + 1))
            fig.colorbar(pos)
            plt.show(block=False)
            
            # display the variance image 
            fig = plt.figure()
            pos = plt.imshow(current_predicted_var_image)
            plt.title("Variance image: " + str(step_num + 1))
            fig.colorbar(pos)
            plt.show(block=False)

            #np.save("predicted_seq_means.npy", np.array(predicted_seq_means))
            #np.save("predicted_seq_vars.npy", np.array(predicted_seq_vars))

        print("Finished " + str(steps) + " predictions")

        return predicted_seq_means, predicted_seq_vars

################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
##################################################### Mean and Variance Propagation Faster: THE ONE ############################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################

    def analytical_mean_var_propogation_Faster_separateQij(self, input_means, input_covs, dim1_input_vals=None, dim2_input_vals=None):
        """ 
        Propogates forward the mean and variance of the probability distribution. 
        Batch solve for Qij - separate parts of the Q matrix based on based on 
        the available GPU memory. 
    
        Note: please only use this with a single input mean and single input cov. 

        Attempt to do this faster by parallelizing multiple patches at once using tensor operations. 

        args: 
            - input_mean: the mean along the input dimensions that are random variables
            - input_cov: the corresponding covariance matrix to the input mean 
            - dim1_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first mean propogation
            - dim2_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first and second mean propogation
        returns: 
            - current_prediction_mean: the propogated mean 
            - current_prediction_var : the propogated variance
        """ 
        max_data_size = ((1200*1200*243*243*0.25))#((1200*1200*243*243*4))#(1200*1200*243*243*8*2.15) # done with experimentation to get max capacity #7516192768 # 7 Giga Bytes
        float64_multiplier = 1 # number you mulitply the number of entries by when calculating max data size


        # # # START: REMOVE THIS
        # print("FORCING INPUT COV TO O!!!!")
        # input_covs = np.zeros(np.array(input_covs).shape)
        # # # END: REMOVE THIS


        if type(dim1_input_vals) != type(None) and type(dim2_input_vals) != type(None):

            # # START: REMOVE THIS
            # print("FORCING INPUT COV TO O!!!!")
            # input_covs = np.zeros(np.array(input_covs).shape)
            # # # END: REMOVE THIS

            dim0_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim0_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
            dim1_input_vals = dim1_input_vals
            dim2_input_vals = dim2_input_vals


            # # TODO: REMOVE
            # psuedo_index = 234
            # input_means = self.GP_model.X[psuedo_index][self.dim0_start_idx:self.dim0_end_idx].reshape(input_means.shape)
            # dim0_input_mean_patches = input_means
            # dim1_input_vals = self.GP_model.X[psuedo_index][self.dim1_start_idx:self.dim1_end_idx].reshape(dim1_input_vals.shape)
            # dim2_input_vals = self.GP_model.X[psuedo_index][self.dim2_start_idx:self.dim2_end_idx].reshape(dim2_input_vals.shape)
            # # TODO: REMOVE

            qa_coef = self.variance_dim0/np.sqrt(np.linalg.det(np.matmul(dim0_input_cov_patches, self.lengthscale_matrix_dim0_inv) + np.eye(dim0_input_cov_patches.shape[-1]))).reshape((-1, 1))
            k1_coef = self.kernel_dim1.K(dim1_input_vals, self.dim1_train_xvecs)
            k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
            k_coef  = k1_coef * k2_coef 
            # exponential part
            dim0_input_mean_patches_seperated = dim0_input_mean_patches.reshape((dim0_input_mean_patches.shape[0], 1, dim0_input_mean_patches.shape[1]))
            v = dim0_input_mean_patches_seperated - self.dim0_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim0_input_cov_patches + self.lengthscale_matrix_dim0) 
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))
            
            qa = qa_coef * k_coef * np.exp(-0.5 * qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # #clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim0_input_mean_patches_seperated)
            del(qa_coef)

            #return current_prediction_mean_patches, np.zeros(current_prediction_mean_patches.shape)
            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim0_input_cov_patches_tr = torch.cuda.DoubleTensor(dim0_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim0_input_cov_patches_tr, self.lengthscale_matrix_dim0_inv_tr + self.lengthscale_matrix_dim0_inv_tr) + torch.eye(n=dim0_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim0_input_cov_patches_tr)
                del(dim0_input_cov_patches_tr)
                del(R_tr)
                del(R_inv_tr)

                # Q matrix coefficient
                k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
                total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
                total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)

                # # TODO: REMOVE
                # psuedo_index = 234
                # psuedo_input_means = self.GP_model.X[psuedo_index].reshape((1, -1))
                # psuedo_output_means = self.GP_model.Y[psuedo_index].reshape((1, -1))
                # total_k_coef = self.GP_model.kern.K(psuedo_input_means, self.GP_model._predictive_variable).reshape(total_k_coef.shape)
                # total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
                # # TODO: END REMOVE


                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)
                del(R_sqrt_det_tr)

                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim0_inv_tr)
                del(v_tr)
                
                # replace this part by doing n computations at a time (single column) - rather than n**2
                Qmat_numpy = np.zeros((Qmat_coefs_tr.shape))
                col_dim = Qmat_numpy.shape[-2]
                num_cols = Qmat_numpy.shape[-1]

                max_one_col_datasize = np.product(lengthscale_inv_v_tr.shape) * lengthscale_inv_v_tr.shape[1] * float64_multiplier
                cols_atonce = max(1, min(num_cols, int(np.floor(max_data_size/max_one_col_datasize))))

                cols_list = list(np.arange(0, num_cols+1, cols_atonce))
                if cols_list[-1] != num_cols: 
                    cols_list.append(num_cols)
                    
                for i in range(len(cols_list) - 1):
                    curr_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr[:, cols_list[i]:cols_list[i+1], :], dim=2) + \
                                  torch.unsqueeze(lengthscale_inv_v_tr, dim=1)

                    curr_Qmat_tr = Qmat_coefs_tr[:, cols_list[i]:cols_list[i+1], :] * \
                                   torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(curr_zij_tr, dim=-2), \
                                    torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(curr_zij_tr, dim=-1))))
                    
                    del(curr_zij_tr)
                    Qmat_numpy[:, cols_list[i]:cols_list[i+1], :] = curr_Qmat_tr.detach().cpu().numpy()
                    del(curr_Qmat_tr)

                del(lengthscale_inv_v_tr)

            """
            start_time = time.time()
            # Torch computations of final variance
            Qmat_tr2 = torch.cuda.DoubleTensor(Qmat_numpy)   
            K_train_noise_inv_tr2 = torch.cuda.DoubleTensor(self.K_train_noise_inv) 

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance         
            var_first_terms_tr2 = self.full_variance - torch.trace(torch.matmul(Qmat_tr2.view(K_train_noise_inv_tr2.shape), K_train_noise_inv_tr2).view(K_train_noise_inv_tr2.shape))

            #del(K_train_noise_inv_tr2)
            Ba_vector_tr2 = torch.cuda.DoubleTensor(self.Ba_vector)

            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr2 = torch.matmul(torch.transpose(torch.matmul(Qmat_tr2, Ba_vector_tr2), dim0=1,dim1=2), Ba_vector_tr2)
            # 4. Compute 4th term mean^2
            var_third_terms_tr2 = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr2.detach().cpu().numpy() + var_second_terms_tr2.detach().cpu().numpy() - \
                                                var_third_terms_tr2

            print("Torch computation time: ", time.time() - start_time)
            """
            
            start_time = time.time()
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(Qmat_numpy @ self.K_train_noise_inv, axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            #var_second_terms_tr = (np.transpose(Qmat_numpy @ self.Ba_vector, axes=(0,2,1)) @ self.Ba_vector).flatten() # INCORRECT IN MULTI OUTPUT
            var_second_terms_tr = np.einsum('ijk,jk->ik', Qmat_numpy @ self.Ba_vector, self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr
            #print("Numpy time: ", time.time() - start_time)

            # #############################################################################################
            # # START: DEBUGGING COMPARISONS 
            # print("Starting: Debugging computations")
            # input_means = np.array(input_means)
            # total_input = np.hstack((input_means, np.hstack((dim1_input_vals, dim2_input_vals))))

            # # GPy GP Computation
            # gpy_mean, gpy_var = self.GP_model.predict(Xnew=total_input)
            # print("Finished GPy computation, starting manual computations")

            # # Manual GP Computation for mean and variance 
            # K_xstar_xstar = self.GP_model.kern.K(total_input, total_input)
            # K_X_xstar = self.GP_model.kern.K(self.GP_model._predictive_variable, total_input)
            # manual_mean = np.dot(K_X_xstar.T, self.GP_model.posterior.woodbury_vector)
            # var_intermediate = np.linalg.lstsq(self.GP_model.posterior.woodbury_chol, K_X_xstar, rcond=-1)
            # manual_var = K_xstar_xstar - (var_intermediate[0].T @ var_intermediate[0])
            # manual_var_matmul = K_xstar_xstar - (K_X_xstar.T @ (self.GP_model.posterior.woodbury_inv @ K_X_xstar))

            # print("Finished: Debugging Computations")
            # print("TODO: Comment out all the deletes again!")

            # print("At the end of the first prediction")


            # print("predicted var patches: ", current_prediction_var_patches)
            # print("gpy var: ", gpy_var)
            # print("manual var: ", manual_var)
            # print("manual var matmul: ", manual_var_matmul)

            # pdb.set_trace()

            # END: DEBUGGING COMPARISONS
            #############################################################################################

            # print("gpy var: ", gpy_var)
            # print("manual var: ", manual_var)
            # print()
            # print("propagation var: ", current_prediction_var_patches)
            # print()
            # print("STOP: End of computation")
            # print("The initially coded approach - not the copy of PILCO")
            # pdb.set_trace()

        # if type(dim1_input_vals) != type(None) and type(dim2_input_vals) != type(None):
        #     # PYTHON COPY OF PILCO ORIGINAL CODE

        #     # change imports later
        #     from GPy.util.linalg import jitchol, dtrtri, dpotri, dpotrs


        #     # 1) Compute cached variables: K, iK, L, beta
        #     K = self.GP_model.kern.K(self.GP_model._predictive_variable, self.GP_model._predictive_variable)
        #     noise_variance = self.GP_model.Gaussian_noise.variance
        #     jitter_matrix = 1e-8 * np.eye(K.shape[0])
        #     K_noise = K + (noise_variance * np.eye(K.shape[0])) + jitter_matrix

        #     L =  jitchol(K_noise)
        #     Li = dtrtri(L)
        #     iK = dpotri(L, lower=1)[0]
        #     beta, _ = dpotrs(L, self.GP_model.Y, lower=1)

        #     dim0_input_mean_patches = input_means # vector (patch nums, vec dim)
        #     dim0_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
        #     dim1_input_vals = dim1_input_vals
        #     dim2_input_vals = dim2_input_vals

        #     k1_coef = self.kernel_dim1.K(dim1_input_vals, self.dim1_train_xvecs)
        #     k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
        #     k_coef  = k1_coef * k2_coef 

        #     # 2) Mean Prediction: computing the predicted mean and inv(s) times input-output covariance
        #     m = dim0_input_mean_patches
        #     s_mat = dim0_input_cov_patches

        #     dim0_input_mean_patches_seperated = dim0_input_mean_patches.reshape((dim0_input_mean_patches.shape[0], 1, dim0_input_mean_patches.shape[1]))
        #     inp = self.dim0_train_xvecs - dim0_input_mean_patches_seperated # demeaned inputs

        #     iL = 1/self.GP_model.kern.lengthscale * np.eye(self.lengthscale_matrix_dim0.shape[0]) # sqrt of the lengthscale matrix 
        #     iN = inp @ iL #np.einsum('ij,klj->klj', iL, inp)
        #     B = (iL @ s_mat) @ iL + np.eye(iL.shape[0])

        #     assert(B.shape[0] == 1)
        #     assert(iN.shape[0] == 1)
        #     t = np.linalg.lstsq(a=B.reshape((B.shape[1], B.shape[2])), b=iN.reshape((iN.shape[1], iN.shape[2])).T, rcond=-1)[0].T
        #     t = t.reshape((1, t.shape[0], t.shape[1]))
        #     l_mat = np.exp(-np.sum(iN * t, axis=-1)/2) 
        #     #lb = l_mat * beta
        #     tiL = t @ iL 
        #     c_coef = self.variance_dim0/np.sqrt(np.linalg.det(B)) 
        #     lb_k = k_coef * l_mat
        #     M = c_coef * (lb_k @ beta)
        #     predicted_mean = M

        #     # 3) Compute predictive covariance 
        #     with torch.no_grad():
        #         # load tensors into cuda
        #         dim0_input_cov_patches_tr = torch.cuda.DoubleTensor(dim0_input_cov_patches)
        #         s_tr = dim0_input_cov_patches_tr
        #         v_tr = torch.cuda.DoubleTensor(inp)

        #         R_tr = torch.matmul(dim0_input_cov_patches_tr, self.lengthscale_matrix_dim0_inv_tr + self.lengthscale_matrix_dim0_inv_tr) + torch.eye(n=dim0_input_cov_patches.shape[-1]).to(self.device)
        #         t_R = 1/torch.sqrt(torch.det(R_tr)).reshape((-1, 1, 1))
        #         Qmat_exp_covmat_tr = torch.lstsq(input=s_tr.reshape((s_tr.shape[1], s_tr.shape[2])), A=R_tr.reshape((R_tr.shape[1], R_tr.shape[2]))).solution
        #         Qmat_exp_covmat_tr = torch.unsqueeze(Qmat_exp_covmat_tr, dim=0)

        #         # Q matrix coefficient
        #         k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
        #         total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
        #         total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
        #         Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)*t_R

        #         # Z torch computation
        #         lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim0_inv_tr)
        #         # del(v_tr)

        #         # replace this part by doing n computations at a time (single column) - rather than n**2
        #         Qmat_numpy = np.zeros((Qmat_coefs_tr.shape))
        #         col_dim = Qmat_numpy.shape[-2]
        #         num_cols = Qmat_numpy.shape[-1]

        #         max_one_col_datasize = np.product(lengthscale_inv_v_tr.shape) * lengthscale_inv_v_tr.shape[1] * float64_multiplier
        #         cols_atonce = max(1, min(num_cols, int(np.floor(max_data_size/max_one_col_datasize))))

        #         cols_list = list(np.arange(0, num_cols+1, cols_atonce))
        #         if cols_list[-1] != num_cols: 
        #             cols_list.append(num_cols)

        #         for i in range(len(cols_list) - 1):
        #             curr_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr[:, cols_list[i]:cols_list[i+1], :], dim=2) + \
        #                           torch.unsqueeze(lengthscale_inv_v_tr, dim=1)

        #             curr_Qmat_tr = Qmat_coefs_tr[:, cols_list[i]:cols_list[i+1], :] * \
        #                            torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(curr_zij_tr, dim=-2), \
        #                             torch.einsum('ijk,ilmkn->ilmkn', Qmat_exp_covmat_tr, torch.unsqueeze(curr_zij_tr, dim=-1))))
                    
        #             # del(curr_zij_tr)
        #             Qmat_numpy[:, cols_list[i]:cols_list[i+1], :] = curr_Qmat_tr.detach().cpu().numpy()
        #             # del(curr_Qmat_tr)

        #         del(lengthscale_inv_v_tr)

        #     # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
        #     var_first_terms_tr2 = self.full_variance - np.trace(iK @ Qmat_numpy, axis1=1, axis2=2).flatten()
        #     # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
        #     #var_second_terms_tr = (np.transpose(Qmat_numpy @ self.Ba_vector, axes=(0,2,1)) @ self.Ba_vector).flatten() # INCORRECT IN MULTI OUTPUT
        #     var_second_terms_tr2 = np.einsum('ijk,jk->ik', Qmat_numpy @ beta, beta).flatten()
        #     # 4. Compute 4th term mean^2
        #     var_third_terms_tr2 = np.square(predicted_mean).flatten()
        #     current_prediction_var_patches2 = var_first_terms_tr2 + var_second_terms_tr2 - var_third_terms_tr2
        #     #print("Numpy time: ", time.time() - start_time)

        #     #diff = current_prediction_var_patches - current_prediction_var_patches2

        #     current_prediction_mean_patches = predicted_mean 
        #     current_prediction_var_patches = current_prediction_var_patches2

        #     print("THIS IS IN THE NEW THING ? ")
        #     print("Hello im here")
        #     pdb.set_trace()

        elif type(dim2_input_vals) != type(None):

            # Mean Propagation - NEW
            dim01_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim01_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
            dim0_input_mean_patches = dim01_input_mean_patches[:, self.dim0_start_idx:self.dim0_end_idx] 
            dim1_input_mean_patches = dim01_input_mean_patches[:, self.dim1_start_idx:self.dim1_end_idx]
            dim2_input_vals = dim2_input_vals

            qa_coef = (self.variance_dim0*self.variance_dim1)/np.sqrt(np.linalg.det(np.matmul(dim01_input_cov_patches, self.lengthscale_matrix_dim01_inv) + np.eye(dim01_input_cov_patches.shape[-1]))).reshape((-1, 1))
            k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
            
            # exponential part
            dim01_input_mean_patches_seperated = dim01_input_mean_patches.reshape((dim01_input_mean_patches.shape[0], 1, dim01_input_mean_patches.shape[1]))
            v = dim01_input_mean_patches_seperated - self.dim01_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim01_input_cov_patches + self.lengthscale_matrix_dim01)
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))

            qa = qa_coef * k2_coef * np.exp(-0.5*qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim01_input_mean_patches_seperated)
            del(qa_coef)
            
            #return current_prediction_mean_patches, np.zeros(current_prediction_mean_patches.shape)
            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim01_input_cov_patches_tr = torch.cuda.DoubleTensor(dim01_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim01_input_cov_patches_tr, self.lengthscale_matrix_dim01_inv_tr + self.lengthscale_matrix_dim01_inv_tr) + torch.eye(n=dim01_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim01_input_cov_patches_tr)
                del(R_tr)
                del(dim01_input_cov_patches_tr)
                del(R_inv_tr)

                # Q matrix coefficient
                k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
                k1_coef = self.kernel_dim1.K(dim1_input_mean_patches, self.dim1_train_xvecs)
                total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
                total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)
                del(R_sqrt_det_tr)

                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim01_inv_tr)
                del(v_tr)

                # replace this part by doing n computations at a time (single column) - rather than n**2
                Qmat_numpy = np.zeros((Qmat_coefs_tr.shape))
                col_dim = Qmat_numpy.shape[-2]
                num_cols = Qmat_numpy.shape[-1]

                max_one_col_datasize = np.product(lengthscale_inv_v_tr.shape) * lengthscale_inv_v_tr.shape[1] * float64_multiplier
                cols_atonce = max(1, min(num_cols, int(np.floor(max_data_size/max_one_col_datasize))))

                cols_list = list(np.arange(0, num_cols+1, cols_atonce))
                if cols_list[-1] != num_cols: 
                    cols_list.append(num_cols)
                    
                for i in range(len(cols_list) - 1):
                    curr_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr[:, cols_list[i]:cols_list[i+1], :], dim=2) + \
                                  torch.unsqueeze(lengthscale_inv_v_tr, dim=1)

                    curr_Qmat_tr = Qmat_coefs_tr[:, cols_list[i]:cols_list[i+1], :] * \
                                   torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(curr_zij_tr, dim=-2), \
                                    torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(curr_zij_tr, dim=-1))))
                    
                    del(curr_zij_tr)
                    Qmat_numpy[:, cols_list[i]:cols_list[i+1], :] = curr_Qmat_tr.detach().cpu().numpy()
                    del(curr_Qmat_tr)

                del(lengthscale_inv_v_tr)

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(Qmat_numpy @ self.K_train_noise_inv, axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            #var_second_terms_tr = (np.transpose(Qmat_numpy @ self.Ba_vector, axes=(0,2,1)) @ self.Ba_vector).flatten() # INCORRECT IN MULTI OUTPUT
            var_second_terms_tr = np.einsum('ijk,jk->ik', Qmat_numpy @ self.Ba_vector, self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr


            #############################################################################################
            # # START: DEBUGGING COMPARISONS 
            # print("Starting: Debugging computations")
            # input_means = np.array(input_means)
            # total_input = np.hstack((input_means, dim2_input_vals))

            # # GPy GP Computation
            # gpy_mean, gpy_var = self.GP_model.predict(Xnew=total_input)
            # print("Finished GPy computation, starting manual computations")

            # # Manual GP Computation for mean and variance 
            # K_xstar_xstar = self.GP_model.kern.K(total_input, total_input)
            # K_X_xstar = self.GP_model.kern.K(self.GP_model._predictive_variable, total_input)
            # manual_mean = np.dot(K_X_xstar.T, self.GP_model.posterior.woodbury_vector)
            # var_intermediate = np.linalg.lstsq(self.GP_model.posterior.woodbury_chol, K_X_xstar, rcond=-1)
            # manual_var = K_xstar_xstar - (var_intermediate[0].T @ var_intermediate[0])
            # manual_var_matmul = K_xstar_xstar - (K_X_xstar.T @ (self.GP_model.posterior.woodbury_inv @ K_X_xstar))

            # print("Finished: Debugging Computations")
            # print("TODO: Comment out all the deletes again!")

            # print("At the end of the second prediction")

            # print("predicted var patches: ", current_prediction_var_patches)
            # print("gpy var: ", gpy_var)
            # print("manual var: ", manual_var)
            # print("manual var matmul: ", manual_var_matmul)

            # pdb.set_trace()

            # END: DEBUGGING COMPARISONS
            #############################################################################################
            
        else:
            # Mean Propagation - NEW
            dim012_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim012_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)

            qa_coef = (self.full_variance)/np.sqrt(np.linalg.det(np.matmul(dim012_input_cov_patches, self.lengthscale_matrix_dim012_inv) + np.eye(dim012_input_cov_patches.shape[-1]))).reshape((-1, 1))
            # exponential part
            dim012_input_mean_patches_seperated = dim012_input_mean_patches.reshape((dim012_input_mean_patches.shape[0], 1, dim012_input_mean_patches.shape[1]))
            v = dim012_input_mean_patches_seperated - self.full_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim012_input_cov_patches + self.lengthscale_matrix_dim012)
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))

            qa = qa_coef * np.exp(-0.5*qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim012_input_mean_patches_seperated)
            del(qa_coef)

            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim012_input_cov_patches_tr = torch.cuda.DoubleTensor(dim012_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim012_input_cov_patches_tr, self.lengthscale_matrix_dim012_inv_tr + self.lengthscale_matrix_dim012_inv_tr) + torch.eye(n=dim012_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim012_input_cov_patches_tr)
                # del(R_tr)
                # del(dim012_input_cov_patches_tr)
                # del(R_inv_tr)

                # Q matrix coefficient
                total_k_coef = self.GP_model.kern.K(dim012_input_mean_patches, self.full_train_xvecs)
                total_k_coef_tr = torch.unsqueeze(torch.cuda.DoubleTensor(total_k_coef), dim=1)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                # del(total_k_coef_tr)
                # del(R_sqrt_det_tr)
                
                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim012_inv_tr)
                # del(v_tr)

                # replace this part by doing n computations at a time (single column) - rather than n**2
                Qmat_numpy = np.zeros((Qmat_coefs_tr.shape))
                col_dim = Qmat_numpy.shape[-2]
                num_cols = Qmat_numpy.shape[-1]

                max_one_col_datasize = np.product(lengthscale_inv_v_tr.shape) * lengthscale_inv_v_tr.shape[1] * float64_multiplier
                
                cols_atonce = max(1, min(num_cols, int(np.floor(max_data_size/max_one_col_datasize))))

                cols_list = list(np.arange(0, num_cols+1, cols_atonce))
                if cols_list[-1] != num_cols: 
                    cols_list.append(num_cols)
                    
                for i in range(len(cols_list) - 1):
                    curr_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr[:, cols_list[i]:cols_list[i+1], :], dim=2) + \
                                  torch.unsqueeze(lengthscale_inv_v_tr, dim=1)

                    curr_Qmat_tr = Qmat_coefs_tr[:, cols_list[i]:cols_list[i+1], :] * \
                                   torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(curr_zij_tr, dim=-2), \
                                    torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(curr_zij_tr, dim=-1))))
                    
                    # del(curr_zij_tr)
                    Qmat_numpy[:, cols_list[i]:cols_list[i+1], :] = curr_Qmat_tr.detach().cpu().numpy()
                    # del(curr_Qmat_tr)

                # del(lengthscale_inv_v_tr)
                
            # NOTE: TODO: ASDFASDF Move these computations to torch gpu to make them faster - Needed here where you have more training data
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(Qmat_numpy @ self.K_train_noise_inv, axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            #var_second_terms_tr = (np.transpose(Qmat_numpy @ self.Ba_vector, axes=(0,2,1)) @ self.Ba_vector).flatten() # INCORRECT IN MULTI OUTPUT
            var_second_terms_tr = np.einsum('ijk,jk->ik', Qmat_numpy @ self.Ba_vector, self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr

            #############################################################################################
            # # START: DEBUGGING COMPARISONS 
            # print("Starting: Debugging computations")
            # input_means = np.array(input_means)
            # total_input = input_means

            # # GPy GP Computation
            # gpy_mean, gpy_var = self.GP_model.predict(Xnew=total_input)
            # print("Finished GPy computation, starting manual computations")

            # # Manual GP Computation for mean and variance 
            # K_xstar_xstar = self.GP_model.kern.K(total_input, total_input)
            # K_X_xstar = self.GP_model.kern.K(self.GP_model._predictive_variable, total_input)
            # manual_mean = np.dot(K_X_xstar.T, self.GP_model.posterior.woodbury_vector)
            # var_intermediate = np.linalg.lstsq(self.GP_model.posterior.woodbury_chol, K_X_xstar, rcond=-1)
            # manual_var = K_xstar_xstar - (var_intermediate[0].T @ var_intermediate[0])
            # manual_var_matmul = K_xstar_xstar - (K_X_xstar.T @ (self.GP_model.posterior.woodbury_inv @ K_X_xstar))

            # print("Finished: Debugging Computations")
            # print("TODO: Comment out all the deletes again!")

            # print("At the end of the third prediction")

            # print("predicted var patches: ", current_prediction_var_patches)
            # print("gpy var: ", gpy_var)
            # print("manual var: ", manual_var)
            # print("manual var matmul: ", manual_var_matmul)

            # pdb.set_trace()

            # END: DEBUGGING COMPARISONS
            #############################################################################################
        
        """
        print("End of the mean propagation separate")
        print("Doing prediction with GPy to test")
        total_input = np.hstack((dim0_input_mean_patches, np.hstack((dim1_input_vals, dim2_input_vals))))
        mucp, varcp = self.GP_model.predict(Xnew=total_input)

        print("Back in the Qij separate function")
        pdb.set_trace()
        """

        # force negative jitter to zero
        current_prediction_var_patches[current_prediction_var_patches < 0] = 0
        return current_prediction_mean_patches, current_prediction_var_patches    

    def analytical_mean_var_propogation_Faster(self, input_means, input_covs, dim1_input_vals=None, dim2_input_vals=None):
        """ 
        Propogates forward the mean and variance of the probability distribution. 
        Attempt to do this faster by parallelizing multiple patches at once using tensor operations. 

        args: 
            - input_mean: the mean along the input dimensions that are random variables
            - input_cov: the corresponding covariance matrix to the input mean 
            - dim1_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first mean propogation
            - dim2_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first and second mean propogation
        returns: 
            - current_prediction_mean: the propogated mean 
            - current_prediction_var : the propogated variance
        """ 

        if type(dim1_input_vals) != type(None) and type(dim2_input_vals) != type(None):

            dim0_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim0_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
            dim1_input_vals = dim1_input_vals
            dim2_input_vals = dim2_input_vals

            qa_coef = self.variance_dim0/np.sqrt(np.linalg.det(np.matmul(dim0_input_cov_patches, self.lengthscale_matrix_dim0_inv) + np.eye(dim0_input_cov_patches.shape[-1]))).reshape((-1, 1))
            k1_coef = self.kernel_dim1.K(dim1_input_vals, self.dim1_train_xvecs)
            k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
            k_coef  = k1_coef * k2_coef 
            # exponential part
            dim0_input_mean_patches_seperated = dim0_input_mean_patches.reshape((dim0_input_mean_patches.shape[0], 1, dim0_input_mean_patches.shape[1]))
            v = dim0_input_mean_patches_seperated - self.dim0_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim0_input_cov_patches + self.lengthscale_matrix_dim0) 
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))
            
            qa = qa_coef * k_coef * np.exp(-0.5 * qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            #clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim0_input_mean_patches_seperated)
            del(qa_coef)

            
            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim0_input_cov_patches_tr = torch.cuda.DoubleTensor(dim0_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim0_input_cov_patches_tr, self.lengthscale_matrix_dim0_inv_tr + self.lengthscale_matrix_dim0_inv_tr) + torch.eye(n=dim0_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim0_input_cov_patches_tr)
                del(dim0_input_cov_patches_tr)
                del(R_tr)
                del(R_inv_tr)

                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim0_inv_tr)
                all_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr, dim=2).expand(-1, -1, lengthscale_inv_v_tr.shape[1], -1) + \
                              torch.unsqueeze(lengthscale_inv_v_tr, dim=1).expand(-1, lengthscale_inv_v_tr.shape[1], -1, -1)
                del(lengthscale_inv_v_tr)
                del(v_tr)

                # Q matrix coefficient
                k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
                total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
                total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)

                # Q matrix torch computation 
                Qmat_tr = Qmat_coefs_tr * torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(all_zij_tr, dim=-2), \
                            torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(all_zij_tr, dim=-1))))
                
                # Final variance torch computation 
                Qmat_tr = Qmat_tr.detach().cpu().numpy()
        
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(np.dot(Qmat_tr, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr = np.dot(np.transpose(np.dot(Qmat_tr, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr
            
        elif type(dim2_input_vals) != type(None):

            # Mean Propagation - NEW
            dim01_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim01_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
            dim0_input_mean_patches = dim01_input_mean_patches[:, self.dim0_start_idx:self.dim0_end_idx] 
            dim1_input_mean_patches = dim01_input_mean_patches[:, self.dim1_start_idx:self.dim1_end_idx]
            dim2_input_vals = dim2_input_vals

            qa_coef = (self.variance_dim0*self.variance_dim1)/np.sqrt(np.linalg.det(np.matmul(dim01_input_cov_patches, self.lengthscale_matrix_dim01_inv) + np.eye(dim01_input_cov_patches.shape[-1]))).reshape((-1, 1))
            k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
            
            # exponential part
            dim01_input_mean_patches_seperated = dim01_input_mean_patches.reshape((dim01_input_mean_patches.shape[0], 1, dim01_input_mean_patches.shape[1]))
            v = dim01_input_mean_patches_seperated - self.dim01_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim01_input_cov_patches + self.lengthscale_matrix_dim01)
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))

            qa = qa_coef * k2_coef * np.exp(-0.5*qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim01_input_mean_patches_seperated)
            del(qa_coef)
            
            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim01_input_cov_patches_tr = torch.cuda.DoubleTensor(dim01_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim01_input_cov_patches_tr, self.lengthscale_matrix_dim01_inv_tr + self.lengthscale_matrix_dim01_inv_tr) + torch.eye(n=dim01_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim01_input_cov_patches_tr)
                del(R_tr)
                del(dim01_input_cov_patches_tr)
                del(R_inv_tr)

                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim01_inv_tr)
                all_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr, dim=2).expand(-1, -1, lengthscale_inv_v_tr.shape[1], -1) + \
                              torch.unsqueeze(lengthscale_inv_v_tr, dim=1).expand(-1, lengthscale_inv_v_tr.shape[1], -1, -1)
                del(v_tr)
                del(lengthscale_inv_v_tr)

                # Q matrix coefficient
                k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
                k1_coef = self.kernel_dim1.K(dim1_input_mean_patches, self.dim1_train_xvecs)
                total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
                total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)

                # Q matrix torch computation 
                Qmat_tr = Qmat_coefs_tr * torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(all_zij_tr, dim=-2), \
                            torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(all_zij_tr, dim=-1))))
                
                # Final variance torch computation 
                Qmat_tr = Qmat_tr.detach().cpu().numpy()

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(np.dot(Qmat_tr, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr = np.dot(np.transpose(np.dot(Qmat_tr, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr
            
        else:
            # Mean Propagation - NEW
            dim012_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim012_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)

            qa_coef = (self.full_variance)/np.sqrt(np.linalg.det(np.matmul(dim012_input_cov_patches, self.lengthscale_matrix_dim012_inv) + np.eye(dim012_input_cov_patches.shape[-1]))).reshape((-1, 1))
            # exponential part
            dim012_input_mean_patches_seperated = dim012_input_mean_patches.reshape((dim012_input_mean_patches.shape[0], 1, dim012_input_mean_patches.shape[1]))
            v = dim012_input_mean_patches_seperated - self.full_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim012_input_cov_patches + self.lengthscale_matrix_dim012)
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))

            qa = qa_coef * np.exp(-0.5*qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim012_input_mean_patches_seperated)
            del(qa_coef)

            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim012_input_cov_patches_tr = torch.cuda.DoubleTensor(dim012_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim012_input_cov_patches_tr, self.lengthscale_matrix_dim012_inv_tr + self.lengthscale_matrix_dim012_inv_tr) + torch.eye(n=dim012_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim012_input_cov_patches_tr)
                del(R_tr)
                del(dim012_input_cov_patches_tr)
                del(R_inv_tr)
                
                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim012_inv_tr)
                all_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr, dim=2).expand(-1, -1, lengthscale_inv_v_tr.shape[1], -1) + \
                              torch.unsqueeze(lengthscale_inv_v_tr, dim=1).expand(-1, lengthscale_inv_v_tr.shape[1], -1, -1)
                del(v_tr)
                del(lengthscale_inv_v_tr)
                
                # Q matrix coefficient
                total_k_coef = self.GP_model.kern.K(dim012_input_mean_patches, self.full_train_xvecs)
                total_k_coef_tr = torch.unsqueeze(torch.cuda.DoubleTensor(total_k_coef), dim=1)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)
                del(R_sqrt_det_tr)
                
                # Q matrix torch computation 
                Qmat_tr = Qmat_coefs_tr * torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(all_zij_tr, dim=-2), \
                            torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(all_zij_tr, dim=-1))))
                                
                # Final variance torch computation 
                Qmat_tr = Qmat_tr.detach().cpu().numpy()

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(np.dot(Qmat_tr, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr = np.dot(np.transpose(np.dot(Qmat_tr, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr
            
        # force negative jitter to zero
        current_prediction_var_patches[current_prediction_var_patches < 0] = 0
        return current_prediction_mean_patches, current_prediction_var_patches

    def mean_var_propogation_Faster(self, starting_x_images, steps, all_together, use_variance_weighting, test_patch_obj=None, single_kernel=False, 
        save_filename_suffix=None, dirname='', num_patches_list=[4,2,1], seperate=False, svd_inverse = False, sigma_threshold = None, plot_for_thresholding = False, 
        show_intermediate_imgs=True):
        """
        Predicts a rollout sequence from the starting_x_images for steps number of steps. 
        Propogate the random variable forward at each timestep.  

        args:  
            - starting_x_images: list of starting images: depending on the dataprocessor up to 3 images are needed
                - this is used to start the prediction process
            - steps: the number of steps to rollout
            - all_together: boolean indicator
                - True: predict the entire dataset in one go 
                - False: predict the dataset one image at a time - NOTE: TODO: might need to change this to x patches at a time. 
            - use_variance_weighting: 
                - True: use variance for the mask weighting
                - False: do not use variance - just do averaging
            - single_kernel: boolean indicator: 
                - True: just one lengthscale for all dimensions
                - False: one lengthscale per different image space. 
            - test_patch_obj: type: patcher: Separate patcher to use for prediction, if None use self.patch_obj
            - save_filename_suffix: type: string: ending string to add to  the .npy file that contains the predicted means and variances
                - if None: do not save 
            - dirname: type: str: the directory path where to save the files if saving
            - num_patches_list: type: list of integers: [x,y,z]: 
                - predicts x simultaneous patches for the first propagation
                - predicts y simultaneous patches for the second propagation 
                - predicts z simultaneous patches for the third propagation  
            - seperate: type: boolean 
                - True: if true then solve columns separately using the separate propagation function 
                - False: use the faster all together propagation function and the num patches specified in the list
            - svd_inverse: type: boolean: 
                - True: use the svd inverse with the specified sigma cutoff
                - False: use the matrices returned by GPy directly 
            - sigma_threshold: 
                - None: take the sigma values up to the "large" dropoff
                    - We define a large dropoff as if the change in values between the next two sigma values is less than 1 order of magnitude compared to the change in the last 2 values
                - float: the maximum sigma value you want to keep before inverting
            - plot_for_thresholding: type: boolean:
                - False: just use the specified thresholds 
                - True: plot the sigma values and go into a set trace to allow the sigma threshold to be set manually by the user. 
            - show_intermediate_imgs: type: boolean
                - False: does not show the images as they are predicted
                - True: Default: shows the images using matplotlib as soon as they are predicted. 
        returns: 
            - pred_seq_means: list of unpatchified images that are predicted from the test_x set - these are the means
            - pred_var_images: 
                - returns this if  use_variance_weighting is set to True
            - xdataset: the xdataset that was used to generate each of the predicted images during the rollout
        """
        if seperate: 
            num_patches_list = [1,1,1] # can only use single patch when evaluating Q cols seperately. 

        # default toy variance value for testing
        print("This has been updated: 1")

        # use_variance_weighting input only left in for compatibility - do not use with this approach 
        if type(test_patch_obj) == type(None):
            patch_obj = self.patch_obj
        else: 
            patch_obj = test_patch_obj

        assert(use_variance_weighting == False)
        #assert(patch_obj.stride == (1,1))
        if not seperate:
            assert(patch_obj.get_ypatch_dim() == (1,1)) # This method is really only for when you are estimating the center pixel of the patch 

        # currently only coding for this input space and output space
        assert(self.data_processor.xtypes == ['img0', 'img-1', 'img-2'])
        assert(self.data_processor.ytype == 'img')

        # create save names 
        if type(save_filename_suffix) == type(None):
            save_file = False
        else: 
            save_file = True
            predicted_mean_savefilename = os.path.join(dirname, "predicted_mean_" + str(save_filename_suffix) + ".npy")
            predicted_var_savefilename = os.path.join(dirname, "predicted_var_" + str(save_filename_suffix) + ".npy")

        predicted_seq_means = []
        predicted_seq_vars = [] # even though not being estimated set up here to establish framework 

        predicted_input_mean_patches = [] # list of sublists at most 3 long - each sublist is a list of patches - 0th list corresponds to patches of the last image and so on
        predicted_input_var_patches = [] # same as list of mean patches above but with variance 

        self.setup_mean_var_propogation_parameters(patch_obj=patch_obj, svd_inverse=svd_inverse, sigma_threshold=sigma_threshold, plot_for_thresholding=plot_for_thresholding)

        # Test data: starting x images
        # NOTE: add a padder image as create_x always leaves the last image for the y dataset creator
        curr_test_x = list(starting_x_images)
        processed_curr_test_x = self.data_processor.create_x(image_seq=list(curr_test_x) + [np.zeros(curr_test_x[0].shape)])
        patched_processed_curr_test_x = patch_obj.patchify_dataset(dataset=processed_curr_test_x,
            dataset_type='x')
        curr_test_x_vecs = np.array(self.data_processor.convert_imgdataset_to_vecdataset(dataset=patched_processed_curr_test_x))
        dim0_test_start_xvecs = curr_test_x_vecs[:, self.dim0_start_idx:self.dim0_end_idx]
        dim1_test_start_xvecs = curr_test_x_vecs[:, self.dim1_start_idx:self.dim1_end_idx]
        dim2_test_start_xvecs = curr_test_x_vecs[:, self.dim2_start_idx:self.dim2_end_idx]
        num_patches_per_image = curr_test_x_vecs.shape[0]

        # Save the GP model with the given suffix - save method is not implemented for the sparse models.
        if save_file and not ( type(self.GP_model) == GPy.models.sparse_gp_regression.SparseGPRegression ): 
            gp_model_filename = os.path.join(dirname, "GP_model" + str(save_filename_suffix))
            self.GP_model.save_model(output_filename=gp_model_filename)
        
        ###########################################################################
        # 1st Prediction: GP output
        ###########################################################################
        # perform the prediction
        if all_together:
            first_prediction_means, first_prediction_vars = GP_model.predict(Xnew=np.array(curr_test_x_vecs).reshape((len(curr_test_x_vecs), -1)))
        else: 
            first_prediction_means = []
            first_prediction_vars = []

            start_range = 0
            while start_range < len(curr_test_x_vecs):
                end_range = min(start_range + self.predictor.max_all_together, len(curr_test_x_vecs))
                predict_subset_x = curr_test_x_vecs[start_range:end_range]
                predict_subset_ymean, predict_subset_yvar = self.GP_model.predict(Xnew=np.array(predict_subset_x).reshape((len(predict_subset_x), -1)))
                # append
                first_prediction_means.extend(predict_subset_ymean)
                first_prediction_vars.extend(predict_subset_yvar)

                start_range += self.predictor.max_all_together

        # convert the predicted vecs to patches
        first_prediction_mean_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in first_prediction_means]
        first_prediction_var_patches = [np.ones(patch_obj.get_ypatch_dim())*float(vec) for vec in first_prediction_vars]

        # convert the predicted patches to one predicted image
        if not use_variance_weighting:
            if patch_obj.get_ypatch_dim() == (1,1):
                # Each output pixel is determined only by one prediction so unpatchify just puts the prediction in the correct location. 
                first_predicted_mean_image, _, _, _, _ = patch_obj.unpatchify_image(patch_list=first_prediction_mean_patches, 
                                                                                      patch_variance_list=None, 
                                                                                      img_type='y')
                first_predicted_var_image, _, _, _, _  = patch_obj.unpatchify_image(patch_list=first_prediction_var_patches, 
                                                                                      patch_variance_list=None, 
                                                                                      img_type='y') 
            else: 
                # Predictions must be combined like a gaussian mixture model and moment matched to single gaussian distribution 
                first_predicted_mean_image, first_predicted_var_image = \
                    patch_obj.unpatchify_mean_variance_image(mean_patch_list=first_prediction_mean_patches, 
                                                             var_patch_list=first_prediction_var_patches, 
                                                             patch_weight=np.ones(patch_obj.get_ypatch_dim()))
        else: 
            print("ERROR: incorrectly tried to use variance weighting here")
            assert(False)

        # append to stored lists
        predicted_seq_means.append(first_predicted_mean_image)
        predicted_seq_vars.append(first_predicted_var_image)# * 1e4) # TODO: REMOVE THE INCREASED VARIANCE

        if save_file: 
            np.save(predicted_mean_savefilename, predicted_seq_means)
            np.save(predicted_var_savefilename, predicted_seq_vars)

        # Convert last predicted output to next input
        first_predicted_input_mean_patches = patch_obj.patchify_image(img=predicted_seq_means[-1], img_type='x')
        first_predicted_input_var_patches = patch_obj.patchify_image(img=predicted_seq_vars[-1], img_type='x')
        # vectorize
        first_predicted_input_means = np.array([patch_mean.flatten() for patch_mean in first_predicted_input_mean_patches]) 
        first_predicted_input_vars = np.array([patch_var.flatten() for patch_var in first_predicted_input_var_patches])

        predicted_input_mean_patches.append(first_predicted_input_means)
        predicted_input_var_patches.append(first_predicted_input_vars)

        if show_intermediate_imgs:
            # show the first image
            fig = plt.figure()
            pos = plt.imshow(predicted_seq_means[-1])
            fig.colorbar(pos)
            plt.title("First Prediction")
            plt.show(block=False)

            # show the variance image
            fig = plt.figure()
            pos = plt.imshow(predicted_seq_vars[-1])
            fig.colorbar(pos)
            plt.title("First Variance Image")
            plt.show(block=False)

        ###########################################################################
        # 2nd and Following Predictions: PILCO
        ###########################################################################
        second_prediction_means = []
        second_prediction_means_normalgp = []
        second_prediction_vars = []
        stopper = True

        for step_num in range(1, steps):
            single_img_pred_start_time = time.time()

            current_prediction_means = []
            current_prediction_vars = []

            if step_num == 1: 

                patches_at_once = num_patches_list[0]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim0_input_means = predicted_input_mean_patches[-1][start_patches_num:end_patches_num]
                    dim0_input_vars = predicted_input_var_patches[-1][start_patches_num:end_patches_num] #* 1e4
                    dim0_input_covs = np.array([np.diag(dim0_input_var.flatten()) for dim0_input_var in dim0_input_vars])

                    dim1_input_vals = dim0_test_start_xvecs[start_patches_num:end_patches_num]
                    dim2_input_vals = dim1_test_start_xvecs[start_patches_num:end_patches_num]

                    if seperate: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim0_input_means, 
                                                                        input_covs=dim0_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim0_input_means, 
                                                                        input_covs=dim0_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            elif step_num == 2: 

                patches_at_once = num_patches_list[1]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim01_input_means = np.hstack((predicted_input_mean_patches[-1][start_patches_num:end_patches_num], predicted_input_mean_patches[-2][start_patches_num:end_patches_num]))
                    dim01_input_vars =  np.hstack((predicted_input_var_patches[-1][start_patches_num:end_patches_num], predicted_input_var_patches[-2][start_patches_num:end_patches_num]))
                    dim01_input_covs = np.array([np.diag(dim01_input_var.flatten()) for dim01_input_var in dim01_input_vars])

                    dim1_input_vals = None
                    dim2_input_vals = dim0_test_start_xvecs[start_patches_num:end_patches_num]

                    if seperate: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim01_input_means, 
                                                                        input_covs=dim01_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim01_input_means, 
                                                                        input_covs=dim01_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            else: 
                patches_at_once = num_patches_list[2]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim012_input_means = np.hstack((predicted_input_mean_patches[-1][start_patches_num:end_patches_num], 
                                              np.hstack((predicted_input_mean_patches[-2][start_patches_num:end_patches_num], 
                                                         predicted_input_mean_patches[-3][start_patches_num:end_patches_num]))))
                    dim012_input_vars = np.hstack((predicted_input_var_patches[-1][start_patches_num:end_patches_num], 
                                                  np.hstack((predicted_input_var_patches[-2][start_patches_num:end_patches_num], 
                                                             predicted_input_var_patches[-3][start_patches_num:end_patches_num]))))
                    dim012_input_covs = np.array([np.diag(dim012_input_var.flatten()) for dim012_input_var in dim012_input_vars])

                    dim1_input_vals = None
                    dim2_input_vals = None

                    if seperate:
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim012_input_means, 
                                                                        input_covs=dim012_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim012_input_means, 
                                                                        input_covs=dim012_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            # process the predicted means and variances to form images for saving and the next round of predictions
            current_prediction_mean_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in current_prediction_means]
            current_prediction_var_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in current_prediction_vars]

            # unpatchify 
            if not use_variance_weighting: 

                if patch_obj.get_ypatch_dim() == (1,1):
                    # Each output pixel is determined only by one prediction so unpatchify just puts the prediction in the correct location. 
                    current_predicted_mean_image, _, _, _, _ = patch_obj.unpatchify_image(patch_list=current_prediction_mean_patches, 
                                                                                          patch_variance_list=None, 
                                                                                          img_type='y')
                    current_predicted_var_image, _, _, _, _  = patch_obj.unpatchify_image(patch_list=current_prediction_var_patches, 
                                                                                          patch_variance_list=None, 
                                                                                          img_type='y') 
                else: 
                    # Predictions must be combined like a gaussian mixture model and moment matched to single gaussian distribution 
                    current_predicted_mean_image, current_predicted_var_image = \
                        patch_obj.unpatchify_mean_variance_image(mean_patch_list=current_prediction_mean_patches, 
                                                                 var_patch_list=current_prediction_var_patches, 
                                                                 patch_weight=np.ones(patch_obj.get_ypatch_dim()))

            else: 
                print("ERROR: incorrectly tried to use variance weighting here")
                assert(False)
            print("Finished Predicting: " + str(step_num + 1) + " took: " + str(time.time() - single_img_pred_start_time))
            # append 
            predicted_seq_means.append(current_predicted_mean_image)
            predicted_seq_vars.append(current_predicted_var_image)

            if save_file: 
                np.save(predicted_mean_savefilename, predicted_seq_means)
                np.save(predicted_var_savefilename, predicted_seq_vars)

            # process input patches for the next round
            current_predicted_input_mean_patches = patch_obj.patchify_image(img=predicted_seq_means[-1], img_type='x')
            current_predicted_input_var_patches  = patch_obj.patchify_image(img=predicted_seq_vars[-1], img_type='x')  
            current_predicted_input_means = np.array([patch_mean.flatten() for patch_mean in current_predicted_input_mean_patches])
            current_predicted_input_vars  = np.array([patch_var.flatten() for patch_var in current_predicted_input_var_patches])

            predicted_input_mean_patches.append(current_predicted_input_means)
            predicted_input_var_patches.append(current_predicted_input_vars)

            if len(predicted_input_mean_patches) > 3:
                predicted_input_mean_patches = predicted_input_mean_patches[1:]
                predicted_input_var_patches  = predicted_input_var_patches[1:]

            if show_intermediate_imgs: 
                # display the result
                fig = plt.figure()
                pos = plt.imshow(current_predicted_mean_image)
                plt.title("Prediction: " + str(step_num + 1))
                fig.colorbar(pos)
                plt.show(block=False)
                
                # display the variance image 
                fig = plt.figure()
                pos = plt.imshow(current_predicted_var_image)
                plt.title("Variance image: " + str(step_num + 1))
                fig.colorbar(pos)
                plt.show(block=False)

        print("Finished " + str(steps) + " predictions")

        return predicted_seq_means, predicted_seq_vars

    def continue_mean_var_propogation_Faster(self, predicted_seq_means_filename, predicted_seq_vars_filename, steps, 
        GP_model=None, GP_model_filename=None,
        test_patch_obj=None, save_filename_suffix=None, num_patches_list=[4,2,1], seperate=False, 
        extra_starting_x_images=None, show_intermediate_imgs=True):
        """
        Predicts a rollout sequence from the starting_x_images for steps number of steps. 
        Propogate the random variable forward at each timestep.  

        args:  
            - predicted_seq_means_filename: type: str: filename of the list of predicted mean images to continue from 
            - predicted_seq_vars_filename: type: str: filename of the list of predicted var images to continue from 
            - steps: the number of images to rollout 
            - GP_model_filename: type: str: the zip file where the GP model was saved
            - GP_model:type: GPy GP model 
                - Should only be provided if the GP_model_filename is None. 
            - test_patch_obj: type: patcher: Separate patcher to use for prediction, if None use self.patch_obj
            - save_filename_suffix: type: string: ending string to add to  the .npy file that contains the predicted means and variances
                - if None: do not save 
            - num_patches_list: type: list of integers: [x,y,z]: 
                - predicts x simultaneous patches for the first propagation
                - predicts y simultaneous patches for the second propagation 
                - predicts z simultaneous patches for the third propagation  
            - seperate: type: boolean 
                - True: if true then solve columns separately using the separate propagation function 
                - False: use the faster all together propagation function and the num patches specified in the list
            - extra_starting_x_images
            - show_intermediate_imgs: type: boolean
                - False: does not show the images as they are predicted
                - True: Default: shows the images using matplotlib as soon as they are predicted. 
        returns: 
            - pred_seq_means: list of unpatchified images that are predicted from the test_x set - these are the means
            - pred_var_images: 
                - returns this if  use_variance_weighting is set to True
            - xdataset: the xdataset that was used to generate each of the predicted images during the rollout
        """
        if seperate: 
            num_patches_list = [1,1,1] # can only use single patch when evaluating Q cols seperately. 

        # default toy variance value for testing
        print("This has been updated: 1")
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # use_variance_weighting input only left in for compatibility - do not use with this approach 
        if type(test_patch_obj) == type(None):
            patch_obj = self.patch_obj
        else: 
            patch_obj = test_patch_obj

        assert(use_variance_weighting == False)
        assert(patch_obj.stride == (1,1))
        assert(not (type(GP_model) == type(None)) and (type(GP_model_filename) == type(None)))

        if not seperate:
            assert(patch_obj.get_ypatch_dim() == (1,1)) # This method is really only for when you are estimating the center pixel of the patch 

        # currently only coding for this input space and output space
        assert(self.data_processor.xtypes == ['img0', 'img-1', 'img-2'])
        assert(self.data_processor.ytype == 'img')

        # create save names 
        if type(save_filename_suffix) == type(None):
            save_file = False
        else: 
            save_file = True
            predicted_mean_savefilename = "predicted_mean_" + str(save_filename_suffix) + ".npy"
            predicted_var_savefilename = "predicted_var_" + str(save_filename_suffix) + ".npy"


        predicted_seq_means = list(np.load(predicted_seq_means_filename))
        predicted_seq_vars = list(np.load(predicted_seq_vars_filename))

        predicted_input_mean_patches = [] # list of sublists at most 3 long - each sublist is a list of patches - 0th list corresponds to patches of the last image and so on
        predicted_input_var_patches = [] # same as list of mean patches above but with variance 

        self.setup_mean_var_propogation_parameters(patch_obj=patch_obj)

        if len(predicted_seq_means) == 1: 
            assert(len(extra_starting_x_images) >= 2)

            last_last_starting_xpatches = patch_obj.patchify_image(img=extra_starting_x_images[-2], img_type='x')
            last_last_starting_xvecs = np.array([patch_val.flatten() for patch_val in dim2_starting_xpatches])

            last_starting_xpatches = patch_obj.patchify_image(img=extra_starting_x_images[-1], img_type='x')
            last_starting_xvecs = np.array([patch_val.flatten() for patch_val in dim1_starting_xpatches])

            predicted_img_idxs = np.arange(-1, 0)

        elif len(predicted_seq_means) == 2: 
            assert(len(extra_starting_x_images) >= 1)

            last_starting_xpatches = patch_obj.patchify_image(img=extra_starting_x_images[-1], img_type='x')
            last_starting_xvecs = np.array([patch_val.flatten() for patch_val in dim2_starting_xpatches])
            
            predicted_img_idxs = np.arange(-2, 0)

        else:
            assert(len(predicted_seq_means) >= 3)

            predicted_img_idxs = np.arange(-3, 0)
            
        for predicted_img_idx in predicted_img_idxs:
            current_predicted_input_mean_patches = patch_obj.patchify_image(img=predicted_seq_means[predicted_img_idx], img_type='x')
            current_predicted_input_var_patches  = patch_obj.patchify_image(img=predicted_seq_vars[predicted_img_idx], img_type='x')  
            current_predicted_input_means = np.array([patch_mean.flatten() for patch_mean in current_predicted_input_mean_patches])
            current_predicted_input_vars  = np.array([patch_var.flatten() for patch_var in current_predicted_input_var_patches])

            predicted_input_mean_patches.append(current_predicted_input_means)
            predicted_input_var_patches.append(current_predicted_input_vars)


        num_patches_per_image = predicted_input_mean_patches[-1].shape[0]#curr_test_x_vecs.shape[0]

        ###########################################################################
        # 2nd and Following Predictions: PILCO
        ###########################################################################
        second_prediction_means = []
        second_prediction_vars = []
        stopper = True

        for step_num in range(1, steps):
            single_img_pred_start_time = time.time()

            current_prediction_means = []
            current_prediction_vars = []

            if step_num == 1 and len(predicted_seq_imgs) == 1: 

                patches_at_once = num_patches_list[0]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim0_input_means = predicted_input_mean_patches[-1][start_patches_num:end_patches_num]
                    dim0_input_vars = predicted_input_var_patches[-1][start_patches_num:end_patches_num] #* 1e4
                    dim0_input_covs = np.array([np.diag(dim0_input_var.flatten()) for dim0_input_var in dim0_input_vars])

                    dim1_input_vals = last_starting_xvecs[start_patches_num:end_patches_num]
                    dim2_input_vals = last_last_starting_xvecs[start_patches_num:end_patches_num]

                    if seperate: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim0_input_means, 
                                                                        input_covs=dim0_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim0_input_means, 
                                                                        input_covs=dim0_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            elif ((step_num == 2 and len(predicted_seq_imgs) == 1) or 
                 (step_num == 1 and len(predicted_seq_imgs) == 2)): 

                patches_at_once = num_patches_list[1]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim01_input_means = np.hstack((predicted_input_mean_patches[-1][start_patches_num:end_patches_num], predicted_input_mean_patches[-2][start_patches_num:end_patches_num]))
                    dim01_input_vars =  np.hstack((predicted_input_var_patches[-1][start_patches_num:end_patches_num], predicted_input_var_patches[-2][start_patches_num:end_patches_num]))
                    dim01_input_covs = np.array([np.diag(dim01_input_var.flatten()) for dim01_input_var in dim01_input_vars])

                    dim1_input_vals = None
                    dim2_input_vals = last_starting_xvecs[start_patches_num:end_patches_num]

                    if seperate: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim01_input_means, 
                                                                        input_covs=dim01_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim01_input_means, 
                                                                        input_covs=dim01_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            else: 
                patches_at_once = num_patches_list[2]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim012_input_means = np.hstack((predicted_input_mean_patches[-1][start_patches_num:end_patches_num], 
                                              np.hstack((predicted_input_mean_patches[-2][start_patches_num:end_patches_num], 
                                                         predicted_input_mean_patches[-3][start_patches_num:end_patches_num]))))
                    dim012_input_vars = np.hstack((predicted_input_var_patches[-1][start_patches_num:end_patches_num], 
                                                  np.hstack((predicted_input_var_patches[-2][start_patches_num:end_patches_num], 
                                                             predicted_input_var_patches[-3][start_patches_num:end_patches_num]))))
                    dim012_input_covs = np.array([np.diag(dim012_input_var.flatten()) for dim012_input_var in dim012_input_vars])

                    dim1_input_vals = None
                    dim2_input_vals = None

                    if seperate:
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim012_input_means, 
                                                                        input_covs=dim012_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim012_input_means, 
                                                                        input_covs=dim012_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            # process the predicted means and variances to form images for saving and the next round of predictions
            current_prediction_mean_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in current_prediction_means]
            current_prediction_var_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in current_prediction_vars]

            # unpatchify 
            if not use_variance_weighting: 

                if patch_obj.get_ypatch_dim() == (1,1):
                    # Each output pixel is determined only by one prediction so unpatchify just puts the prediction in the correct location. 
                    current_predicted_mean_image, _, _, _, _ = patch_obj.unpatchify_image(patch_list=current_prediction_mean_patches, 
                                                                                          patch_variance_list=None, 
                                                                                          img_type='y')
                    current_predicted_var_image, _, _, _, _  = patch_obj.unpatchify_image(patch_list=current_prediction_var_patches, 
                                                                                          patch_variance_list=None, 
                                                                                          img_type='y') 
                else: 
                    # Predictions must be combined like a gaussian mixture model and moment matched to single gaussian distribution 
                    current_predicted_mean_image, current_predicted_var_image = \
                        patch_obj.unpatchify_mean_variance_image(mean_patch_list=current_prediction_mean_patches, 
                                                                 var_patch_list=current_prediction_var_patches, 
                                                                 patch_weight=np.ones(patch_obj.get_ypatch_dim()))

            else: 
                print("ERROR: incorrectly tried to use variance weighting here")
                assert(False)
            print("Finished Predicting: " + str(step_num + 1) + " took: " + str(time.time() - single_img_pred_start_time))
            # append 
            predicted_seq_means.append(current_predicted_mean_image)
            predicted_seq_vars.append(current_predicted_var_image)

            if save_file: 
                np.save(predicted_mean_savefilename, predicted_seq_means)
                np.save(predicted_var_savefilename, predicted_seq_vars)

            # process input patches for the next round
            current_predicted_input_mean_patches = patch_obj.patchify_image(img=predicted_seq_means[-1], img_type='x')
            current_predicted_input_var_patches  = patch_obj.patchify_image(img=predicted_seq_vars[-1], img_type='x')  
            current_predicted_input_means = np.array([patch_mean.flatten() for patch_mean in current_predicted_input_mean_patches])
            current_predicted_input_vars  = np.array([patch_var.flatten() for patch_var in current_predicted_input_var_patches])

            predicted_input_mean_patches.append(current_predicted_input_means)
            predicted_input_var_patches.append(current_predicted_input_vars)

            if len(predicted_input_mean_patches) > 3:
                predicted_input_mean_patches = predicted_input_mean_patches[1:]
                predicted_input_var_patches  = predicted_input_var_patches[1:]

            if show_intermediate_imgs:
                # display the result
                fig = plt.figure()
                pos = plt.imshow(current_predicted_mean_image)
                plt.title("Prediction: " + str(step_num + 1))
                fig.colorbar(pos)
                plt.show(block=False)
                
                # display the variance image 
                fig = plt.figure()
                pos = plt.imshow(current_predicted_var_image)
                plt.title("Variance image: " + str(step_num + 1))
                fig.colorbar(pos)
                plt.show(block=False)

        print("Finished " + str(steps) + " predictions")

        return predicted_seq_means, predicted_seq_vars

    


################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
######################################################### Debugging Code and extra functions   #################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################

    # DEBUGGING HELPER
    def plot_img(self, img):
        fig = plt.figure()
        pos = plt.imshow(img)
        fig.colorbar(pos)
        plt.show()
        return 

    def data_refinement(self, kernel_threshold=None, verbose=True):
        """
        Code to refine the training datapoints that are used during prediction. 
        NOTE: the GP model must be trained and contain all the training data before this function 
        is called. 

        Methodology: 
            1. Add the first datapoint
            2. For all subsequent datapoints compute the kernel function 
            3. If the kernel function is below a certain threshold for all current training points add it to the training points
            4. Set the training data for the GP model again. 
            5. Re-optimize kernel after finished going through all the training points
            
        args: 
            - kernel_threshold: type: float: the threshold value below which you add training points to the pool of refined datapoints
                - if None: use a preset ratio of the variance of the optimized RBF kernel. 
            - verbose: type: bool: 
                - True: prints out the optimization messages
                - False: does not print out the optimization messages
        returns: 
            - Nothing
        """
        if type(None) == type(kernel_threshold): 
            kernel_threshold = 0.99 * float(self.GP_model.kern.variance)

        x_datapoints = self.GP_model.X
        y_datapoints = self.GP_model.Y

        refined_x_datapoints = []
        refined_y_datapoints = []

        for point_num in tqdm(range(x_datapoints.shape[0])):    
            current_x_datapoint = x_datapoints[point_num].reshape((1, -1))
            current_y_datapoint = y_datapoints[point_num].reshape((1, -1))

            if point_num == 0: 
                refined_x_datapoints.append(current_x_datapoint[0])
                refined_y_datapoints.append(current_y_datapoint[0])
                continue

            kernel_values = self.GP_model.kern.K(current_x_datapoint, np.array(refined_x_datapoints))
            #print(kernel_values)
            if np.all(kernel_values < kernel_threshold): 
                refined_x_datapoints.append(current_x_datapoint[0])
                refined_y_datapoints.append(current_y_datapoint[0])

        if verbose: 
            print("Initial: " + str(x_datapoints.shape[0]) + " datapoints refined to: " + str(len(refined_x_datapoints)) + " datapoints")
        # Reset the GP model training data
        self.GP_model.set_XY(X=np.array(refined_x_datapoints).reshape((len(refined_x_datapoints), -1)), 
                             Y=np.array(refined_y_datapoints).reshape((len(refined_y_datapoints), -1)))
        # Re-optimize the model 
        self.GP_model.optimize(messages=verbose)
        
        return kernel_threshold




