# This File contains the functions used to predict outputs on test data 

import numpy as np 
import cv2 
import matplotlib.pyplot as plt 
import GPy

from patchify_unpatchify import patcher
from processing import processing

class predict: 

    def __init__(self, max_all_together = 1000):
        # maximum predictions to be done at once if all predictions not to be processed together.  
        self.max_all_together = max_all_together
        return 

    def pred_test_set(self, test_x, data_processor, GP_model, patch_obj, all_together, use_variance_weighting): 
        """
        Predicts all the corresponding y outputs for a given test x dataset
        args: 
            - test_x: list of images/image tuples to be used as x values for prediction 
            - data_processor: object of class processing used to pre and post process the images to the requisite data types for the GP
            - GP_model: GP model object of type GPy's GP model
            - patch_obj: object of type patcher to help patchify and unpatchify the images
            - all_together: boolean indicator
                - True: predict the entire dataset in one go 
                - False: predict the dataset one image at a time - NOTE: TODO: might need to change this to x patches at a time. 
            - use_variance_weighting: 
                - True: use variance for the mask weighting
                - False: do not use variance - just do averaging
        returns: 
            - pred_images: list of unpatchified images that are predicted from the test_x set
            - pred_var_images: 
                - returns this if  use_variance_weighting is set to True
        """
        # patchify the test dataset
        test_x_patches = patch_obj.patchify_dataset(dataset=test_x, dataset_type='x')
        # vectorize the test dataset
        test_x_patches_vecs = data_processor.convert_imgdataset_to_vecdataset(dataset=test_x_patches)
        # predict using the vectors
        if all_together:
            pred_test_ymean_vecs, pred_test_yvar_list = GP_model.predict(Xnew=np.array(test_x_patches_vecs).reshape((len(test_x_patches_vecs), -1)))
        else:
            # go through some number of the test datapoints at a time
            start_range = 0 
            pred_test_ymean_vecs = []
            pred_test_yvar_list = []
            while start_range < len(test_x_patches_vecs):
                end_range = min(start_range + self.max_all_together, len(test_x_patches_vecs))
                predict_subset_x = test_x_patches_vecs[start_range:end_range]
                predict_subset_ymean, predict_subset_yvar = GP_model.predict(Xnew=np.array(predict_subset_x).reshape((len(predict_subset_x), -1)))
                # append
                pred_test_ymean_vecs.extend(predict_subset_ymean)
                pred_test_yvar_list.extend(predict_subset_yvar)

                start_range += self.max_all_together

        # reshape the means and vars 
        print("Finished prediction")
        ypatch_dim = patch_obj.get_ypatch_dim()
        pred_test_ymean_patches = [vec.reshape(ypatch_dim) for vec in pred_test_ymean_vecs]
        print("Finished reshaping")
        # get images from the predicted patches. 
        if use_variance_weighting:
            pred_images, pred_var_images, _, _ = patch_obj.unpatchify_dataset(dataset_patch_list=pred_test_ymean_patches, dataset_patch_variance_list=pred_test_yvar_list, 
                                                                        dataset_type='y')
        else:
            pred_images, pred_var_images, _, _ = patch_obj.unpatchify_dataset(dataset_patch_list=pred_test_ymean_patches, dataset_patch_variance_list=None, 
                                                                        dataset_type='y')
        print("Finished unpatchifying")
        return pred_images, pred_var_images


    def pred_sequential(self, starting_x_images, data_processor, GP_model, patch_obj, steps, all_together, use_variance_weighting): 
        """
        Predicts a rollout sequence from the starting_x_images for steps number of steps. 
        args: 
            - starting_x_images: list of starting images: depending on the dataprocessor up to 3 images are needed
                - this is used to start the prediction process
            - data_processor: object of class processing used to pre and post process the images to the requisite data types for the GP
            - GP_model: GP model object of type GPy's GP model 
            - patch_obj: object of type patcher to help patchify and unpatchify the images
            - steps: the number of steps to rollout
            - all_together: boolean indicator
                - True: predict the entire dataset in one go 
                - False: predict the dataset one image at a time - NOTE: TODO: might need to change this to x patches at a time. 
            - use_variance_weighting: 
                - True: use variance for the mask weighting
                - False: do not use variance - just do averaging
        returns: 
            - pred_images: list of unpatchified images that are predicted from the test_x set
            - pred_var_images: 
                - returns this if  use_variance_weighting is set to True
            - xdataset: the xdataset that was used to generate each of the predicted images during the rollout

        """
        predicted_seq_imgs = []
        predicted_seq_vars = []
        xdataset = []

        curr_test_x = list(starting_x_images)

        for step in range(steps):
            # process the current image seq
            # NOTE: add a padder image as create_x always leaves the last image for the y dataset creator
            processed_curr_test_x = data_processor.create_x(image_seq=list(curr_test_x) + [np.zeros(curr_test_x[0].shape)])
            xdataset.extend(processed_curr_test_x)
            # patch the processed dataset
            patched_processed_curr_test_x = patch_obj.patchify_dataset(dataset=processed_curr_test_x,
                dataset_type='x')
            # convert patches to flattened vectors
            curr_test_x_vecs = data_processor.convert_imgdataset_to_vecdataset(dataset=patched_processed_curr_test_x)
            # perform the prediction
            if all_together:
                predicted_vecs, predicted_vars = GP_model.predict(Xnew=np.array(curr_test_x_vecs).reshape((len(curr_test_x_vecs), -1)))
            else: 
                predicted_vecs = []
                predicted_vars = []
                start_range = 0
                while start_range < len(curr_test_x_vecs):
                    end_range = min(start_range + self.max_all_together, len(curr_test_x_vecs))
                    predict_subset_x = curr_test_x_vecs[start_range:end_range]
                    predict_subset_ymean, predict_subset_yvar = GP_model.predict(Xnew=np.array(predict_subset_x).reshape((len(predict_subset_x), -1)))
                    # append
                    predicted_vecs.extend(predict_subset_ymean)
                    predicted_vars.extend(predict_subset_yvar)

                    start_range += self.max_all_together
                    
            # convert the predicted vecs to patches
            predicted_patches = [vec.reshape(patch_obj.get_ypatch_dim()) for vec in predicted_vecs]
            # convert the predicted patches to one predicted image
            if use_variance_weighting:
                predicted_image, predicted_padded_image, predicted_var_image, _, _ = patch_obj.unpatchify_image(patch_list=predicted_patches, patch_variance_list=predicted_vars, 
                                                                                                          img_type='y')
            else:
                predicted_image, predicted_padded_image, predicted_var_image, _, _ = patch_obj.unpatchify_image(patch_list=predicted_patches, patch_variance_list=None,
                                                                                                          img_type='y')
            if data_processor.ytype == 'diff':
                # properly calculate the predicted image
                last_image = curr_test_x[-1]
                predicted_image = predicted_image + last_image

            # append to stored lists
            predicted_seq_imgs.append(predicted_image)
            predicted_seq_vars.append(predicted_var_image)

            # update the curr test x
            curr_test_x.append(predicted_image)
            curr_test_x = curr_test_x[1:]

        return predicted_seq_imgs, predicted_seq_vars, xdataset

    def train_model(self, kernel, datapoints, data_processor, patch_obj, optimize=False, noise_var=0.01, use_sparse_GP=False, num_inducing_points=100, max_opt_iters=None):
        """
        Trains the GP model. To do this it first converts the train_image_seq into the desired training data. 
        and then trains the GP model using the created dataset. 

        args: 
            - kernel: initialized GPy kernel
            - datapoints: dictionary containing - only has one of the below
                Either: 
                    - 'image_seq': sequence of images to be converted into a patchified, vectorized dataset
                Or: 
                    - 'x_dataset': pre-processed dataset of images
                    - 'y_dataset': pre-processed dataset of images corresponding to the given 'x_dataset'
            - data_processor: type processing
            - patch_obj: type: patchify_unpatchify
            - optimize: if true will optimize the GP model kernel and other parameters
            - noise_var: the noise variance to use for the GP Regression model
            - use_sparse_GP: type: bool
                - True: use the sparse GP regression model 
                - False: use normal GP regression model 
            - num_inducing_points: type: int: the number of inducing points to use in the sparse GP model. not used if use_sparse_GP is False. 
            - max_opt_iters: type: int: the maximum optimization iterations to run - default is None - dont use it
        returns: None
            - the trained model is stored in the attribute self.GP_model

        # TODO: Maybe later add support for other types of GP models
        """

        if 'image_seq' in datapoints.keys(): 
            # create image dataset
            train_x, train_y = data_processor.create_xy(datapoints['image_seq'])

            # patch and vectorize dataset
            train_x_patch = patch_obj.patchify_dataset(dataset=train_x, dataset_type='x')
            train_y_patch = patch_obj.patchify_dataset(dataset=train_y, dataset_type='y')
            train_x_patch_vecs = data_processor.convert_imgdataset_to_vecdataset(dataset=train_x_patch)
            train_y_patch_vecs = data_processor.convert_imgdataset_to_vecdataset(dataset=train_y_patch)

        elif 'x_dataset' in datapoints.keys() and 'y_dataset' in datapoints.keys():
            train_x = datapoints['x_dataset']
            train_y = datapoints['y_dataset']

            # patch and vectorize dataset
            train_x_patch = patch_obj.patchify_dataset(dataset=train_x, dataset_type='x')
            train_y_patch = patch_obj.patchify_dataset(dataset=train_y, dataset_type='y')
            train_x_patch_vecs = data_processor.convert_imgdataset_to_vecdataset(dataset=train_x_patch)
            train_y_patch_vecs = data_processor.convert_imgdataset_to_vecdataset(dataset=train_y_patch)

        elif 'x_patchvecs_dataset' in datapoints.keys() and 'y_patchvecs_dataset' in datapoints.keys():
            train_x_patch_vecs = datapoints['x_patchvecs_dataset']
            train_y_patch_vecs = datapoints['y_patchvecs_dataset']

        num_train_vecs = len(train_x_patch_vecs)

        # train using the vectorized patches
        kernel = kernel
        if use_sparse_GP:
            GP_model = GPy.models.sparse_gp_regression.SparseGPRegression(X=np.array(train_x_patch_vecs).reshape((num_train_vecs, -1)), 
                                                                          Y=np.array(train_y_patch_vecs).reshape((num_train_vecs, -1)), 
                                                                          kernel=kernel,
                                                                          num_inducing=num_inducing_points)
        else: 
            GP_model = GPy.models.GPRegression(X=np.array(train_x_patch_vecs).reshape((num_train_vecs, -1)), 
                                               Y=np.array(train_y_patch_vecs).reshape((num_train_vecs, -1)), 
                                               kernel=kernel, 
                                               noise_var=noise_var)

        if optimize:
            if type(max_opt_iters) == type(None):
                GP_model.optimize(messages=True)
            else: 
                GP_model.optimize(messages=True, max_iters=max_opt_iters)
                
        return GP_model


    def add_datapoints_to_GP(self, GP_model, datapoints, optimize, data_processor, patch_obj):
        """
        Add datapoints to the trained GP model. 
        args: 
            - GP_model: type GPy.models: the gp model to add the datapoints to 
            - datapoints: dictionary containing - only has one of the below
                Either: 
                    - 'image_seq': sequence of images to be converted into a patchified, vectorized dataset
                Or: 
                    - 'x_dataset': pre-processed dataset of images
                    - 'y_dataset': pre-processed dataset of images corresponding to the given 'x_dataset'
            - optimize: type: boolean: 
                - True optimize the model after adding dadta 
                - False don't optimize
            - data_processor: type: processing
            - patch_obj: type: patchify_unpatchify
        returns: 
            - None: but update the GP_model
        """

        if 'image_seq' in datapoints.keys(): 
            x_dataset, y_dataset = data_processor.create_xy(image_seq=datapoints['image_seq'])
        elif 'x_dataset' in datapoints.keys() and 'y_dataset' in datapoints.keys():
            x_dataset = datapoints['x_dataset']
            y_dataset = datapoints['y_dataset']
        else:
            print("ERROR: Proper data not provided")
            assert(False)

        # patchify dataset to add 
        x_dataset = patch_obj.patchify_dataset(dataset=x_dataset, dataset_type='x')
        y_dataset = patch_obj.patchify_dataset(dataset=y_dataset, dataset_type='y')
        # vectorize dataset to add
        x_dataset = data_processor.convert_imgdataset_to_vecdataset(dataset=x_dataset)
        y_dataset = data_processor.convert_imgdataset_to_vecdataset(dataset=y_dataset)

        new_X = np.vstack((GP_model.X, np.array(x_dataset).reshape((len(x_dataset), -1))))
        new_Y = np.vstack((GP_model.Y, np.array(y_dataset).reshape((len(y_dataset), -1))))
        # update the model 
        GP_model.set_XY(X=new_X, Y=new_Y)

        if optimize: 
            GP_model.optimize(messages=True)

        return GP_model
    
    
    def pred_sequential_seqlearn(self, image_seq, learning_indices, rollout_indices, max_rollout_steps, GP_model, 
        kernel, data_processor, patch_obj, test_patch_obj, optimize, all_together, use_variance_weighting, noise_var=0.01):
        """
        Predicts several rollout sequences. After learning_interval number of images it updates the training data used by the 
        GP model. After this training it does the next rollout. It continues this till it gets to the end of the sequence. 

        Methodology: Use self.pred_sequential for each rollout, add data to the GP model every learning_interval, continue rollouts, repeat. 

        args: 
            - image_seq: sequence of images to predict - used to generate the training data and also used for testing against the rollouts 
            - learning_indices: type: list of ints: list of integers corresponding to the index of the image in image_seq 
                                - these indices indcate the index of the image when the model will be updated, these model updates
                                mean that the model will have data on all the images in the image sequence up to (but not inclusive of) the specified index
            - rollout_indices: type: list of ints: index of image in the image seq that will be the first expected image of the rollout
                                    - the starting x images to begin the rollout will include image_seq[index-3:index]
            - max_rollout_steps: type: int: maximum number of steps to rollout - max of this and the steps remaining in the image_seq
            - data_processor: type: processing object: used for pre and post processing of the images to create the requisite data types for the GP
            - GP_model: type: GPy Regression Model: pretrained GP model 
                - None: if None then use the specified kernel to create a GPy model 
            - kernel: type: GPy.kern: to be used to create the GP_model 
                - None: if GP model is specified then this should and can be None (it is unused)
            - optimize: type: boolean:
                - True: Optimize every time you add data to the GP model
                - False: do not optimize and use the same kernel parameters. 
            - patch_obj: object of type patcher to help patchify and unpatchify the images - not for generating the training data
            - test_patch_obj: type patchify_unpatchify: used for testing 
            - all_together: boolean indicator
                - True: predict the entire dataset in one go 
                - False: predict the dataset one image at a time - NOTE: TODO: might need to change this to x patches at a time. 
            - use_variance_weighting: 
                - True: use variance for the mask weighting
                - False: do not use variance - just do averaging
            - noise_var: the noise variance to use when training or creating the GP model 
        
        NOTE: the rollout indices must start with a number greater than 3 and the learning_indices must start with a number greater than 4, both must
        be in ascending order. 

        returns: 
            - pred_rollouts: list of lists showing the predict rollouts
            - true_rollouts: list of lists showing the true image sequences corresponding to the rollouts. 
            - all_rollout_indices: list of lists where each entry is an integer corresponding to the index of the image in true_rollouts in the image_seq
            - last_train_indices: list of length pred_rollouts where each entry is the last image that the prediction model used to train. 
        """
        assert(not (type(GP_model) == type(None) and type(kernel) == type(None)))
        assert(rollout_indices[0] >= 3)
        assert(learning_indices[0] >= 4)

        pred_rollouts = []
        true_rollouts = []
        all_rollout_indices = []
        last_train_indices = []

        learning_indices = list(learning_indices)
        rollout_indices = list(rollout_indices)

        # you can only start rolling out the image sequence after the first 3 images 
        last_img_num_in_train = 0 # the index (inclusive) from which to start adding data to the gp model during the next model update
        current_image_num = 3 # this refers to the index of the current image - you have not seen this image yet - but have seen up to it ([:current_image_num])
        if type(GP_model) == type(None) and learning_indices[0] > rollout_indices[0]:
            # train a model using the image sequence: first 4 images
            last_img_num_in_train = 4
            current_image_num = 4 
            train_datapoints = {'image_seq':image_seq[:current_image_num]} 
            GP_model = self.train_model(kernel=kernel, datapoints=train_datapoints, 
                                        data_processor=data_processor, patch_obj=patch_obj, 
                                        optimize=optimize, noise_var=noise_var)
            # remove 4 from the learning indices if it starts with 4
            if 4 in learning_indices:
                learning_indices.remove(4)

        current_rollout_num = 0
        num_rollouts = len(rollout_indices)

        print("Starting Rollouts and Online learning")
        while (current_image_num < len(image_seq) and current_rollout_num < num_rollouts): 
            if current_image_num in learning_indices: 
                # update and train the model using images up to but not including current_image_num
                # the -3 accounts for the fact that one datapoint is composed of 4 images so you need to account for those lost datapoints
                images_to_add = image_seq[max(0, last_img_num_in_train-3):current_image_num]
                print("Learning at image: ", current_image_num)
                last_img_num_in_train = current_image_num
                new_datapoints = {'image_seq':images_to_add}
                if type(GP_model) == type(None):
                    GP_model = self.train_model(kernel=kernel, datapoints=new_datapoints, 
                                                data_processor=data_processor, 
                                                patch_obj=patch_obj, optimize=optimize, 
                                                noise_var=noise_var)
                else:
                    self.add_datapoints_to_GP(GP_model=GP_model, datapoints=new_datapoints, optimize=optimize, 
                                              data_processor=data_processor, patch_obj=patch_obj)

                print("GP model: ", GP_model)
                print("End Learning \n")

            if current_image_num in rollout_indices: 
                # rollout using images up to current_image_num to start the rollout
                rollout_length = min(max_rollout_steps, len(image_seq) - current_image_num)
                current_starting_x_images = image_seq[current_image_num-3:current_image_num]
                print("Starting Rollout at image: ", current_image_num)
                current_pred_rollout, _, _ = self.pred_sequential(starting_x_images=current_starting_x_images, 
                                                             data_processor=data_processor, 
                                                             GP_model=GP_model, 
                                                             patch_obj=patch_obj, 
                                                             steps=rollout_length, 
                                                             all_together=all_together, 
                                                             use_variance_weighting=use_variance_weighting)

                pred_rollouts.append(current_pred_rollout)
                true_rollouts.append(image_seq[current_image_num:current_image_num + rollout_length])
                all_rollout_indices.append(np.arange(current_image_num, current_image_num + rollout_length))
                last_train_indices.append(last_img_num_in_train)
                current_rollout_num += 1
                print("Ending Rollout \n")

            # increment 
            current_image_num += 1

        return pred_rollouts, true_rollouts, all_rollout_indices, last_train_indices, GP_model

    def pred_sequential_learning(self, image_seq, data_processor, GP_model, patch_obj, num_rollouts, steps_per_rollout, all_together, use_patch_variance):
        """
        Predicts several rollout sequences. After each rollout of steps size, adds one datapoint from test_x to the GP dataset
        and then continues rollout prediction starting from the next datapoint. Each rollout is steps_per_rollout long and this 
        there are num_rollouts repeated rollouts. 

        Each datapoint indexed in test_x is the starting datapoint for the rollout list at the corresponding index of the
        returned list of rollouts. 

        Methodology: Use self.pred_sequential for each rollout, add data to the GP model and then repeat. 
        args: 
            - image_seq: list of images that you use to both train the GP model and that act as a baseline for comparison
                - you want to start the predictions from the first 3 images of this sequence 
            - data_processor: object of class processing used to pre and post process the images to the requisite data types for the GP
            - GP_model: GP model object of type GPy's GP model - pretrain - you may want to pretrain this model with something
            - patch_obj: object of type patcher to help patchify and unpatchify the images
            - num_rollouts: how many prediction rollouts should occur
            - steps_per_rollout: how many steps should occur in each rollout
            - all_together: boolean indicator
                - True: predict the entire dataset in one go 
                - False: predict the dataset one image at a time - NOTE: TODO: might need to change this to x patches at a time. 
            - use_variance_weighting: 
                - True: use variance for the mask weighting
                - False: do not use variance - just do averaging
        returns: 
            - pred_rollouts: list of lists where each sublist contains a rollout steps_per_rollout images long
        """
        raise NotImplementedError
        return 