"""
File for kernel regression to use as a baseline. 
"""

import numpy as np 
import cv2 
import GPy 

import matplotlib.pyplot as plt 
import os 
import sys 

current_dir = os.getcwd()
parent_dir = os.path.join(current_dir, "..")
sys.path.append(parent_dir) # append parent directory to the path 

# import locally written files 
from patchify_unpatchify import patcher
from predict import predict 
from processing import processing

def compute_time_ein(x):
	start_time = time.time()
	ret = np.einsum('ij,jk',x,x,optimize=False)
	time_taken = time.time() - start_time 
	print("Time taken: ", time_taken)
	return ret, time_taken

class kernel_regression(): 
    """
    Method: 
        1. Pre-process data: image, difference images, optical flow images 
        2. Patchify the images to generate the training data 
        3. Store all the training data patches to use online during kernel regression 
        4. Create the kernel to use for similarity comparisons for kernel regression. (might need to use GPy to optimize kernel parameters ? )
        5. Use the training data and the kernel for prediction: 
            5.1. First Pre-process on the test images 
            5.2. Patchify the test images 
            5.3. Use the kernel and the trianing images to predict on the test patches 
            5.4. Unpatchify the predicted patches 
            5.5. Repeat and continue to rollout (if needed)
    """

    def __init__(self, patch_parameters, processor_parameters):
        """
        Initialization function for the kernel regression class. 
        To intiailize this class need to initialize the patchify_unpatchify object and the processing object 
        args: 
            - patch_parameters: dictionary of the arguments for the patcher class 
                - 'img_dim': tuple (rows, cols): dimension of the image
                - 'patch_dim': tuple (rows, cols): dimension of the patch
                - 'patch_border': tuple (rows, cols): number of rows, cols that are used as the border of the patch 
                - 'img_padlen': tuple (rows, cols): number of rows, cols used to pad the image
                - 'img_padtype': string: to indicate the method used for padding the image. For viable options look at self.pad_image 
                - 'wrap_x': whether you want the patches to wrap in the x or column direction
                - 'wrap_y': whether you want the patches to wrap in the y or row direction
                - 'stride': tuple (row stride, col stride): the number of pixels to skip by when creating image patches
                - 'patch_weight': the patch weight to use when recombining  
            - processor_parameters: dictionary of th earguments for the processing class
                - 'xtypes': list of strings - order indicates order of types of x components in dataset
                    - 'img-2': 2 images before current img
                    - 'img-1': 1 image before current img
                    - 'img0': current img - when predicting the next
                    - 'vel_diff': image difference velocity ('img0' - 'img-1')
                    - 'accel_diff': image difference acceleration (('img0' - 'img-1') - ('img-1'  'img-2'))
                    - 'of_x': optical flow x 
                    - 'of_y': optical flow y
                - 'ytype': a string that indicates the type of the y components in dataset (also what is to be predicted)
                    - 'img': the next image 
                    - 'diff': difference between the last image and the next
                - 'farneback_flow_params': parameters for optical flow - can be/default to None if not using any optical flow
        """

        self.patch_parameters = patch_parameters
        self.processor_parameters = processor_parameters

        self.patch_obj = patcher(img_dim=self.patch_parameters['img_dim'], 
                                 patch_dim=self.patch_parameters['patch_dim'], 
                                 patch_border=self.patch_parameters['patch_border'], 
                                 img_padlen=self.patch_parameters['img_padlen'], 
                                 img_padtype=self.patch_parameters['img_padtype'], 
                                 wrap_x=self.patch_parameters['wrap_x'], 
                                 wrap_y=self.patch_parameters['wrap_y'], 
                                 stride=self.patch_parameters['stride'], 
                                 patch_weight=self.patch_parameters['patch_weight'])

        self.data_processor = processing(xtypes=self.processor_parameters['xtypes'], 
                                         ytype=self.processor_parameters['ytype'], 
                                         farneback_flow_params=self.processor_parameters['farneback_flow_params'])
        
        self.kernel = None
        return 

    def create_kernel(self, base_kernel, kernel_groups, kernel_lengthscales, kernel_variances):
        """
        Creates a kernel function based on the specified parameters
        args: 
            - Implicit parameters: self.xtypes - already in the object
            - base_kernel: This is the base kernel function to use: eg: GPy.kern.RBF
            - kernel_groups: list of lists
                - each sublist contains the xtype keys that should be 'joined' together under the same kernel 
            - kernel_lengthscales: type: list of numbers
                - the lengthscale corresponds to the lengthscale of the corresponding group in kernel_groups
                - same length as kernel_groups
            - kernel_variances: type: list of numbers 
                - the lengthscale corresponds to the variance of the corresponding group in kernel_groups
                - same length as kernel groups
        returns: 
            - kernel to be used. 
        """
        assert(len(kernel_groups) == len(kernel_lengthscales))
        assert(len(kernel_groups) == len(kernel_variances))

        kernel_activedims = []
        single_patch_len = np.product(self.patch_obj.patch_dim)

        for group in kernel_groups:
            current_activedims = []
            for xtype in group:     
                if xtype not in self.data_processor.xtypes: 
                    print("ERROR: incorrect group for forming the kernel")
                    assert(False)
                xtype_index = self.data_processor.xtypes.index(xtype)
                current_activedims.extend(list(np.arange(single_patch_len*xtype_index, single_patch_len*(xtype_index + 1)))) # add one to account for 0 index

            kernel_activedims.append(current_activedims)

        # print("length kernel active dims: ", len(kernel_activedims))
        # print(kernel_activedims)
        # create the kernels and then multiply them together
        sub_kernels = []        
        for i, curr_kernel_activedims in enumerate(kernel_activedims): 
            current_kernel = base_kernel(input_dim=len(curr_kernel_activedims), active_dims=curr_kernel_activedims, 
                lengthscale=kernel_lengthscales[i], variance=kernel_variances[i])
            sub_kernels.append(current_kernel)

        # create the final kernel
        kernel = None
        for sub_kernel in sub_kernels: 
            if type(kernel) == type(None):
                kernel = sub_kernel
            else: 
                kernel = kernel * sub_kernel

        self.kernel = kernel 
        return kernel 

    def train(self, data, optimize_kernel=False): 
        """
        Set up the training data to be used by the model 
        args: 
            - data: type: dictionary: can specify either the image sequence of the data processor pre-procesed datasets
                - 'image_seq': 
                OR 
                - 'xdataset': pre-processed image tuples
                - 'ydataset': the y data that corresponds to the pre-processed image tuples in 'xdataset'
            - optimize_kernel: type: boolean
                - True: creates a gp model and uses its optimize function to optimize kernel parameters 
                - False: does not optimize the kernel. 
        returns: 
            - None: trains the model by storing the datasets
        """
        assert(type(self.kernel) != type(None))

        if 'image_seq' in data.keys(): 
            xdataset, ydataset = self.data_processor.create_xy(data['image_seq'])
        elif 'xdataset' in data.keys() and 'ydataset' in data.keys(): 
            xdataset = data['xdataset']
            ydataset = data['ydataset']
        else: 
            print("ERROR: proper data was not provided for training")
            assert(False)

        # patchify dataset
        xdataset_patches = self.patch_obj.patchify_dataset(dataset=xdataset, dataset_type='x')
        ydataset_patches = self.patch_obj.patchify_dataset(dataset=ydataset, dataset_type='y')

        # vectorize dataset
        xdataset_patch_vecs = np.array(self.data_processor.convert_imgdataset_to_vecdataset(xdataset_patches))
        ydataset_patch_vecs = np.array(self.data_processor.convert_imgdataset_to_vecdataset(ydataset_patches))
        num_patch_vecs = len(xdataset_patch_vecs)

        self.x = xdataset_patch_vecs.reshape((num_patch_vecs, -1))
        self.y = ydataset_patch_vecs.reshape((num_patch_vecs, -1))

        if optimize_kernel: 
            gpy_model = GPy.models.GPRegression(X=self.x, 
                                                Y=self.y, 
                                                kernel=self.kernel, 
                                                noise_var=0.01)
            gpy_model.optimize(messages=True)
            self.kernel = gpy_model.kern

        return 

    def predict_patchvec(self, test_patchvec):
        """
        Return the predicted output using prediction on the test patchvec
        args: 
            - test_patchvec: x dataset patch that is vectorized 
        returns: 
            - the predicted output vectorized patch 
        """
        test_point = np.array(test_patchvec).reshape((1, -1))
        kernel_weights = self.kernel.K(self.x, test_point)
        total_weight = np.sum(kernel_weights) + np.finfo(float).eps

        weighted_outputs = kernel_weights * self.y
        predicted_output = np.sum(weighted_outputs, axis=0)/total_weight

        return predicted_output

    def predict_patchvecs_together(self, test_patchvecs, max_all_together):
        """
        Use matrix operations to predict multiple patches together
        args: 
            - max_all_together: if predicting everything together
                - 0: predicting everything together
                - number: predict max_all_together together at once
        """
        if max_all_together == 0: 
            test_points = np.array(test_patchvecs).reshape((len(test_patchvecs), -1))
            kernel_weights = self.kernel.K(test_points, self.x)
            total_weight = np.sum(kernel_weights, axis=1)
            total_weight = total_weight.reshape((len(total_weight), -1)) + np.finfo(float).eps

            weighted_mul = np.dot(kernel_weights, self.y)
            predicted_outputs = weighted_mul/total_weight
        else: 
            counter = 0
            predicted_outputs = []
            while counter < len(test_patchvecs): 
                curr_test_patchvecs = test_patchvecs[counter: min(counter + max_all_together, len(test_patchvecs))]
                test_points = np.array(curr_test_patchvecs).reshape((len(curr_test_patchvecs), -1))
                kernel_weights = self.kernel.K(test_points, self.x)
                total_weight = np.sum(kernel_weights, axis=1)
                total_weight = total_weight.reshape((len(total_weight), -1)) + np.finfo(float).eps

                weighted_mul = np.dot(kernel_weights, self.y)
                curr_predicted_output = weighted_mul/total_weight

                predicted_outputs.extend(list(curr_predicted_output))
                counter += max_all_together

        assert(len(predicted_outputs) == len(test_patchvecs))
        predicted_outputs = np.array(predicted_outputs)
        return predicted_outputs

    def full_predict_dataset(self, test_set, test_set_type, max_all_together=1000, test_patch_obj=None):
        """
        Use the internal GP model to predict on the specified test dataset using the predict class
        args: 
            - test_set: type: list: This can be either just a sequence of images or a sequence of image tuples - the x of a dataset
            - test_set_type: type: string: the type of test_set inputted 
                - 'image_seq': This indicates the test set is a sequence of images that must be pre-processed into the appropriate x dataset 
                - 'xdataset': Pre-processed x dataset that can be patchified, vectorized and put into the model 
            - all_together: If the model should predict the dataset all together or parts at a time (memory, big matrix)
            - use_variance_weighting: whether you want to use intelligent weighting when unpatchifying images
            - test_patch_obj: type: patcher: Separate patcher to use for prediction, if None use self.patch_obj
        """ 

        if test_set_type == 'image_seq': 
            test_xdataset = self.data_processor.create_x(image_seq=test_set)
        elif test_set_type == 'xdataset': 
            test_xdataset = test_set
        else: 
            print("ERROR: data type of the test set is incorrect")
            assert(False)

        if type(test_patch_obj) != type(None): 
            patch_obj = test_patch_obj
        else: 
            patch_obj = self.patch_obj

        test_xdataset_patches = patch_obj.patchify_dataset(dataset=test_xdataset, dataset_type='x')
        test_xdataset_patchvecs = self.data_processor.convert_imgdataset_to_vecdataset(dataset=test_xdataset_patches)

        predicted_patchvecs = self.predict_patchvecs_together(test_xdataset_patchvecs, max_all_together)

        """
        # predict all the patches 
        predicted_patchvecs = []
        counter = 0 

        for test_patchvec in test_xdataset_patchvecs: 
            predicted_patchvec = self.predict_patchvec(test_patchvec=test_patchvec)
            predicted_patchvecs.append(predicted_patchvec)

            if counter % 500 == 0: 
                print("Finished processing: " + str(counter) + "/" + str(len(test_xdataset_patchvecs))) 
            counter += 1
        """

        # unvectorize the patches 
        predicted_patches = [patchvec.reshape(self.patch_obj.get_ypatch_dim()) for patchvec in predicted_patchvecs]
        # re-create the images from the patches 
        pred_images, _, _, _ = patch_obj.unpatchify_dataset(dataset_patch_list=predicted_patches, 
                                                            dataset_patch_variance_list=None, 
                                                            dataset_type='y')
        return pred_images

    def full_pred_sequential(self, starting_x_images, steps, max_all_together=1000, test_patch_obj=None): 
        """
        Starting from the image datapoint formed from starting_x_images rollout predictions for 'steps' steps
        args: 
            - starting_x_images: list of starting images: depending on the dataprocessor up to 3 images are needed
                - this is used to start the prediction process
            - steps: number of steps the predict/rollout
            - all_together: If the model should predict the dataset all together or parts at a time (memory, big matrix)
            - use_variance_weighting: whether you want to use intelligent weighting when unpatchifying images
            - test_patch_obj: type: patcher: Separate patcher to use for prediction, if None use self.patch_obj
        returns: 
            - see predict.pred_sequential
        """
        pred_seq_images = []

        curr_test_x = list(starting_x_images)
        print("Starting image list shape: ", np.array(curr_test_x).shape)

        if type(test_patch_obj) != type(None): 
            patch_obj = test_patch_obj
        else: 
            patch_obj = self.patch_obj

        for step in range(steps): 
        
            # NOTE: add a padder image as create_x always leaves the last image for the y dataset creator
            processed_curr_dataset = self.data_processor.create_x(image_seq=list(curr_test_x) + [np.zeros(curr_test_x[0].shape)])
            patch_processed_curr_dataset = patch_obj.patchify_dataset(dataset=processed_curr_dataset, dataset_type='x')
            vec_patch_processed_curr_dataset = self.data_processor.convert_imgdataset_to_vecdataset(dataset=patch_processed_curr_dataset)

            # perform the prediction to get the next image or diff image
            #pred_vec_patches = [self.predict_patchvec(patch_vec) for patch_vec in vec_patch_processed_curr_dataset]
            pred_vec_patches = self.predict_patchvecs_together(vec_patch_processed_curr_dataset, max_all_together)
            pred_patches = [vec.reshape(patch_obj.get_ypatch_dim()) for vec in pred_vec_patches]
            pred_image, _, _, _, _ = patch_obj.unpatchify_image(patch_list=pred_patches, patch_variance_list=None, img_type='y')

            # if the image is a difference image handle accordingly 
            if self.data_processor.ytype == 'diff': 
                diff_image = pred_image 
                last_image = curr_test_x[-1]
                pred_image = last_image + diff_image

            pred_seq_images.append(pred_image)
            curr_test_x.append(pred_image)
            curr_test_x = curr_test_x[1:]

            if step % 1 == 0: 
                print("Finished predicting image: ", step)

        return pred_seq_images