"""
Same method as code except the sequential rollout will be replaced using a sampling based method: 
For each prediction sample the output pixels rather than just the means
"""

import numpy as np 
import cv2 
import GPy 

import matplotlib.pyplot as plt 
import os 
import sys 
import time 
import multiprocessing as mp
import scipy
from scipy.linalg import lapack, blas


from tqdm import tqdm 

current_dir = os.getcwd()
parent_dir = os.path.join(current_dir, "..")
sys.path.append(parent_dir) # append parent directory to the path 

# import locally written files 
from patchify_unpatchify import patcher
from predict import predict 
from processing import processing

from nonparam_wrapper import nonparam_predictor
import torch

import pdb

class nonparam_predictor_sampling_difference(nonparam_predictor):
    # Same class as nonparam_predictor - except for the rollouts this class samples each rollout rather than using the mean 
    # only change is to the function: 'full_pred_sequential'

################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
######################################################### Mean and Variance Propagation Faster #################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################

    def analytical_mean_var_propogation_Faster_separateQij(self, input_means, input_covs, dim1_input_vals=None, dim2_input_vals=None):
        """ 
        Propogates forward the mean and variance of the probability distribution. 
        Batch solve for Qij - separate parts of the Q matrix based on based on 
        the available GPU memory. 
    
        Note: please only use this with a single input mean and single input cov. 

        Attempt to do this faster by parallelizing multiple patches at once using tensor operations. 

        args: 
            - input_mean: the mean along the input dimensions that are random variables
            - input_cov: the corresponding covariance matrix to the input mean 
            - dim1_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first mean propogation
            - dim2_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first and second mean propogation
        returns: 
            - current_prediction_mean: the propogated mean 
            - current_prediction_var : the propogated variance
        """ 
        max_data_size = ((1200*1200*243*243*0.25))#((1200*1200*243*243*4))#(1200*1200*243*243*8*2.15) # done with experimentation to get max capacity #7516192768 # 7 Giga Bytes
        float64_multiplier = 1 # number you mulitply the number of entries by when calculating max data size


        # # # START: REMOVE THIS
        # print("FORCING INPUT COV TO O!!!!")
        # input_covs = np.zeros(np.array(input_covs).shape)
        # # # END: REMOVE THIS


        if type(dim1_input_vals) != type(None) and type(dim2_input_vals) != type(None):

            # # START: REMOVE THIS
            # print("FORCING INPUT COV TO O!!!!")
            # input_covs = np.zeros(np.array(input_covs).shape)
            # # # END: REMOVE THIS

            dim0_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim0_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
            dim1_input_vals = dim1_input_vals
            dim2_input_vals = dim2_input_vals


            # # TODO: REMOVE
            # psuedo_index = 234
            # input_means = self.GP_model.X[psuedo_index][self.dim0_start_idx:self.dim0_end_idx].reshape(input_means.shape)
            # dim0_input_mean_patches = input_means
            # dim1_input_vals = self.GP_model.X[psuedo_index][self.dim1_start_idx:self.dim1_end_idx].reshape(dim1_input_vals.shape)
            # dim2_input_vals = self.GP_model.X[psuedo_index][self.dim2_start_idx:self.dim2_end_idx].reshape(dim2_input_vals.shape)
            # # TODO: REMOVE

            qa_coef = self.variance_dim0/np.sqrt(np.linalg.det(np.matmul(dim0_input_cov_patches, self.lengthscale_matrix_dim0_inv) + np.eye(dim0_input_cov_patches.shape[-1]))).reshape((-1, 1))
            k1_coef = self.kernel_dim1.K(dim1_input_vals, self.dim1_train_xvecs)
            k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
            k_coef  = k1_coef * k2_coef 
            # exponential part
            dim0_input_mean_patches_seperated = dim0_input_mean_patches.reshape((dim0_input_mean_patches.shape[0], 1, dim0_input_mean_patches.shape[1]))
            v = dim0_input_mean_patches_seperated - self.dim0_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim0_input_cov_patches + self.lengthscale_matrix_dim0) 
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))
            
            qa = qa_coef * k_coef * np.exp(-0.5 * qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # #clear mean variables
            # del(qa)
            # del(qa_exp)
            # del(qa_exp_covmat)
            # del(dim0_input_mean_patches_seperated)
            # del(qa_coef)

            #return current_prediction_mean_patches, np.zeros(current_prediction_mean_patches.shape)
            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim0_input_cov_patches_tr = torch.cuda.DoubleTensor(dim0_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim0_input_cov_patches_tr, self.lengthscale_matrix_dim0_inv_tr + self.lengthscale_matrix_dim0_inv_tr) + torch.eye(n=dim0_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim0_input_cov_patches_tr)
                # del(dim0_input_cov_patches_tr)
                # del(R_tr)
                # del(R_inv_tr)

                # Q matrix coefficient
                k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
                total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
                total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)

                # # TODO: REMOVE
                # psuedo_index = 234
                # psuedo_input_means = self.GP_model.X[psuedo_index].reshape((1, -1))
                # psuedo_output_means = self.GP_model.Y[psuedo_index].reshape((1, -1))
                # total_k_coef = self.GP_model.kern.K(psuedo_input_means, self.GP_model._predictive_variable).reshape(total_k_coef.shape)
                # total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
                # # TODO: END REMOVE


                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                # del(total_k_coef_tr)
                # del(R_sqrt_det_tr)

                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim0_inv_tr)
                # del(v_tr)
                
                # replace this part by doing n computations at a time (single column) - rather than n**2
                Qmat_numpy = np.zeros((Qmat_coefs_tr.shape))
                col_dim = Qmat_numpy.shape[-2]
                num_cols = Qmat_numpy.shape[-1]

                max_one_col_datasize = np.product(lengthscale_inv_v_tr.shape) * lengthscale_inv_v_tr.shape[1] * float64_multiplier
                cols_atonce = max(1, min(num_cols, int(np.floor(max_data_size/max_one_col_datasize))))

                cols_list = list(np.arange(0, num_cols+1, cols_atonce))
                if cols_list[-1] != num_cols: 
                    cols_list.append(num_cols)
                    
                for i in range(len(cols_list) - 1):
                    curr_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr[:, cols_list[i]:cols_list[i+1], :], dim=2) + \
                                  torch.unsqueeze(lengthscale_inv_v_tr, dim=1)

                    curr_Qmat_tr = Qmat_coefs_tr[:, cols_list[i]:cols_list[i+1], :] * \
                                   torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(curr_zij_tr, dim=-2), \
                                    torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(curr_zij_tr, dim=-1))))
                    
                    # del(curr_zij_tr)
                    Qmat_numpy[:, cols_list[i]:cols_list[i+1], :] = curr_Qmat_tr.detach().cpu().numpy()
                    # del(curr_Qmat_tr)

                del(lengthscale_inv_v_tr)

            """
            start_time = time.time()
            # Torch computations of final variance
            Qmat_tr2 = torch.cuda.DoubleTensor(Qmat_numpy)   
            K_train_noise_inv_tr2 = torch.cuda.DoubleTensor(self.K_train_noise_inv) 

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance         
            var_first_terms_tr2 = self.full_variance - torch.trace(torch.matmul(Qmat_tr2.view(K_train_noise_inv_tr2.shape), K_train_noise_inv_tr2).view(K_train_noise_inv_tr2.shape))

            #del(K_train_noise_inv_tr2)
            Ba_vector_tr2 = torch.cuda.DoubleTensor(self.Ba_vector)

            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr2 = torch.matmul(torch.transpose(torch.matmul(Qmat_tr2, Ba_vector_tr2), dim0=1,dim1=2), Ba_vector_tr2)
            # 4. Compute 4th term mean^2
            var_third_terms_tr2 = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr2.detach().cpu().numpy() + var_second_terms_tr2.detach().cpu().numpy() - \
                                                var_third_terms_tr2

            print("Torch computation time: ", time.time() - start_time)
            """
            
            start_time = time.time()
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(Qmat_numpy @ self.K_train_noise_inv, axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            #var_second_terms_tr = (np.transpose(Qmat_numpy @ self.Ba_vector, axes=(0,2,1)) @ self.Ba_vector).flatten() # INCORRECT IN MULTI OUTPUT
            var_second_terms_tr = np.einsum('ijk,jk->ik', Qmat_numpy @ self.Ba_vector, self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr
            #print("Numpy time: ", time.time() - start_time)

            # #############################################################################################
            # # START: DEBUGGING COMPARISONS 
            # print("Starting: Debugging computations")
            # input_means = np.array(input_means)
            # total_input = np.hstack((input_means, np.hstack((dim1_input_vals, dim2_input_vals))))

            # # GPy GP Computation
            # gpy_mean, gpy_var = self.GP_model.predict(Xnew=total_input)
            # print("Finished GPy computation, starting manual computations")

            # # Manual GP Computation for mean and variance 
            # K_xstar_xstar = self.GP_model.kern.K(total_input, total_input)
            # K_X_xstar = self.GP_model.kern.K(self.GP_model._predictive_variable, total_input)
            # manual_mean = np.dot(K_X_xstar.T, self.GP_model.posterior.woodbury_vector)
            # var_intermediate = np.linalg.lstsq(self.GP_model.posterior.woodbury_chol, K_X_xstar, rcond=-1)
            # manual_var = K_xstar_xstar - (var_intermediate[0].T @ var_intermediate[0])
            # manual_var_matmul = K_xstar_xstar - (K_X_xstar.T @ (self.GP_model.posterior.woodbury_inv @ K_X_xstar))

            # print("Finished: Debugging Computations")
            # print("TODO: Comment out all the deletes again!")

            # print("At the end of the first prediction")


            # print("predicted var patches: ", current_prediction_var_patches)
            # print("gpy var: ", gpy_var)
            # print("manual var: ", manual_var)
            # print("manual var matmul: ", manual_var_matmul)

            # pdb.set_trace()

            # END: DEBUGGING COMPARISONS
            #############################################################################################

            # print("gpy var: ", gpy_var)
            # print("manual var: ", manual_var)
            # print()
            # print("propagation var: ", current_prediction_var_patches)
            # print()
            # print("STOP: End of computation")
            # print("The initially coded approach - not the copy of PILCO")
            # pdb.set_trace()

        # if type(dim1_input_vals) != type(None) and type(dim2_input_vals) != type(None):
        #     # PYTHON COPY OF PILCO ORIGINAL CODE

        #     # change imports later
        #     from GPy.util.linalg import jitchol, dtrtri, dpotri, dpotrs


        #     # 1) Compute cached variables: K, iK, L, beta
        #     K = self.GP_model.kern.K(self.GP_model._predictive_variable, self.GP_model._predictive_variable)
        #     noise_variance = self.GP_model.Gaussian_noise.variance
        #     jitter_matrix = 1e-8 * np.eye(K.shape[0])
        #     K_noise = K + (noise_variance * np.eye(K.shape[0])) + jitter_matrix

        #     L =  jitchol(K_noise)
        #     Li = dtrtri(L)
        #     iK = dpotri(L, lower=1)[0]
        #     beta, _ = dpotrs(L, self.GP_model.Y, lower=1)

        #     dim0_input_mean_patches = input_means # vector (patch nums, vec dim)
        #     dim0_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
        #     dim1_input_vals = dim1_input_vals
        #     dim2_input_vals = dim2_input_vals

        #     k1_coef = self.kernel_dim1.K(dim1_input_vals, self.dim1_train_xvecs)
        #     k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
        #     k_coef  = k1_coef * k2_coef 

        #     # 2) Mean Prediction: computing the predicted mean and inv(s) times input-output covariance
        #     m = dim0_input_mean_patches
        #     s_mat = dim0_input_cov_patches

        #     dim0_input_mean_patches_seperated = dim0_input_mean_patches.reshape((dim0_input_mean_patches.shape[0], 1, dim0_input_mean_patches.shape[1]))
        #     inp = self.dim0_train_xvecs - dim0_input_mean_patches_seperated # demeaned inputs

        #     iL = 1/self.GP_model.kern.lengthscale * np.eye(self.lengthscale_matrix_dim0.shape[0]) # sqrt of the lengthscale matrix 
        #     iN = inp @ iL #np.einsum('ij,klj->klj', iL, inp)
        #     B = (iL @ s_mat) @ iL + np.eye(iL.shape[0])

        #     assert(B.shape[0] == 1)
        #     assert(iN.shape[0] == 1)
        #     t = np.linalg.lstsq(a=B.reshape((B.shape[1], B.shape[2])), b=iN.reshape((iN.shape[1], iN.shape[2])).T, rcond=-1)[0].T
        #     t = t.reshape((1, t.shape[0], t.shape[1]))
        #     l_mat = np.exp(-np.sum(iN * t, axis=-1)/2) 
        #     #lb = l_mat * beta
        #     tiL = t @ iL 
        #     c_coef = self.variance_dim0/np.sqrt(np.linalg.det(B)) 
        #     lb_k = k_coef * l_mat
        #     M = c_coef * (lb_k @ beta)
        #     predicted_mean = M

        #     # 3) Compute predictive covariance 
        #     with torch.no_grad():
        #         # load tensors into cuda
        #         dim0_input_cov_patches_tr = torch.cuda.DoubleTensor(dim0_input_cov_patches)
        #         s_tr = dim0_input_cov_patches_tr
        #         v_tr = torch.cuda.DoubleTensor(inp)

        #         R_tr = torch.matmul(dim0_input_cov_patches_tr, self.lengthscale_matrix_dim0_inv_tr + self.lengthscale_matrix_dim0_inv_tr) + torch.eye(n=dim0_input_cov_patches.shape[-1]).to(self.device)
        #         t_R = 1/torch.sqrt(torch.det(R_tr)).reshape((-1, 1, 1))
        #         Qmat_exp_covmat_tr = torch.lstsq(input=s_tr.reshape((s_tr.shape[1], s_tr.shape[2])), A=R_tr.reshape((R_tr.shape[1], R_tr.shape[2]))).solution
        #         Qmat_exp_covmat_tr = torch.unsqueeze(Qmat_exp_covmat_tr, dim=0)

        #         # Q matrix coefficient
        #         k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
        #         total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
        #         total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
        #         Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)*t_R

        #         # Z torch computation
        #         lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim0_inv_tr)
        #         # del(v_tr)

        #         # replace this part by doing n computations at a time (single column) - rather than n**2
        #         Qmat_numpy = np.zeros((Qmat_coefs_tr.shape))
        #         col_dim = Qmat_numpy.shape[-2]
        #         num_cols = Qmat_numpy.shape[-1]

        #         max_one_col_datasize = np.product(lengthscale_inv_v_tr.shape) * lengthscale_inv_v_tr.shape[1] * float64_multiplier
        #         cols_atonce = max(1, min(num_cols, int(np.floor(max_data_size/max_one_col_datasize))))

        #         cols_list = list(np.arange(0, num_cols+1, cols_atonce))
        #         if cols_list[-1] != num_cols: 
        #             cols_list.append(num_cols)

        #         for i in range(len(cols_list) - 1):
        #             curr_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr[:, cols_list[i]:cols_list[i+1], :], dim=2) + \
        #                           torch.unsqueeze(lengthscale_inv_v_tr, dim=1)

        #             curr_Qmat_tr = Qmat_coefs_tr[:, cols_list[i]:cols_list[i+1], :] * \
        #                            torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(curr_zij_tr, dim=-2), \
        #                             torch.einsum('ijk,ilmkn->ilmkn', Qmat_exp_covmat_tr, torch.unsqueeze(curr_zij_tr, dim=-1))))
                    
        #             # del(curr_zij_tr)
        #             Qmat_numpy[:, cols_list[i]:cols_list[i+1], :] = curr_Qmat_tr.detach().cpu().numpy()
        #             # del(curr_Qmat_tr)

        #         del(lengthscale_inv_v_tr)

        #     # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
        #     var_first_terms_tr2 = self.full_variance - np.trace(iK @ Qmat_numpy, axis1=1, axis2=2).flatten()
        #     # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
        #     #var_second_terms_tr = (np.transpose(Qmat_numpy @ self.Ba_vector, axes=(0,2,1)) @ self.Ba_vector).flatten() # INCORRECT IN MULTI OUTPUT
        #     var_second_terms_tr2 = np.einsum('ijk,jk->ik', Qmat_numpy @ beta, beta).flatten()
        #     # 4. Compute 4th term mean^2
        #     var_third_terms_tr2 = np.square(predicted_mean).flatten()
        #     current_prediction_var_patches2 = var_first_terms_tr2 + var_second_terms_tr2 - var_third_terms_tr2
        #     #print("Numpy time: ", time.time() - start_time)

        #     #diff = current_prediction_var_patches - current_prediction_var_patches2

        #     current_prediction_mean_patches = predicted_mean 
        #     current_prediction_var_patches = current_prediction_var_patches2

        #     print("THIS IS IN THE NEW THING ? ")
        #     print("Hello im here")
        #     pdb.set_trace()

        elif type(dim2_input_vals) != type(None):

            # Mean Propagation - NEW
            dim01_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim01_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
            dim0_input_mean_patches = dim01_input_mean_patches[:, self.dim0_start_idx:self.dim0_end_idx] 
            dim1_input_mean_patches = dim01_input_mean_patches[:, self.dim1_start_idx:self.dim1_end_idx]
            dim2_input_vals = dim2_input_vals

            qa_coef = (self.variance_dim0*self.variance_dim1)/np.sqrt(np.linalg.det(np.matmul(dim01_input_cov_patches, self.lengthscale_matrix_dim01_inv) + np.eye(dim01_input_cov_patches.shape[-1]))).reshape((-1, 1))
            k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
            
            # exponential part
            dim01_input_mean_patches_seperated = dim01_input_mean_patches.reshape((dim01_input_mean_patches.shape[0], 1, dim01_input_mean_patches.shape[1]))
            v = dim01_input_mean_patches_seperated - self.dim01_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim01_input_cov_patches + self.lengthscale_matrix_dim01)
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))

            qa = qa_coef * k2_coef * np.exp(-0.5*qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim01_input_mean_patches_seperated)
            del(qa_coef)
            
            #return current_prediction_mean_patches, np.zeros(current_prediction_mean_patches.shape)
            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim01_input_cov_patches_tr = torch.cuda.DoubleTensor(dim01_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim01_input_cov_patches_tr, self.lengthscale_matrix_dim01_inv_tr + self.lengthscale_matrix_dim01_inv_tr) + torch.eye(n=dim01_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim01_input_cov_patches_tr)
                del(R_tr)
                del(dim01_input_cov_patches_tr)
                del(R_inv_tr)

                # Q matrix coefficient
                k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
                k1_coef = self.kernel_dim1.K(dim1_input_mean_patches, self.dim1_train_xvecs)
                total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
                total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)
                del(R_sqrt_det_tr)

                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim01_inv_tr)
                del(v_tr)

                # replace this part by doing n computations at a time (single column) - rather than n**2
                Qmat_numpy = np.zeros((Qmat_coefs_tr.shape))
                col_dim = Qmat_numpy.shape[-2]
                num_cols = Qmat_numpy.shape[-1]

                max_one_col_datasize = np.product(lengthscale_inv_v_tr.shape) * lengthscale_inv_v_tr.shape[1] * float64_multiplier
                cols_atonce = max(1, min(num_cols, int(np.floor(max_data_size/max_one_col_datasize))))

                cols_list = list(np.arange(0, num_cols+1, cols_atonce))
                if cols_list[-1] != num_cols: 
                    cols_list.append(num_cols)
                    
                for i in range(len(cols_list) - 1):
                    curr_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr[:, cols_list[i]:cols_list[i+1], :], dim=2) + \
                                  torch.unsqueeze(lengthscale_inv_v_tr, dim=1)

                    curr_Qmat_tr = Qmat_coefs_tr[:, cols_list[i]:cols_list[i+1], :] * \
                                   torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(curr_zij_tr, dim=-2), \
                                    torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(curr_zij_tr, dim=-1))))
                    
                    del(curr_zij_tr)
                    Qmat_numpy[:, cols_list[i]:cols_list[i+1], :] = curr_Qmat_tr.detach().cpu().numpy()
                    del(curr_Qmat_tr)

                del(lengthscale_inv_v_tr)

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(Qmat_numpy @ self.K_train_noise_inv, axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            #var_second_terms_tr = (np.transpose(Qmat_numpy @ self.Ba_vector, axes=(0,2,1)) @ self.Ba_vector).flatten() # INCORRECT IN MULTI OUTPUT
            var_second_terms_tr = np.einsum('ijk,jk->ik', Qmat_numpy @ self.Ba_vector, self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr


            #############################################################################################
            # # START: DEBUGGING COMPARISONS 
            # print("Starting: Debugging computations")
            # input_means = np.array(input_means)
            # total_input = np.hstack((input_means, dim2_input_vals))

            # # GPy GP Computation
            # gpy_mean, gpy_var = self.GP_model.predict(Xnew=total_input)
            # print("Finished GPy computation, starting manual computations")

            # # Manual GP Computation for mean and variance 
            # K_xstar_xstar = self.GP_model.kern.K(total_input, total_input)
            # K_X_xstar = self.GP_model.kern.K(self.GP_model._predictive_variable, total_input)
            # manual_mean = np.dot(K_X_xstar.T, self.GP_model.posterior.woodbury_vector)
            # var_intermediate = np.linalg.lstsq(self.GP_model.posterior.woodbury_chol, K_X_xstar, rcond=-1)
            # manual_var = K_xstar_xstar - (var_intermediate[0].T @ var_intermediate[0])
            # manual_var_matmul = K_xstar_xstar - (K_X_xstar.T @ (self.GP_model.posterior.woodbury_inv @ K_X_xstar))

            # print("Finished: Debugging Computations")
            # print("TODO: Comment out all the deletes again!")

            # print("At the end of the second prediction")

            # print("predicted var patches: ", current_prediction_var_patches)
            # print("gpy var: ", gpy_var)
            # print("manual var: ", manual_var)
            # print("manual var matmul: ", manual_var_matmul)

            # pdb.set_trace()

            # END: DEBUGGING COMPARISONS
            #############################################################################################
            
        else:
            # Mean Propagation - NEW
            dim012_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim012_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)

            qa_coef = (self.full_variance)/np.sqrt(np.linalg.det(np.matmul(dim012_input_cov_patches, self.lengthscale_matrix_dim012_inv) + np.eye(dim012_input_cov_patches.shape[-1]))).reshape((-1, 1))
            # exponential part
            dim012_input_mean_patches_seperated = dim012_input_mean_patches.reshape((dim012_input_mean_patches.shape[0], 1, dim012_input_mean_patches.shape[1]))
            v = dim012_input_mean_patches_seperated - self.full_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim012_input_cov_patches + self.lengthscale_matrix_dim012)
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))

            qa = qa_coef * np.exp(-0.5*qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim012_input_mean_patches_seperated)
            del(qa_coef)

            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim012_input_cov_patches_tr = torch.cuda.DoubleTensor(dim012_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim012_input_cov_patches_tr, self.lengthscale_matrix_dim012_inv_tr + self.lengthscale_matrix_dim012_inv_tr) + torch.eye(n=dim012_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim012_input_cov_patches_tr)
                # del(R_tr)
                # del(dim012_input_cov_patches_tr)
                # del(R_inv_tr)

                # Q matrix coefficient
                total_k_coef = self.GP_model.kern.K(dim012_input_mean_patches, self.full_train_xvecs)
                total_k_coef_tr = torch.unsqueeze(torch.cuda.DoubleTensor(total_k_coef), dim=1)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                # del(total_k_coef_tr)
                # del(R_sqrt_det_tr)
                
                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim012_inv_tr)
                # del(v_tr)

                # replace this part by doing n computations at a time (single column) - rather than n**2
                Qmat_numpy = np.zeros((Qmat_coefs_tr.shape))
                col_dim = Qmat_numpy.shape[-2]
                num_cols = Qmat_numpy.shape[-1]

                max_one_col_datasize = np.product(lengthscale_inv_v_tr.shape) * lengthscale_inv_v_tr.shape[1] * float64_multiplier
                
                cols_atonce = max(1, min(num_cols, int(np.floor(max_data_size/max_one_col_datasize))))

                cols_list = list(np.arange(0, num_cols+1, cols_atonce))
                if cols_list[-1] != num_cols: 
                    cols_list.append(num_cols)
                    
                for i in range(len(cols_list) - 1):
                    curr_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr[:, cols_list[i]:cols_list[i+1], :], dim=2) + \
                                  torch.unsqueeze(lengthscale_inv_v_tr, dim=1)

                    curr_Qmat_tr = Qmat_coefs_tr[:, cols_list[i]:cols_list[i+1], :] * \
                                   torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(curr_zij_tr, dim=-2), \
                                    torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(curr_zij_tr, dim=-1))))
                    
                    # del(curr_zij_tr)
                    Qmat_numpy[:, cols_list[i]:cols_list[i+1], :] = curr_Qmat_tr.detach().cpu().numpy()
                    # del(curr_Qmat_tr)

                # del(lengthscale_inv_v_tr)
                
            # NOTE: TODO: ASDFASDF Move these computations to torch gpu to make them faster - Needed here where you have more training data
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(Qmat_numpy @ self.K_train_noise_inv, axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            #var_second_terms_tr = (np.transpose(Qmat_numpy @ self.Ba_vector, axes=(0,2,1)) @ self.Ba_vector).flatten() # INCORRECT IN MULTI OUTPUT
            var_second_terms_tr = np.einsum('ijk,jk->ik', Qmat_numpy @ self.Ba_vector, self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr

            #############################################################################################
            # # START: DEBUGGING COMPARISONS 
            # print("Starting: Debugging computations")
            # input_means = np.array(input_means)
            # total_input = input_means

            # # GPy GP Computation
            # gpy_mean, gpy_var = self.GP_model.predict(Xnew=total_input)
            # print("Finished GPy computation, starting manual computations")

            # # Manual GP Computation for mean and variance 
            # K_xstar_xstar = self.GP_model.kern.K(total_input, total_input)
            # K_X_xstar = self.GP_model.kern.K(self.GP_model._predictive_variable, total_input)
            # manual_mean = np.dot(K_X_xstar.T, self.GP_model.posterior.woodbury_vector)
            # var_intermediate = np.linalg.lstsq(self.GP_model.posterior.woodbury_chol, K_X_xstar, rcond=-1)
            # manual_var = K_xstar_xstar - (var_intermediate[0].T @ var_intermediate[0])
            # manual_var_matmul = K_xstar_xstar - (K_X_xstar.T @ (self.GP_model.posterior.woodbury_inv @ K_X_xstar))

            # print("Finished: Debugging Computations")
            # print("TODO: Comment out all the deletes again!")

            # print("At the end of the third prediction")

            # print("predicted var patches: ", current_prediction_var_patches)
            # print("gpy var: ", gpy_var)
            # print("manual var: ", manual_var)
            # print("manual var matmul: ", manual_var_matmul)

            # pdb.set_trace()

            # END: DEBUGGING COMPARISONS
            #############################################################################################
        
        """
        print("End of the mean propagation separate")
        print("Doing prediction with GPy to test")
        total_input = np.hstack((dim0_input_mean_patches, np.hstack((dim1_input_vals, dim2_input_vals))))
        mucp, varcp = self.GP_model.predict(Xnew=total_input)

        print("Back in the Qij separate function")
        pdb.set_trace()
        """

        # force negative jitter to zero
        current_prediction_var_patches[current_prediction_var_patches < 0] = 0
        return current_prediction_mean_patches, current_prediction_var_patches    

    def analytical_mean_var_propogation_Faster(self, input_means, input_covs, dim1_input_vals=None, dim2_input_vals=None):
        """ 
        Propogates forward the mean and variance of the probability distribution. 
        Attempt to do this faster by parallelizing multiple patches at once using tensor operations. 

        args: 
            - input_mean: the mean along the input dimensions that are random variables
            - input_cov: the corresponding covariance matrix to the input mean 
            - dim1_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first mean propogation
            - dim2_input_val: type: numpy array: the value to use for the 1st dimensional input 
                - NOTE: only needed for the first and second mean propogation
        returns: 
            - current_prediction_mean: the propogated mean 
            - current_prediction_var : the propogated variance
        """ 

        if type(dim1_input_vals) != type(None) and type(dim2_input_vals) != type(None):

            dim0_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim0_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
            dim1_input_vals = dim1_input_vals
            dim2_input_vals = dim2_input_vals

            qa_coef = self.variance_dim0/np.sqrt(np.linalg.det(np.matmul(dim0_input_cov_patches, self.lengthscale_matrix_dim0_inv) + np.eye(dim0_input_cov_patches.shape[-1]))).reshape((-1, 1))
            k1_coef = self.kernel_dim1.K(dim1_input_vals, self.dim1_train_xvecs)
            k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
            k_coef  = k1_coef * k2_coef 
            # exponential part
            dim0_input_mean_patches_seperated = dim0_input_mean_patches.reshape((dim0_input_mean_patches.shape[0], 1, dim0_input_mean_patches.shape[1]))
            v = dim0_input_mean_patches_seperated - self.dim0_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim0_input_cov_patches + self.lengthscale_matrix_dim0) 
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))
            
            qa = qa_coef * k_coef * np.exp(-0.5 * qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            #clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim0_input_mean_patches_seperated)
            del(qa_coef)

            
            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim0_input_cov_patches_tr = torch.cuda.DoubleTensor(dim0_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim0_input_cov_patches_tr, self.lengthscale_matrix_dim0_inv_tr + self.lengthscale_matrix_dim0_inv_tr) + torch.eye(n=dim0_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim0_input_cov_patches_tr)
                del(dim0_input_cov_patches_tr)
                del(R_tr)
                del(R_inv_tr)

                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim0_inv_tr)
                all_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr, dim=2).expand(-1, -1, lengthscale_inv_v_tr.shape[1], -1) + \
                              torch.unsqueeze(lengthscale_inv_v_tr, dim=1).expand(-1, lengthscale_inv_v_tr.shape[1], -1, -1)
                del(lengthscale_inv_v_tr)
                del(v_tr)

                # Q matrix coefficient
                k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
                total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
                total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)

                # Q matrix torch computation 
                Qmat_tr = Qmat_coefs_tr * torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(all_zij_tr, dim=-2), \
                            torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(all_zij_tr, dim=-1))))
                
                # Final variance torch computation 
                Qmat_tr = Qmat_tr.detach().cpu().numpy()
        
            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(np.dot(Qmat_tr, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr = np.dot(np.transpose(np.dot(Qmat_tr, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr
            
        elif type(dim2_input_vals) != type(None):

            # Mean Propagation - NEW
            dim01_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim01_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)
            dim0_input_mean_patches = dim01_input_mean_patches[:, self.dim0_start_idx:self.dim0_end_idx] 
            dim1_input_mean_patches = dim01_input_mean_patches[:, self.dim1_start_idx:self.dim1_end_idx]
            dim2_input_vals = dim2_input_vals

            qa_coef = (self.variance_dim0*self.variance_dim1)/np.sqrt(np.linalg.det(np.matmul(dim01_input_cov_patches, self.lengthscale_matrix_dim01_inv) + np.eye(dim01_input_cov_patches.shape[-1]))).reshape((-1, 1))
            k2_coef = self.kernel_dim2.K(dim2_input_vals, self.dim2_train_xvecs)
            
            # exponential part
            dim01_input_mean_patches_seperated = dim01_input_mean_patches.reshape((dim01_input_mean_patches.shape[0], 1, dim01_input_mean_patches.shape[1]))
            v = dim01_input_mean_patches_seperated - self.dim01_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim01_input_cov_patches + self.lengthscale_matrix_dim01)
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))

            qa = qa_coef * k2_coef * np.exp(-0.5*qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim01_input_mean_patches_seperated)
            del(qa_coef)
            
            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim01_input_cov_patches_tr = torch.cuda.DoubleTensor(dim01_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim01_input_cov_patches_tr, self.lengthscale_matrix_dim01_inv_tr + self.lengthscale_matrix_dim01_inv_tr) + torch.eye(n=dim01_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim01_input_cov_patches_tr)
                del(R_tr)
                del(dim01_input_cov_patches_tr)
                del(R_inv_tr)

                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim01_inv_tr)
                all_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr, dim=2).expand(-1, -1, lengthscale_inv_v_tr.shape[1], -1) + \
                              torch.unsqueeze(lengthscale_inv_v_tr, dim=1).expand(-1, lengthscale_inv_v_tr.shape[1], -1, -1)
                del(v_tr)
                del(lengthscale_inv_v_tr)

                # Q matrix coefficient
                k0_coef = self.kernel_dim0.K(dim0_input_mean_patches, self.dim0_train_xvecs)
                k1_coef = self.kernel_dim1.K(dim1_input_mean_patches, self.dim1_train_xvecs)
                total_k_coef = (k0_coef * k1_coef * k2_coef).reshape((k0_coef.shape[0], 1, k0_coef.shape[1])) 
                total_k_coef_tr = torch.cuda.DoubleTensor(total_k_coef)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)

                # Q matrix torch computation 
                Qmat_tr = Qmat_coefs_tr * torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(all_zij_tr, dim=-2), \
                            torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(all_zij_tr, dim=-1))))
                
                # Final variance torch computation 
                Qmat_tr = Qmat_tr.detach().cpu().numpy()

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(np.dot(Qmat_tr, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr = np.dot(np.transpose(np.dot(Qmat_tr, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr
            
        else:
            # Mean Propagation - NEW
            dim012_input_mean_patches = input_means # vector (patch nums, vec dim)
            dim012_input_cov_patches = input_covs # vector (patch nums, vec dim, vec dim)

            qa_coef = (self.full_variance)/np.sqrt(np.linalg.det(np.matmul(dim012_input_cov_patches, self.lengthscale_matrix_dim012_inv) + np.eye(dim012_input_cov_patches.shape[-1]))).reshape((-1, 1))
            # exponential part
            dim012_input_mean_patches_seperated = dim012_input_mean_patches.reshape((dim012_input_mean_patches.shape[0], 1, dim012_input_mean_patches.shape[1]))
            v = dim012_input_mean_patches_seperated - self.full_train_xvecs

            qa_exp_covmat = np.linalg.inv(dim012_input_cov_patches + self.lengthscale_matrix_dim012)
            qa_exp = np.matmul(np.expand_dims(np.matmul(v, qa_exp_covmat), axis=-2), np.expand_dims(v, axis=-1)).reshape((v.shape[0:2]))

            qa = qa_coef * np.exp(-0.5*qa_exp)
            current_prediction_mean_patches = np.dot(self.Ba_vector.T, qa.T)

            # clear mean variables
            del(qa)
            del(qa_exp)
            del(qa_exp_covmat)
            del(dim012_input_mean_patches_seperated)
            del(qa_coef)

            ################################################################################################
            # Variance torch propagation 
            with torch.no_grad():
                # load tensors into cuda
                dim012_input_cov_patches_tr = torch.cuda.DoubleTensor(dim012_input_cov_patches)
                v_tr = torch.cuda.DoubleTensor(v)

                R_tr = torch.matmul(dim012_input_cov_patches_tr, self.lengthscale_matrix_dim012_inv_tr + self.lengthscale_matrix_dim012_inv_tr) + torch.eye(n=dim012_input_cov_patches.shape[-1]).to(self.device)
                R_sqrt_det_tr = torch.sqrt(torch.det(R_tr)).reshape((-1,1,1))
                R_inv_tr = torch.linalg.inv(R_tr)
                Q_mat_exp_covmat_tr = torch.matmul(R_inv_tr, dim012_input_cov_patches_tr)
                del(R_tr)
                del(dim012_input_cov_patches_tr)
                del(R_inv_tr)
                
                # Z torch computation
                lengthscale_inv_v_tr = torch.matmul(v_tr, self.lengthscale_matrix_dim012_inv_tr)
                all_zij_tr = torch.unsqueeze(lengthscale_inv_v_tr, dim=2).expand(-1, -1, lengthscale_inv_v_tr.shape[1], -1) + \
                              torch.unsqueeze(lengthscale_inv_v_tr, dim=1).expand(-1, lengthscale_inv_v_tr.shape[1], -1, -1)
                del(v_tr)
                del(lengthscale_inv_v_tr)
                
                # Q matrix coefficient
                total_k_coef = self.GP_model.kern.K(dim012_input_mean_patches, self.full_train_xvecs)
                total_k_coef_tr = torch.unsqueeze(torch.cuda.DoubleTensor(total_k_coef), dim=1)
                Qmat_coefs_tr = torch.matmul(torch.transpose(total_k_coef_tr, dim0=1,dim1=2), total_k_coef_tr)/R_sqrt_det_tr
                del(total_k_coef_tr)
                del(R_sqrt_det_tr)
                
                # Q matrix torch computation 
                Qmat_tr = Qmat_coefs_tr * torch.exp(0.5 * torch.einsum('ijklm,ijkml->ijk', torch.unsqueeze(all_zij_tr, dim=-2), \
                            torch.einsum('ijk,ilmkn->ilmkn', Q_mat_exp_covmat_tr, torch.unsqueeze(all_zij_tr, dim=-1))))
                                
                # Final variance torch computation 
                Qmat_tr = Qmat_tr.detach().cpu().numpy()

            # 2. Compute 1st term E_x[var(f|x)] = alpha^2 - trace([Kxx + noise]^-1 Q )  - alpha^2 = self.full_variance
            var_first_terms_tr = self.full_variance - np.trace(np.dot(Qmat_tr, self.K_train_noise_inv), axis1=1, axis2=2).flatten()
            # 3. Compute 2nd term E_x[E_f[f|x] E_f[f|x]] = Ba_vector^T Q Ba_vector
            var_second_terms_tr = np.dot(np.transpose(np.dot(Qmat_tr, self.Ba_vector), axes=(0,2,1)), self.Ba_vector).flatten()
            # 4. Compute 4th term mean^2
            var_third_terms_tr = np.square(current_prediction_mean_patches).flatten()
            current_prediction_var_patches = var_first_terms_tr + var_second_terms_tr - var_third_terms_tr
            
        # force negative jitter to zero
        current_prediction_var_patches[current_prediction_var_patches < 0] = 0
        return current_prediction_mean_patches, current_prediction_var_patches

    def mean_var_propogation_Faster(self, starting_x_images, steps, all_together, use_variance_weighting, test_patch_obj=None, single_kernel=False, 
        save_filename_suffix=None, dirname='', num_patches_list=[4,2,1], seperate=False, svd_inverse = False, sigma_threshold = None, plot_for_thresholding = False, 
        show_intermediate_imgs=True):
        """
        Predicts a rollout sequence from the starting_x_images for steps number of steps. 
        Propogate the random variable forward at each timestep.  

        args:  
            - starting_x_images: list of starting images: depending on the dataprocessor up to 3 images are needed
                - this is used to start the prediction process
            - steps: the number of steps to rollout
            - all_together: boolean indicator
                - True: predict the entire dataset in one go 
                - False: predict the dataset one image at a time - NOTE: TODO: might need to change this to x patches at a time. 
            - use_variance_weighting: 
                - True: use variance for the mask weighting
                - False: do not use variance - just do averaging
            - single_kernel: boolean indicator: 
                - True: just one lengthscale for all dimensions
                - False: one lengthscale per different image space. 
            - test_patch_obj: type: patcher: Separate patcher to use for prediction, if None use self.patch_obj
            - save_filename_suffix: type: string: ending string to add to  the .npy file that contains the predicted means and variances
                - if None: do not save 
            - dirname: type: str: the directory path where to save the files if saving
            - num_patches_list: type: list of integers: [x,y,z]: 
                - predicts x simultaneous patches for the first propagation
                - predicts y simultaneous patches for the second propagation 
                - predicts z simultaneous patches for the third propagation  
            - seperate: type: boolean 
                - True: if true then solve columns separately using the separate propagation function 
                - False: use the faster all together propagation function and the num patches specified in the list
            - svd_inverse: type: boolean: 
                - True: use the svd inverse with the specified sigma cutoff
                - False: use the matrices returned by GPy directly 
            - sigma_threshold: 
                - None: take the sigma values up to the "large" dropoff
                    - We define a large dropoff as if the change in values between the next two sigma values is less than 1 order of magnitude compared to the change in the last 2 values
                - float: the maximum sigma value you want to keep before inverting
            - plot_for_thresholding: type: boolean:
                - False: just use the specified thresholds 
                - True: plot the sigma values and go into a set trace to allow the sigma threshold to be set manually by the user. 
            - show_intermediate_imgs: type: boolean
                - False: does not show the images as they are predicted
                - True: Default: shows the images using matplotlib as soon as they are predicted. 
        returns: 
            - pred_seq_means: list of unpatchified images that are predicted from the test_x set - these are the means
            - pred_var_images: 
                - returns this if  use_variance_weighting is set to True
            - xdataset: the xdataset that was used to generate each of the predicted images during the rollout
        """
        if seperate: 
            num_patches_list = [1,1,1] # can only use single patch when evaluating Q cols seperately. 

        # default toy variance value for testing
        print("This has been updated: 1")

        # use_variance_weighting input only left in for compatibility - do not use with this approach 
        if type(test_patch_obj) == type(None):
            patch_obj = self.patch_obj
        else: 
            patch_obj = test_patch_obj

        assert(use_variance_weighting == False)
        #assert(patch_obj.stride == (1,1))
        if not seperate:
            assert(patch_obj.get_ypatch_dim() == (1,1)) # This method is really only for when you are estimating the center pixel of the patch 

        # currently only coding for this input space and output space
        assert(self.data_processor.xtypes == ['img0', 'img-1', 'img-2'])
        assert(self.data_processor.ytype == 'diff')

        # create save names 
        if type(save_filename_suffix) == type(None):
            save_file = False
        else: 
            save_file = True
            predicted_mean_savefilename = os.path.join(dirname, "predicted_mean_" + str(save_filename_suffix) + ".npy")
            predicted_var_savefilename = os.path.join(dirname, "predicted_var_" + str(save_filename_suffix) + ".npy")

        predicted_seq_means = []
        predicted_seq_vars = [] # even though not being estimated set up here to establish framework 

        predicted_input_mean_patches = [] # list of sublists at most 3 long - each sublist is a list of patches - 0th list corresponds to patches of the last image and so on
        predicted_input_var_patches = [] # same as list of mean patches above but with variance 

        self.setup_mean_var_propogation_parameters(patch_obj=patch_obj, svd_inverse=svd_inverse, sigma_threshold=sigma_threshold, plot_for_thresholding=plot_for_thresholding)

        # Test data: starting x images
        # NOTE: add a padder image as create_x always leaves the last image for the y dataset creator
        curr_test_x = list(starting_x_images)
        processed_curr_test_x = self.data_processor.create_x(image_seq=list(curr_test_x) + [np.zeros(curr_test_x[0].shape)])
        patched_processed_curr_test_x = patch_obj.patchify_dataset(dataset=processed_curr_test_x,
            dataset_type='x')
        curr_test_x_vecs = np.array(self.data_processor.convert_imgdataset_to_vecdataset(dataset=patched_processed_curr_test_x))
        dim0_test_start_xvecs = curr_test_x_vecs[:, self.dim0_start_idx:self.dim0_end_idx]
        dim1_test_start_xvecs = curr_test_x_vecs[:, self.dim1_start_idx:self.dim1_end_idx]
        dim2_test_start_xvecs = curr_test_x_vecs[:, self.dim2_start_idx:self.dim2_end_idx]
        num_patches_per_image = curr_test_x_vecs.shape[0]

        # Save the GP model with the given suffix - save method is not implemented for the sparse models.
        if save_file and not ( type(self.GP_model) == GPy.models.sparse_gp_regression.SparseGPRegression ): 
            gp_model_filename = os.path.join(dirname, "GP_model" + str(save_filename_suffix))
            self.GP_model.save_model(output_filename=gp_model_filename)
        
        ###########################################################################
        # 1st Prediction: GP output
        ###########################################################################
        # perform the prediction
        if all_together:
            first_prediction_means, first_prediction_vars = GP_model.predict(Xnew=np.array(curr_test_x_vecs).reshape((len(curr_test_x_vecs), -1)))
        else: 
            first_prediction_means = []
            first_prediction_vars = []

            start_range = 0
            while start_range < len(curr_test_x_vecs):
                end_range = min(start_range + self.predictor.max_all_together, len(curr_test_x_vecs))
                predict_subset_x = curr_test_x_vecs[start_range:end_range]
                predict_subset_ymean, predict_subset_yvar = self.GP_model.predict(Xnew=np.array(predict_subset_x).reshape((len(predict_subset_x), -1)))
                # append
                first_prediction_means.extend(predict_subset_ymean)
                first_prediction_vars.extend(predict_subset_yvar)

                start_range += self.predictor.max_all_together

        # convert the predicted vecs to patches
        first_prediction_mean_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in first_prediction_means]
        first_prediction_var_patches = [np.ones(patch_obj.get_ypatch_dim())*float(vec) for vec in first_prediction_vars]

        # convert the predicted patches to one predicted image
        if not use_variance_weighting:
            if patch_obj.get_ypatch_dim() == (1,1):
                # Each output pixel is determined only by one prediction so unpatchify just puts the prediction in the correct location. 
                first_predicted_mean_diff, _, _, _, _ = patch_obj.unpatchify_image(patch_list=first_prediction_mean_patches, 
                                                                                      patch_variance_list=None, 
                                                                                      img_type='y')
                first_predicted_var_diff, _, _, _, _  = patch_obj.unpatchify_image(patch_list=first_prediction_var_patches, 
                                                                                      patch_variance_list=None, 
                                                                                      img_type='y') 
            else: 
                # Predictions must be combined like a gaussian mixture model and moment matched to single gaussian distribution 
                first_predicted_mean_diff, first_predicted_var_diff = \
                    patch_obj.unpatchify_mean_variance_image(mean_patch_list=first_prediction_mean_patches, 
                                                             var_patch_list=first_prediction_var_patches, 
                                                             patch_weight=np.ones(patch_obj.get_ypatch_dim()))
                    
        else: 
            print("ERROR: incorrectly tried to use variance weighting here")
            assert(False)

        # difference prediction to image
        first_predicted_mean_image = starting_x_images[-1] + first_predicted_mean_diff
        first_predicted_var_image  = first_predicted_var_diff

        # append to stored lists
        predicted_seq_means.append(first_predicted_mean_image)
        predicted_seq_vars.append(first_predicted_var_image)# * 1e4) # TODO: REMOVE THE INCREASED VARIANCE

        if save_file: 
            np.save(predicted_mean_savefilename, predicted_seq_means)
            np.save(predicted_var_savefilename, predicted_seq_vars)

        # Convert last predicted output to next input
        first_predicted_input_mean_patches = patch_obj.patchify_image(img=predicted_seq_means[-1], img_type='x')
        first_predicted_input_var_patches = patch_obj.patchify_image(img=predicted_seq_vars[-1], img_type='x')
        # vectorize
        first_predicted_input_means = np.array([patch_mean.flatten() for patch_mean in first_predicted_input_mean_patches]) 
        first_predicted_input_vars = np.array([patch_var.flatten() for patch_var in first_predicted_input_var_patches])

        predicted_input_mean_patches.append(first_predicted_input_means)
        predicted_input_var_patches.append(first_predicted_input_vars)

        if show_intermediate_imgs:
            # show the first image
            fig = plt.figure()
            pos = plt.imshow(predicted_seq_means[-1])
            fig.colorbar(pos)
            plt.title("First Prediction")
            plt.show(block=False)

            # show the variance image
            fig = plt.figure()
            pos = plt.imshow(predicted_seq_vars[-1])
            fig.colorbar(pos)
            plt.title("First Variance Image")
            plt.show(block=False)

        ###########################################################################
        # 2nd and Following Predictions: PILCO
        ###########################################################################
        second_prediction_means = []
        second_prediction_means_normalgp = []
        second_prediction_vars = []
        stopper = True

        for step_num in range(1, steps):
            single_img_pred_start_time = time.time()

            current_prediction_means = []
            current_prediction_vars = []

            if step_num == 1: 

                patches_at_once = num_patches_list[0]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim0_input_means = predicted_input_mean_patches[-1][start_patches_num:end_patches_num]
                    dim0_input_vars = predicted_input_var_patches[-1][start_patches_num:end_patches_num] #* 1e4
                    dim0_input_covs = np.array([np.diag(dim0_input_var.flatten()) for dim0_input_var in dim0_input_vars])

                    dim1_input_vals = dim0_test_start_xvecs[start_patches_num:end_patches_num]
                    dim2_input_vals = dim1_test_start_xvecs[start_patches_num:end_patches_num]

                    if seperate: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim0_input_means, 
                                                                        input_covs=dim0_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim0_input_means, 
                                                                        input_covs=dim0_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            elif step_num == 2: 

                patches_at_once = num_patches_list[1]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim01_input_means = np.hstack((predicted_input_mean_patches[-1][start_patches_num:end_patches_num], predicted_input_mean_patches[-2][start_patches_num:end_patches_num]))
                    dim01_input_vars =  np.hstack((predicted_input_var_patches[-1][start_patches_num:end_patches_num], predicted_input_var_patches[-2][start_patches_num:end_patches_num]))
                    dim01_input_covs = np.array([np.diag(dim01_input_var.flatten()) for dim01_input_var in dim01_input_vars])

                    dim1_input_vals = None
                    dim2_input_vals = dim0_test_start_xvecs[start_patches_num:end_patches_num]

                    if seperate: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim01_input_means, 
                                                                        input_covs=dim01_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim01_input_means, 
                                                                        input_covs=dim01_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            else: 
                patches_at_once = num_patches_list[2]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim012_input_means = np.hstack((predicted_input_mean_patches[-1][start_patches_num:end_patches_num], 
                                              np.hstack((predicted_input_mean_patches[-2][start_patches_num:end_patches_num], 
                                                         predicted_input_mean_patches[-3][start_patches_num:end_patches_num]))))
                    dim012_input_vars = np.hstack((predicted_input_var_patches[-1][start_patches_num:end_patches_num], 
                                                  np.hstack((predicted_input_var_patches[-2][start_patches_num:end_patches_num], 
                                                             predicted_input_var_patches[-3][start_patches_num:end_patches_num]))))
                    dim012_input_covs = np.array([np.diag(dim012_input_var.flatten()) for dim012_input_var in dim012_input_vars])

                    dim1_input_vals = None
                    dim2_input_vals = None

                    if seperate:
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim012_input_means, 
                                                                        input_covs=dim012_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim012_input_means, 
                                                                        input_covs=dim012_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            # process the predicted means and variances to form images for saving and the next round of predictions
            current_prediction_mean_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in current_prediction_means]
            current_prediction_var_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in current_prediction_vars]

            # unpatchify 
            if not use_variance_weighting: 

                if patch_obj.get_ypatch_dim() == (1,1):
                    # Each output pixel is determined only by one prediction so unpatchify just puts the prediction in the correct location. 
                    current_predicted_mean_diff, _, _, _, _ = patch_obj.unpatchify_image(patch_list=current_prediction_mean_patches, 
                                                                                          patch_variance_list=None, 
                                                                                          img_type='y')
                    current_predicted_var_diff, _, _, _, _  = patch_obj.unpatchify_image(patch_list=current_prediction_var_patches, 
                                                                                          patch_variance_list=None, 
                                                                                          img_type='y') 
                else: 
                    # Predictions must be combined like a gaussian mixture model and moment matched to single gaussian distribution 
                    current_predicted_mean_diff, current_predicted_var_diff = \
                        patch_obj.unpatchify_mean_variance_image(mean_patch_list=current_prediction_mean_patches, 
                                                                 var_patch_list=current_prediction_var_patches, 
                                                                 patch_weight=np.ones(patch_obj.get_ypatch_dim()))

            else: 
                print("ERROR: incorrectly tried to use variance weighting here")
                assert(False)

            # difference to mean image
            print("PREDICTING DIFFERENCE")
            current_predicted_mean_image = predicted_seq_means[-1] + current_predicted_mean_diff
            current_predicted_var_image = predicted_seq_vars[-1] + current_predicted_var_diff

            print("Finished Predicting: " + str(step_num + 1) + " took: " + str(time.time() - single_img_pred_start_time))
            # append 
            predicted_seq_means.append(current_predicted_mean_image)
            predicted_seq_vars.append(current_predicted_var_image)

            if save_file: 
                np.save(predicted_mean_savefilename, predicted_seq_means)
                np.save(predicted_var_savefilename, predicted_seq_vars)

            # process input patches for the next round
            current_predicted_input_mean_patches = patch_obj.patchify_image(img=predicted_seq_means[-1], img_type='x')
            current_predicted_input_var_patches  = patch_obj.patchify_image(img=predicted_seq_vars[-1], img_type='x')  
            current_predicted_input_means = np.array([patch_mean.flatten() for patch_mean in current_predicted_input_mean_patches])
            current_predicted_input_vars  = np.array([patch_var.flatten() for patch_var in current_predicted_input_var_patches])

            predicted_input_mean_patches.append(current_predicted_input_means)
            predicted_input_var_patches.append(current_predicted_input_vars)

            if len(predicted_input_mean_patches) > 3:
                predicted_input_mean_patches = predicted_input_mean_patches[1:]
                predicted_input_var_patches  = predicted_input_var_patches[1:]

            if show_intermediate_imgs:
                # display the result
                fig = plt.figure()
                pos = plt.imshow(current_predicted_mean_image)
                plt.title("Prediction: " + str(step_num + 1))
                fig.colorbar(pos)
                plt.show(block=False)
                
                # display the variance image 
                fig = plt.figure()
                pos = plt.imshow(current_predicted_var_image)
                plt.title("Variance image: " + str(step_num + 1))
                fig.colorbar(pos)
                plt.show(block=False)

        print("Finished " + str(steps) + " predictions")

        return predicted_seq_means, predicted_seq_vars

    def continue_mean_var_propogation_Faster(self, predicted_seq_means_filename, predicted_seq_vars_filename, steps, 
        GP_model=None, GP_model_filename=None,
        test_patch_obj=None, save_filename_suffix=None, num_patches_list=[4,2,1], seperate=False, 
        extra_starting_x_images=None, svd_inverse=False, sigma_threshold=None, plot_for_thresholding=False, 
        show_intermediate_imgs=True):
        """
        Predicts a rollout sequence from the starting_x_images for steps number of steps. 
        Propogate the random variable forward at each timestep.  

        args:  
            - predicted_seq_means_filename: type: str: filename of the list of predicted mean images to continue from 
            - predicted_seq_vars_filename: type: str: filename of the list of predicted var images to continue from 
            - steps: the number of images to rollout 
            - GP_model_filename: type: str: the zip file where the GP model was saved
            - GP_model:type: GPy GP model 
                - Should only be provided if the GP_model_filename is None. 
            - test_patch_obj: type: patcher: Separate patcher to use for prediction, if None use self.patch_obj
            - save_filename_suffix: type: string: ending string to add to  the .npy file that contains the predicted means and variances
                - if None: do not save 
            - num_patches_list: type: list of integers: [x,y,z]: 
                - predicts x simultaneous patches for the first propagation
                - predicts y simultaneous patches for the second propagation 
                - predicts z simultaneous patches for the third propagation  
            - seperate: type: boolean 
                - True: if true then solve columns separately using the separate propagation function 
                - False: use the faster all together propagation function and the num patches specified in the list
            - svd_inverse: type: boolean: 
                - True: use the svd inverse with the specified sigma cutoff
                - False: use the matrices returned by GPy directly 
            - sigma_threshold: 
                - None: take the sigma values up to the "large" dropoff
                    - We define a large dropoff as if the change in values between the next two sigma values is less than 1 order of magnitude compared to the change in the last 2 values
                - float: the maximum sigma value you want to keep before inverting
            - plot_for_thresholding: type: boolean:
                - False: just use the specified thresholds 
                - True: plot the sigma values and go into a set trace to allow the sigma threshold to be set manually by the user. 
            - show_intermediate_imgs: type: boolean
                - False: does not show the images as they are predicted
                - True: Default: shows the images using matplotlib as soon as they are predicted. 
        returns: 
            - pred_seq_means: list of unpatchified images that are predicted from the test_x set - these are the means
            - pred_var_images: 
                - returns this if  use_variance_weighting is set to True
            - xdataset: the xdataset that was used to generate each of the predicted images during the rollout
        """
        if seperate: 
            num_patches_list = [1,1,1] # can only use single patch when evaluating Q cols seperately. 

        # default toy variance value for testing
        print("This has been updated: 1")
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # use_variance_weighting input only left in for compatibility - do not use with this approach 
        if type(test_patch_obj) == type(None):
            patch_obj = self.patch_obj
        else: 
            patch_obj = test_patch_obj

        assert(use_variance_weighting == False)
        assert(patch_obj.stride == (1,1))
        assert(not (type(GP_model) == type(None)) and (type(GP_model_filename) == type(None)))

        if not seperate:
            assert(patch_obj.get_ypatch_dim() == (1,1)) # This method is really only for when you are estimating the center pixel of the patch 

        # currently only coding for this input space and output space
        assert(self.data_processor.xtypes == ['img0', 'img-1', 'img-2'])
        assert(self.data_processor.ytype == 'diff')

        print("NEED TO CHANGE THIS FUNCTION FOR DIFFERENCE")
        pdb.set_trace()

        # create save names 
        if type(save_filename_suffix) == type(None):
            save_file = False
        else: 
            save_file = True
            predicted_mean_savefilename = "predicted_mean_" + str(save_filename_suffix) + ".npy"
            predicted_var_savefilename = "predicted_var_" + str(save_filename_suffix) + ".npy"


        predicted_seq_means = list(np.load(predicted_seq_means_filename))
        predicted_seq_vars = list(np.load(predicted_seq_vars_filename))

        predicted_input_mean_patches = [] # list of sublists at most 3 long - each sublist is a list of patches - 0th list corresponds to patches of the last image and so on
        predicted_input_var_patches = [] # same as list of mean patches above but with variance 

        self.setup_mean_var_propogation_parameters(patch_obj=patch_obj, svd_inverse=svd_inverse, sigma_threshold=sigma_threshold, plot_for_thresholding=plot_for_thresholding)

        if len(predicted_seq_means) == 1: 
            assert(len(extra_starting_x_images) >= 2)

            last_last_starting_xpatches = patch_obj.patchify_image(img=extra_starting_x_images[-2], img_type='x')
            last_last_starting_xvecs = np.array([patch_val.flatten() for patch_val in dim2_starting_xpatches])

            last_starting_xpatches = patch_obj.patchify_image(img=extra_starting_x_images[-1], img_type='x')
            last_starting_xvecs = np.array([patch_val.flatten() for patch_val in dim1_starting_xpatches])

            predicted_img_idxs = np.arange(-1, 0)

        elif len(predicted_seq_means) == 2: 
            assert(len(extra_starting_x_images) >= 1)

            last_starting_xpatches = patch_obj.patchify_image(img=extra_starting_x_images[-1], img_type='x')
            last_starting_xvecs = np.array([patch_val.flatten() for patch_val in dim2_starting_xpatches])
            
            predicted_img_idxs = np.arange(-2, 0)

        else:
            assert(len(predicted_seq_means) >= 3)

            predicted_img_idxs = np.arange(-3, 0)
            
        for predicted_img_idx in predicted_img_idxs:
            current_predicted_input_mean_patches = patch_obj.patchify_image(img=predicted_seq_means[predicted_img_idx], img_type='x')
            current_predicted_input_var_patches  = patch_obj.patchify_image(img=predicted_seq_vars[predicted_img_idx], img_type='x')  
            current_predicted_input_means = np.array([patch_mean.flatten() for patch_mean in current_predicted_input_mean_patches])
            current_predicted_input_vars  = np.array([patch_var.flatten() for patch_var in current_predicted_input_var_patches])

            predicted_input_mean_patches.append(current_predicted_input_means)
            predicted_input_var_patches.append(current_predicted_input_vars)


        num_patches_per_image = predicted_input_mean_patches[-1].shape[0]#curr_test_x_vecs.shape[0]

        ###########################################################################
        # 2nd and Following Predictions: PILCO
        ###########################################################################
        second_prediction_means = []
        second_prediction_vars = []
        stopper = True

        for step_num in range(1, steps):
            single_img_pred_start_time = time.time()

            current_prediction_means = []
            current_prediction_vars = []

            if step_num == 1 and len(predicted_seq_imgs) == 1: 

                patches_at_once = num_patches_list[0]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim0_input_means = predicted_input_mean_patches[-1][start_patches_num:end_patches_num]
                    dim0_input_vars = predicted_input_var_patches[-1][start_patches_num:end_patches_num] #* 1e4
                    dim0_input_covs = np.array([np.diag(dim0_input_var.flatten()) for dim0_input_var in dim0_input_vars])

                    dim1_input_vals = last_starting_xvecs[start_patches_num:end_patches_num]
                    dim2_input_vals = last_last_starting_xvecs[start_patches_num:end_patches_num]

                    if seperate: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim0_input_means, 
                                                                        input_covs=dim0_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim0_input_means, 
                                                                        input_covs=dim0_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            elif ((step_num == 2 and len(predicted_seq_imgs) == 1) or 
                 (step_num == 1 and len(predicted_seq_imgs) == 2)): 

                patches_at_once = num_patches_list[1]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim01_input_means = np.hstack((predicted_input_mean_patches[-1][start_patches_num:end_patches_num], predicted_input_mean_patches[-2][start_patches_num:end_patches_num]))
                    dim01_input_vars =  np.hstack((predicted_input_var_patches[-1][start_patches_num:end_patches_num], predicted_input_var_patches[-2][start_patches_num:end_patches_num]))
                    dim01_input_covs = np.array([np.diag(dim01_input_var.flatten()) for dim01_input_var in dim01_input_vars])

                    dim1_input_vals = None
                    dim2_input_vals = last_starting_xvecs[start_patches_num:end_patches_num]

                    if seperate: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim01_input_means, 
                                                                        input_covs=dim01_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim01_input_means, 
                                                                        input_covs=dim01_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            else: 
                patches_at_once = num_patches_list[2]

                current_prediction_means = []
                current_prediction_vars = []

                patches_num_list = np.arange(0, num_patches_per_image+1, patches_at_once)
                if patches_num_list[-1] < num_patches_per_image: 
                    patches_num_list = list(patches_num_list) + [num_patches_per_image]

                for i in tqdm(range(len(patches_num_list) - 1)):
                    start_patches_num = patches_num_list[i]
                    end_patches_num = patches_num_list[i+1]

                    dim012_input_means = np.hstack((predicted_input_mean_patches[-1][start_patches_num:end_patches_num], 
                                              np.hstack((predicted_input_mean_patches[-2][start_patches_num:end_patches_num], 
                                                         predicted_input_mean_patches[-3][start_patches_num:end_patches_num]))))
                    dim012_input_vars = np.hstack((predicted_input_var_patches[-1][start_patches_num:end_patches_num], 
                                                  np.hstack((predicted_input_var_patches[-2][start_patches_num:end_patches_num], 
                                                             predicted_input_var_patches[-3][start_patches_num:end_patches_num]))))
                    dim012_input_covs = np.array([np.diag(dim012_input_var.flatten()) for dim012_input_var in dim012_input_vars])

                    dim1_input_vals = None
                    dim2_input_vals = None

                    if seperate:
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster_separateQij(input_means=dim012_input_means, 
                                                                        input_covs=dim012_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)
                    else: 
                        current_prediction_means2, current_prediction_vars2 = \
                            self.analytical_mean_var_propogation_Faster(input_means=dim012_input_means, 
                                                                        input_covs=dim012_input_covs, 
                                                                        dim1_input_vals=dim1_input_vals, 
                                                                        dim2_input_vals=dim2_input_vals)

                    current_prediction_means.append(list(current_prediction_means2.flatten()))
                    current_prediction_vars.append(list(current_prediction_vars2.flatten()))

            # process the predicted means and variances to form images for saving and the next round of predictions
            current_prediction_mean_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in current_prediction_means]
            current_prediction_var_patches = [np.array(vec).reshape(patch_obj.get_ypatch_dim()) for vec in current_prediction_vars]

            # unpatchify 
            if not use_variance_weighting: 

                if patch_obj.get_ypatch_dim() == (1,1):
                    # Each output pixel is determined only by one prediction so unpatchify just puts the prediction in the correct location. 
                    current_predicted_mean_image, _, _, _, _ = patch_obj.unpatchify_image(patch_list=current_prediction_mean_patches, 
                                                                                          patch_variance_list=None, 
                                                                                          img_type='y')
                    current_predicted_var_image, _, _, _, _  = patch_obj.unpatchify_image(patch_list=current_prediction_var_patches, 
                                                                                          patch_variance_list=None, 
                                                                                          img_type='y') 
                else: 
                    # Predictions must be combined like a gaussian mixture model and moment matched to single gaussian distribution 
                    current_predicted_mean_image, current_predicted_var_image = \
                        patch_obj.unpatchify_mean_variance_image(mean_patch_list=current_prediction_mean_patches, 
                                                                 var_patch_list=current_prediction_var_patches, 
                                                                 patch_weight=np.ones(patch_obj.get_ypatch_dim()))

            else: 
                print("ERROR: incorrectly tried to use variance weighting here")
                assert(False)
            print("Finished Predicting: " + str(step_num + 1) + " took: " + str(time.time() - single_img_pred_start_time))
            # append 
            predicted_seq_means.append(current_predicted_mean_image)
            predicted_seq_vars.append(current_predicted_var_image)

            if save_file: 
                np.save(predicted_mean_savefilename, predicted_seq_means)
                np.save(predicted_var_savefilename, predicted_seq_vars)

            # process input patches for the next round
            current_predicted_input_mean_patches = patch_obj.patchify_image(img=predicted_seq_means[-1], img_type='x')
            current_predicted_input_var_patches  = patch_obj.patchify_image(img=predicted_seq_vars[-1], img_type='x')  
            current_predicted_input_means = np.array([patch_mean.flatten() for patch_mean in current_predicted_input_mean_patches])
            current_predicted_input_vars  = np.array([patch_var.flatten() for patch_var in current_predicted_input_var_patches])

            predicted_input_mean_patches.append(current_predicted_input_means)
            predicted_input_var_patches.append(current_predicted_input_vars)

            if len(predicted_input_mean_patches) > 3:
                predicted_input_mean_patches = predicted_input_mean_patches[1:]
                predicted_input_var_patches  = predicted_input_var_patches[1:]

            if show_intermediate_imgs:
                # display the result
                fig = plt.figure()
                pos = plt.imshow(current_predicted_mean_image)
                plt.title("Prediction: " + str(step_num + 1))
                fig.colorbar(pos)
                plt.show(block=False)
                
                # display the variance image 
                fig = plt.figure()
                pos = plt.imshow(current_predicted_var_image)
                plt.title("Variance image: " + str(step_num + 1))
                fig.colorbar(pos)
                plt.show(block=False)

        print("Finished " + str(steps) + " predictions")

        return predicted_seq_means, predicted_seq_vars

    


################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
######################################################### Debugging Code and extra functions   #################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################
################################################################################################################################################################

    # DEBUGGING HELPER
    def plot_img(self, img):
        fig = plt.figure()
        pos = plt.imshow(img)
        fig.colorbar(pos)
        plt.show()
        return 

    def data_refinement(self, kernel_threshold=None, verbose=True):
        """
        Code to refine the training datapoints that are used during prediction. 
        NOTE: the GP model must be trained and contain all the training data before this function 
        is called. 

        Methodology: 
            1. Add the first datapoint
            2. For all subsequent datapoints compute the kernel function 
            3. If the kernel function is below a certain threshold for all current training points add it to the training points
            4. Set the training data for the GP model again. 
            5. Re-optimize kernel after finished going through all the training points
            
        args: 
            - kernel_threshold: type: float: the threshold value below which you add training points to the pool of refined datapoints
                - if None: use a preset ratio of the variance of the optimized RBF kernel. 
            - verbose: type: bool: 
                - True: prints out the optimization messages
                - False: does not print out the optimization messages
        returns: 
            - Nothing
        """
        if type(None) == type(kernel_threshold): 
            kernel_threshold = 0.99 * float(self.GP_model.kern.variance)

        x_datapoints = self.GP_model.X
        y_datapoints = self.GP_model.Y

        refined_x_datapoints = []
        refined_y_datapoints = []

        for point_num in tqdm(range(x_datapoints.shape[0])):    
            current_x_datapoint = x_datapoints[point_num].reshape((1, -1))
            current_y_datapoint = y_datapoints[point_num].reshape((1, -1))

            if point_num == 0: 
                refined_x_datapoints.append(current_x_datapoint[0])
                refined_y_datapoints.append(current_y_datapoint[0])
                continue

            kernel_values = self.GP_model.kern.K(current_x_datapoint, np.array(refined_x_datapoints))
            #print(kernel_values)
            if np.all(kernel_values < kernel_threshold): 
                refined_x_datapoints.append(current_x_datapoint[0])
                refined_y_datapoints.append(current_y_datapoint[0])

        if verbose: 
            print("Initial: " + str(x_datapoints.shape[0]) + " datapoints refined to: " + str(len(refined_x_datapoints)) + " datapoints")
        # Reset the GP model training data
        self.GP_model.set_XY(X=np.array(refined_x_datapoints).reshape((len(refined_x_datapoints), -1)), 
                             Y=np.array(refined_y_datapoints).reshape((len(refined_y_datapoints), -1)))
        # Re-optimize the model 
        self.GP_model.optimize(messages=verbose)
        
        return kernel_threshold




