"""
Overall file: 

1. Load the data 
2. Pre-process the data
    2.1: Convert to image tuples
    2.2: Add different spaces if needed: Image differences, Optical Flow, Fourier Transformations
    2.3: Patchify datasets
        2.3.1: Patchification types: 
            2.3.1.1: 
            2.3.1.2: 
    2.4: Vectorize patches 
3. Train GP model: 
    3.1: Choose model type: GP Regression, Sparse GP Regression, GP Regression with white noise
    3.2: Create kernel: 
        3.2.1: Set lengthscale
        3.2.2: Set noise variance
    3.3: Optimize models - if desired
4. Predict with GP Model:  
    4.1: Predict on test set: 
        4.1.1: Predict all at once 
        4.1.2: Might need to predict one at a time: to prevent memory issues
    4.2: Sequential prediction: 
        4.2.1: Start with one starting point and roll out 
    4.3: Sequential prediction while seeing new data: 
        4.3.1: Do sequential rollouts, update the model and then continue to do sequential rollouts 
5. Unpatchify predictions - might integrate into prediction step if doing things like rollouts. 
6. Postprocessing: if needed - if predicting fourier transforms, or if predicting image differences. 
"""
import numpy as np 
import cv2 
import GPy 

import matplotlib.pyplot as plt 
import os 
import sys 

current_dir = os.getcwd()
parent_dir = os.path.join(current_dir, "..")
sys.path.append(parent_dir) # append parent directory to the path 

# import locally written files 
from patchify_unpatchify import patcher
from predict_separate import predict_separate 
from processing import processing

import pdb 
import time 
import torch
import scipy

def gaussian_kernel_2d(size=(5,5), sigma=10): 
    x = np.linspace(-(size[0]//2), (size[0]//2), size[0])
    gaussian_1d_x = (1/(sigma*np.sqrt(2)))* np.exp(-0.5 * ((x/sigma)**2))
    
    y = np.linspace(-(size[1]//2), (size[1]//2), size[1])
    gaussian_1d_y = (1/(sigma*np.sqrt(2)))* np.exp(-0.5 * ((y/sigma)**2))
    
    gaussian_2d = np.outer(gaussian_1d_x, gaussian_1d_y)
    return gaussian_2d

def create_kernel_isolated(base_kernel, kernel_groups, kernel_lengthscales, kernel_variances, patch_dim, all_xtypes):
        """
        Creates a kernel function based on the specified parameters
        args: 
            - base_kernel: This is the base kernel function to use: eg: GPy.kern.RBF
            - kernel_groups: list of lists
                - each sublist contains the xtype keys that should be 'joined' together under the same kernel 
            - kernel_lengthscales: type: list of numbers
                - the lengthscale corresponds to the lengthscale of the corresponding group in kernel_groups
                - same length as kernel_groups
            - kernel_variances: type: list of numbers 
                - the lengthscale corresponds to the variance of the corresponding group in kernel_groups
                - same length as kernel groups
			- patch_dim: type: tuple: the shape of the patches 
			- all_xtypes: type: list of strings: each string refers to an xtype
        returns: 
            - kernel to be used. 
        """
        assert(len(kernel_groups) == len(kernel_lengthscales))
        assert(len(kernel_groups) == len(kernel_variances))

        kernel_activedims = []
        single_patch_len = np.product(patch_dim)

        for group in kernel_groups:
            current_activedims = []
            for xtype in group:     
                if xtype not in all_xtypes: 
                    print("ERROR: incorrect group for forming the kernel")
                    assert(False)
                xtype_index = all_xtypes.index(xtype)
                current_activedims.extend(list(np.arange(single_patch_len*xtype_index, single_patch_len*(xtype_index + 1)))) # add one to account for 0 index

            kernel_activedims.append(current_activedims)

        # print("length kernel active dims: ", len(kernel_activedims))
        # print(kernel_activedims)
        # create the kernels and then multiply them together
        sub_kernels = []        
        for i, curr_kernel_activedims in enumerate(kernel_activedims): 
            current_kernel = base_kernel(input_dim=len(curr_kernel_activedims), active_dims=curr_kernel_activedims, 
                lengthscale=kernel_lengthscales[i], variance=kernel_variances[i])
            sub_kernels.append(current_kernel)

        # create the final kernel
        kernel = None
        for sub_kernel in sub_kernels: 
            if type(kernel) == type(None):
                kernel = sub_kernel
            else: 
                kernel = kernel * sub_kernel
        return kernel

class nonparam_predictor_separate: 
    """
    This file 
    """
    def __init__(self, patch_parameters, processor_parameters, max_all_together):
        """
        Initializes the whole predictor using the following parameters: 
        args: 
            - patch_parameters: dictionary of the arguments for the patcher class 
                - 'img_dim': tuple (rows, cols): dimension of the image
                - 'patch_dim': tuple (rows, cols): dimension of the patch
                - 'patch_border': tuple (rows, cols): number of rows, cols that are used as the border of the patch 
                - 'img_padlen': tuple (rows, cols): number of rows, cols used to pad the image
                - 'img_padtype': string: to indicate the method used for padding the image. For viable options look at self.pad_image 
                - 'wrap_x': whether you want the patches to wrap in the x or column direction
                - 'wrap_y': whether you want the patches to wrap in the y or row direction
                - 'stride': tuple (row stride, col stride): the number of pixels to skip by when creating image patches
                - 'patch_weight': the patch weight to use when recombining  
            - processor_parameters: dictionary of th earguments for the processing class
                - 'xtypes': list of strings - order indicates order of types of x components in dataset
                    - 'img-2': 2 images before current img
                    - 'img-1': 1 image before current img
                    - 'img0': current img - when predicting the next
                    - 'vel_diff': image difference velocity ('img0' - 'img-1')
                    - 'accel_diff': image difference acceleration (('img0' - 'img-1') - ('img-1'  'img-2'))
                    - 'of_x': optical flow x 
                    - 'of_y': optical flow y
                - 'ytype': a string that indicates the type of the y components in dataset (also what is to be predicted)
                    - 'img': the next image 
                    - 'diff': difference between the last image and the next
                - 'farneback_flow_params': parameters for optical flow - can be/default to None if not using any optical flow
            - max_all_together: parameter for the predict class - when not processing all test vectors together how many to process at once
        """ 

        self.patch_parameters = patch_parameters
        self.processor_parameters = processor_parameters
        self.max_all_together = max_all_together

        self.patch_obj = patcher(img_dim=self.patch_parameters['img_dim'], 
                                 patch_dim=self.patch_parameters['patch_dim'], 
                                 patch_border=self.patch_parameters['patch_border'], 
                                 img_padlen=self.patch_parameters['img_padlen'], 
                                 img_padtype=self.patch_parameters['img_padtype'], 
                                 wrap_x=self.patch_parameters['wrap_x'], 
                                 wrap_y=self.patch_parameters['wrap_y'], 
                                 stride=self.patch_parameters['stride'], 
                                 patch_weight=self.patch_parameters['patch_weight'])

        self.data_processor = processing(xtypes=self.processor_parameters['xtypes'], 
                                         ytype=self.processor_parameters['ytype'], 
                                         farneback_flow_params=self.processor_parameters['farneback_flow_params'])
        self.predictor = predict_separate(max_all_together=max_all_together)
        
        self.GP_model = None # this will be a list of GP models
        return 

    def create_kernel(self, base_kernel, kernel_groups, kernel_lengthscales, kernel_variances):
        """
        Creates a kernel function based on the specified parameters
        args: 
            - Implicit parameters: self.xtypes - already in the object
            - base_kernel: This is the base kernel function to use: eg: GPy.kern.RBF
            - kernel_groups: list of lists
                - each sublist contains the xtype keys that should be 'joined' together under the same kernel 
            - kernel_lengthscales: type: list of numbers
                - the lengthscale corresponds to the lengthscale of the corresponding group in kernel_groups
                - same length as kernel_groups
            - kernel_variances: type: list of numbers 
                - the lengthscale corresponds to the variance of the corresponding group in kernel_groups
                - same length as kernel groups
        returns: 
            - kernel to be used. 
        """
        assert(len(kernel_groups) == len(kernel_lengthscales))
        assert(len(kernel_groups) == len(kernel_variances))

        kernel_activedims = []
        single_patch_len = np.product(self.patch_obj.patch_dim)

        for group in kernel_groups:
            current_activedims = []
            for xtype in group:     
                if xtype not in self.data_processor.xtypes: 
                    print("ERROR: incorrect group for forming the kernel")
                    assert(False)
                xtype_index = self.data_processor.xtypes.index(xtype)
                current_activedims.extend(list(np.arange(single_patch_len*xtype_index, single_patch_len*(xtype_index + 1)))) # add one to account for 0 index

            kernel_activedims.append(current_activedims)

        # create the kernels and then multiply them together
        sub_kernels = []        
        for i, curr_kernel_activedims in enumerate(kernel_activedims): 
            current_kernel = base_kernel(input_dim=len(curr_kernel_activedims), active_dims=curr_kernel_activedims, 
                lengthscale=kernel_lengthscales[i], variance=kernel_variances[i])
            sub_kernels.append(current_kernel)

        # create the final kernel
        kernel = None
        for sub_kernel in sub_kernels: 
            if type(kernel) == type(None):
                kernel = sub_kernel
            else: 
                kernel = kernel * sub_kernel
        return kernel
    
    def train_models(self, kernels, datapoints, optimize=False, noise_var=0.01, use_sparse_GP=False, num_inducing_points=100, max_opt_iters=None):
        """
        Trains the GP model. To do this it first converts the train_image_seq into the desired training data. 
        and then trains the GP model using the created dataset. 

        args: 
            - kernels: list of initialized GPy kernels
            - datapoints: dictionary containing - only has one of the below
                Either: 
                    - 'image_seq': sequence of images to be converted into a patchified, vectorized dataset
                Or: 
                    - 'x_dataset': pre-processed dataset of images
                    - 'y_dataset': pre-processed dataset of images corresponding to the given 'x_dataset'
            - optimize: if true will optimize the GP model kernel and other parameters
            - noise_var: the noise variance to use for the GP Regression model
            - use_sparse_GP: type: bool
            	- True: use the sparse GP regression model 
            	- False: use normal GP regression model 
            - num_inducing_points: type: int: the number of inducing points to use in the sparse GP model. not used if use_sparse_GP is False. 
            - max_opt_iters: type: int: the maximum optimization iterations to run - default is None - dont use it
        returns: None
            - the trained model is stored in the attribute self.GP_model

        # TODO: Maybe later add support for other types of GP models
        """

        self.GP_models = self.predictor.train_models(kernels=kernels, datapoints=datapoints, 
                                                   data_processor=self.data_processor, 
                                                   patch_obj=self.patch_obj, 
                                                   optimize=optimize, 
                                                   noise_var=noise_var, 
                                                   use_sparse_GP=use_sparse_GP, 
                                                   num_inducing_points=num_inducing_points,
                                                   max_opt_iters=max_opt_iters)
        self.kernels = [GP_model.kern for GP_model in self.GP_models] 
        return self.GP_models

    def add_datapoints(self, datapoints):
        """
        Add datapoints to the trained GP model. 
        args: 
            - datapoints: dictionary containing - only has one of the below
                Either: 
                    - 'image_seq': sequence of images to be converted into a patchified, vectorized dataset
                Or: 
                    - 'x_dataset': pre-processed dataset of images
                    - 'y_dataset': pre-processed dataset of images corresponding to the given 'x_dataset'
        returns: 
            - None: but update the GP_model
        """
        self.predictor.add_datapoints_to_GPs(GP_models=self.GP_models, 
                                            datapoints=datapoints, 
                                            optimize=optimize, 
                                            data_processor=self.data_processor, 
                                            patch_obj=self.patch_obj)
        return 

    def stable_cho_inverse(self, matrix, cholesky=False):
        """
        More stable inverse using the cholesky decomposition and least squares type solutions
        NOTE: the matrix for inversion must be Hermetian, Positive Definite. 

        args: 
            - matrix: the matrix you want to invert
                - can also enter the cholesky decomposition (lower triangular part) directly
            - cholesky: boolean: 
                - True: the matrix entered is the lower triangular cholesky decomposition
                - False: the matrix entered is the full matrix - DEFAULT option
        returns: 
            - matrix_inv: the inverse of the matrix
        """

        if not cholesky: 
            L = scipy.linalg.cholesky(a=matrix, lower=True)
        else:
            L = matrix

        matrix_inv = scipy.linalg.cho_solve((L, True), b=np.eye(L.shape[0]), overwrite_b=True)
        return matrix_inv

    def setup_mean_var_propogation_parameters(self, patch_obj, single_kernel=False, svd_inverse = False, sigma_threshold = None, plot_for_thresholding = False):
        """
        Setup the self. parameters for the mean variance propogation functions. 
        args: 
            - patch_obj: the patch object that is used for the test rollout
            - single_kernel: type: boolean: 
                - True:  If True then there is one single kernel with one variance and lengthscale. 
                - False: If False then the kernel is formed by multiplying together 3 separate kernels. 
            - svd_inverse: argument to calculate_ik_beta
            - sigma_threshold: argument to calculate_ik_beta
            - plot_for_thresholding: argument to calculate_ik_beta
        """
        # START SETUP 
        # Separate out the kernels - NOTE: these are all the same for each image space for now - structured to make it easy to change later
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        pixels_per_inpatch = np.product(patch_obj.patch_dim)
        if not single_kernel: 
            img_space_lengthscales = [float(self.GP_models[i].kern.lengthscale) for i in range(len(self.GP_models))]
            img_space_variances    = [float(self.GP_models[i].kern.variance)**(1/3) for i in range(len(self.GP_models))]# alpha part of rbf kernel
        else: 
            # a separate kernel per image space
            img_space_lengthscales = [float(self.GP_models[i].kern.rbf.lengthscale) *  float(self.GP_models[i].kern.rbf_1.lengthscale) * float(self.GP_models[i].kern.rbf_2.lengthscale) for i in range(len(self.GP_models))]
            img_space_variances    = [float(self.GP_models[i].kern.rbf.variance) * float(self.GP_models[i].kern.rbf_1.variance) * float(self.GP_models[i].kern.rbf_2.variance) for i in range(len(self.GP_models))]

        # index i is the lengthscale matrix for output dimension i
        self.lengthscale_matrices_dim0 = np.array([np.eye(pixels_per_inpatch) * img_space_lengthscales[i]**2 for i in range(len(img_space_lengthscales))] )
        self.lengthscale_matrices_dim1 = np.array([np.eye(pixels_per_inpatch) * img_space_lengthscales[i]**2 for i in range(len(img_space_lengthscales))] )
        self.lengthscale_matrices_dim2 = np.array([np.eye(pixels_per_inpatch) * img_space_lengthscales[i]**2 for i in range(len(img_space_lengthscales))] )
        self.lengthscale_matrices_dim01 = np.array([scipy.linalg.block_diag(self.lengthscale_matrices_dim0[i], self.lengthscale_matrices_dim1[i]) for i in range(len(img_space_lengthscales))] )
        self.lengthscale_matrices_dim012 = np.array([scipy.linalg.block_diag(self.lengthscale_matrices_dim01[i], self.lengthscale_matrices_dim2[i]) for i in range(len(img_space_lengthscales))] )

        self.lengthscale_matrices_dim0_inv = np.array([np.eye(pixels_per_inpatch) * 1/(img_space_lengthscales[i]**2) for i in range(len(img_space_lengthscales))] )
        self.lengthscale_matrices_dim1_inv = np.array([np.eye(pixels_per_inpatch) * 1/(img_space_lengthscales[i]**2) for i in range(len(img_space_lengthscales))] )
        self.lengthscale_matrices_dim2_inv = np.array([np.eye(pixels_per_inpatch) * 1/(img_space_lengthscales[i]**2) for i in range(len(img_space_lengthscales))] )
        self.lengthscale_matrices_dim01_inv = np.array([scipy.linalg.block_diag(self.lengthscale_matrices_dim0_inv[i], self.lengthscale_matrices_dim1_inv[i]) for i in range(len(img_space_lengthscales))] )
        self.lengthscale_matrices_dim012_inv = np.array([scipy.linalg.block_diag(self.lengthscale_matrices_dim01_inv[i], self.lengthscale_matrices_dim2_inv[i]) for i in range(len(img_space_lengthscales))])

        # load inverse matrices into gpu
        self.lengthscale_matrices_dim0_inv_tr = torch.cuda.DoubleTensor(np.array(self.lengthscale_matrices_dim0_inv))
        self.lengthscale_matrices_dim01_inv_tr = torch.cuda.DoubleTensor(np.array(self.lengthscale_matrices_dim01_inv))
        self.lengthscale_matrices_dim012_inv_tr = torch.cuda.DoubleTensor(np.array(self.lengthscale_matrices_dim012_inv))

        self.variances_dim0 = np.array([img_space_variances[i] for i in range(len(img_space_variances))])
        self.variances_dim1 = np.array([img_space_variances[i] for i in range(len(img_space_variances))])
        self.variances_dim2 = np.array([img_space_variances[i] for i in range(len(img_space_variances))])
        self.full_variances = np.array([float(self.GP_models[i].kern.variance) for i in range(len(self.GP_models))])


        self.kernels_dim0 = [GPy.kern.RBF(input_dim=pixels_per_inpatch, lengthscale=img_space_lengthscales[i], variance=self.variances_dim0[i]) for i in range(len(img_space_lengthscales))]# kernels for the 0th image patch in the input vec
        self.kernels_dim1 = [GPy.kern.RBF(input_dim=pixels_per_inpatch, lengthscale=img_space_lengthscales[i], variance=self.variances_dim1[i]) for i in range(len(img_space_lengthscales))]# kernels for the 1st image patch in the input vec
        self.kernels_dim2 = [GPy.kern.RBF(input_dim=pixels_per_inpatch, lengthscale=img_space_lengthscales[i], variance=self.variances_dim2[i]) for i in range(len(img_space_lengthscales))]# kernels for the 2nd image patch in the input vec

        # Process Data: 
        # Seperate datasets by image patch dimensions
        self.dim0_start_idx = 0 
        self.dim0_end_idx = self.dim0_start_idx + pixels_per_inpatch
        self.dim1_start_idx = self.dim0_end_idx
        self.dim1_end_idx = self.dim1_start_idx + pixels_per_inpatch
        self.dim2_start_idx = self.dim1_end_idx
        self.dim2_end_idx = self.dim2_start_idx + pixels_per_inpatch

        # Training data: seperate the image patch dimensions
        self.full_train_xvecs = np.array(self.GP_models[0]._predictive_variable)
        self.num_training_points = self.full_train_xvecs.shape[0]
        self.dim0_train_xvecs = self.full_train_xvecs[:, self.dim0_start_idx:self.dim0_end_idx]
        self.dim1_train_xvecs = self.full_train_xvecs[:, self.dim1_start_idx:self.dim1_end_idx]
        self.dim2_train_xvecs = self.full_train_xvecs[:, self.dim2_start_idx:self.dim2_end_idx]
        self.dim01_train_xvecs = self.full_train_xvecs[:, self.dim0_start_idx:self.dim1_end_idx]
        # END SETUP

        ###########################################################################
        # Set up: pre-process variables needed for all predictions
        ###########################################################################
        print("Creating the Ba vector")
        start_time = time.time()

        self.K_trains_noise, self.K_trains_noise_inv, self.Ba_vectors = self.calculate_iK_beta(svd_inverse = svd_inverse, sigma_threshold = sigma_threshold, plot_for_thresholding = plot_for_thresholding)
        
        elapsed_time = time.time() - start_time
        print("Finished creating the Ba vector: ", elapsed_time)

        return 

    def calculate_iK_beta(self, svd_inverse = False, sigma_threshold = None, plot_for_thresholding = False):
        """
        Calculate the inverse of the training kernel gram matrix and the beta vector used in the PILCO paper. 
        args: 
            - implicit argument gp_model: GPy gp model
            - svd_inverse: type: boolean: 
                - True: use the svd inverse with the specified sigma cutoff
                - False: use the matrices returned by GPy directly 
            - sigma_threshold: 
                - None: take the sigma values up to the "large" dropoff
                    - We define a large dropoff as if the change in values between the next two sigma values is less than 1 order of magnitude compared to the change in the last 2 values
                - float: the maximum sigma value you want to keep before inverting
            - plot_for_thresholding: type: boolean:
                - False: just use the specified thresholds 
                - True: plot the sigma values and go into a set trace to allow the sigma threshold to be set manually by the user. 

            NOTE: order of precedence: svd_inverse,  sigma_threshold, then plot_for_thresholding
        returns: 
            - inv_K: inverse of the kernel gram matrix
            - beta: the beta vector used in the mean and covariance calculation
        """
        if not svd_inverse: 
            # Use GPy matrices 
            train_data = self.GP_models[0]._predictive_variable

            K_trains = [self.GP_models[i].kern.K(train_data, train_data) for i in range(len(self.GP_models))]
            K_trains_noise = [K_trains[i] + (np.eye(K_trains[0].shape[0])*float(self.GP_models[i].Gaussian_noise.variance)) for i in range(len(self.GP_models))]
            K_trains_noise = [K_trains_noise[i] + (np.eye(K_trains_noise[i].shape[0])*1e-8) for i in range(len(K_trains_noise))]# additional noise for invertibility

            K_trains_noise_inv = [self.GP_models[i].posterior.woodbury_inv for i in range(len(self.GP_models))]#self.stable_cho_inverse(matrix=L, cholesky=True)
            betas = [self.GP_models[i].posterior.woodbury_vector for i in range(len(self.GP_models))]

            return np.array(K_trains_noise), np.array(K_trains_noise_inv), np.array(betas)

        else: 
            # Stable inverse calculation using the SVD
            from GPy.util.linalg import jitchol, dtrtri, dpotri, dpotrs
            X_train = self.GP_models[0]._predictive_variable
            Ks = [self.GP_models[i].kern.K(X_train, X_train) for i in range(len(self.GP_models))]
            noise_variances = [float(self.GP_models[i].Gaussian_noise.variance) for i in range(len(self.GP_models))]
            jitter = 1e-8
            K_noises = [Ks[i] + ((noise_variances[i] + jitter) * np.eye(Ks[0].shape[0])) for i in range(len(self.GP_models))]
            
            # compute the svds - inverses and beta vectors 
            iKs = []
            betas = []
            for model_num in range(len(self.GP_models)):
                umat, svec, vhmat = np.linalg.svd(K_noises[model_num])

                if type(sigma_threshold) == type(None):
                    keep_threshold = None
                    # Find the dropoff point
                    for i in range(1, len(svec) - 1): 
                        last_difference = np.abs(svec[i-1] - svec[i])
                        next_difference = np.abs(svec[i] - svec[i+1])

                        if (last_difference/next_difference) > 10: 
                            keep_threshold = svec[i]
                            break 
                    if type(keep_threshold) == type(None):
                        keep_threshold = svec[-1] - np.finfo(float).eps
                else: 
                    keep_threshold = float(sigma_threshold)

                if plot_for_thresholding: 
                    # plot and let the user manually input the threshold
                    fig = plt.figure()
                    plt.title("Sigma values K train noise")
                    plt.plot(svec.flatten())
                    plt.show()

                    print("Sigma values are stored in variable svec")
                    print("Default kee_threshold: ", keep_threshold)
                    print("Set variable: \'keep_threshold\' as the maximum sigma value you want to keep")
                    pdb.set_trace()

                # truncate the sigma vector 
                len_svec = len(svec)

                print("Keep threshold: ", keep_threshold)
                svec_truncated_inv = np.zeros(len_svec)
                svec_truncated_inv[np.abs(svec) > keep_threshold] = 1/np.array(svec[np.abs(svec) > keep_threshold]) 

                svec_truncated_inv_mat = np.diag(svec_truncated_inv)
                iK = vhmat.T@svec_truncated_inv_mat@umat.T
                beta = iK@self.GP_models[model_num].Y

                iKs.append(iK)
                betas.append(beta)

            return np.array(K_noises), np.array(iKs), np.array(betas)