# This File contains the functions used to predict outputs on test data 

import numpy as np 
import cv2 
import matplotlib.pyplot as plt 
import GPy

# Append system path for file imports
import sys
import os 
cwd = os.getcwd()
parent_dir = os.path.join(cwd, "..")
sys.path.append(parent_dir)

from patchify_unpatchify import patcher
from processing import processing

import pdb

class predict_separate: 
    """
    Predict Separate: This predictor class will create and store a separate GP model for each output dimension. 
    """

    def __init__(self, max_all_together = 1000):
        # maximum predictions to be done at once if all predictions not to be processed together.  
        self.max_all_together = max_all_together
        return 

    def train_models(self, kernels, datapoints, data_processor, patch_obj, optimize=False, noise_var=0.01, use_sparse_GP=False, num_inducing_points=100, max_opt_iters=None):
        """
        Trains a GP model for each output dimension. 
        To do this it first converts the train_image_seq into the desired training data. 
        and then trains each GP model using the created dataset. 

        args: 
            - kernels: list of initialized GPy kernels: one for each output dimension
            - datapoints: dictionary containing - only has one of the below
                Either: 
                    - 'image_seq': sequence of images to be converted into a patchified, vectorized dataset
                Or: 
                    - 'x_dataset': pre-processed dataset of images
                    - 'y_dataset': pre-processed dataset of images corresponding to the given 'x_dataset'
            - data_processor: type processing
            - patch_obj: type: patchify_unpatchify
            - optimize: if true will optimize the GP model kernel and other parameters
            - noise_var: the noise variance to use for the GP Regression model
            - use_sparse_GP: type: bool
                - True: use the sparse GP regression model 
                - False: use normal GP regression model 
            - num_inducing_points: type: int: the number of inducing points to use in the sparse GP model. not used if use_sparse_GP is False. 
            - max_opt_iters: type: int: the maximum optimization iterations to run - default is None - dont use it
        returns: None
            - list of trained models to be stored in self.GP_models

        # TODO: Maybe later add support for other types of GP models
        """

        if 'image_seq' in datapoints.keys(): 
            # create image dataset
            train_x, train_y = data_processor.create_xy(datapoints['image_seq'])

            # patch and vectorize dataset
            train_x_patch = patch_obj.patchify_dataset(dataset=train_x, dataset_type='x')
            train_y_patch = patch_obj.patchify_dataset(dataset=train_y, dataset_type='y')
            train_x_patch_vecs = data_processor.convert_imgdataset_to_vecdataset(dataset=train_x_patch)
            train_y_patch_vecs = data_processor.convert_imgdataset_to_vecdataset(dataset=train_y_patch)

        elif 'x_dataset' in datapoints.keys() and 'y_dataset' in datapoints.keys():
            train_x = datapoints['x_dataset']
            train_y = datapoints['y_dataset']

            # patch and vectorize dataset
            train_x_patch = patch_obj.patchify_dataset(dataset=train_x, dataset_type='x')
            train_y_patch = patch_obj.patchify_dataset(dataset=train_y, dataset_type='y')
            train_x_patch_vecs = data_processor.convert_imgdataset_to_vecdataset(dataset=train_x_patch)
            train_y_patch_vecs = data_processor.convert_imgdataset_to_vecdataset(dataset=train_y_patch)

        elif 'x_patchvecs_dataset' in datapoints.keys() and 'y_patchvecs_dataset' in datapoints.keys():
            train_x_patch_vecs = datapoints['x_patchvecs_dataset']
            train_y_patch_vecs = datapoints['y_patchvecs_dataset']

        num_train_vecs = len(train_x_patch_vecs)
        X_train = np.array(train_x_patch_vecs).reshape((num_train_vecs, -1))
        Y_train = np.array(train_y_patch_vecs).reshape((num_train_vecs, -1))
        num_gp_models = Y_train.shape[1] # num output dimensions

        GP_models = []
        for idx, kernel in enumerate(kernels): 
            if use_sparse_GP:
                GP_model = GPy.models.sparse_gp_regression.SparseGPRegression(X=X_train, 
                                                                              Y=Y_train[:, idx].reshape((num_train_vecs, 1)), 
                                                                              kernel=kernel,
                                                                              num_inducing=num_inducing_points)
            else: 
                GP_model = GPy.models.GPRegression(X=X_train, 
                                                   Y=Y_train[:, idx].reshape((num_train_vecs, 1)), 
                                                   kernel=kernel, 
                                                   noise_var=noise_var)
            GP_models.append(GP_model)

        if optimize:
            for model_num in range(len(GP_models)): 
                if type(max_opt_iters) == type(None):
                    GP_models[model_num].optimize(messages=True)
                else: 
                    GP_models[model_num].optimize(messages=True, max_iters=max_opt_iters)
                
        return GP_models

    def add_datapoints_to_GPs(self, GP_models, datapoints, optimize, data_processor, patch_obj):
        """
        Add datapoints to each trained GP model
        args: 
            - GP_models: list of type GPy.models: GPy models stored in the order of output dimension. Add that index of the output dimension as 
                        a datapoint to the corresponding GP model 
            - datapoints: dictionary containing - only has one of the below
                Either: 
                    - 'image_seq': sequence of images to be converted into a patchified, vectorized dataset
                Or: 
                    - 'x_dataset': pre-processed dataset of images
                    - 'y_dataset': pre-processed dataset of images corresponding to the given 'x_dataset'
            - optimize: type: boolean: 
                - True optimize the model after adding dadta 
                - False don't optimize
            - data_processor: type: processing
            - patch_obj: type: patchify_unpatchify
        returns: 
            - the list of GP_models
        """

        if 'image_seq' in datapoints.keys(): 
            x_dataset, y_dataset = data_processor.create_xy(image_seq=datapoints['image_seq'])
        elif 'x_dataset' in datapoints.keys() and 'y_dataset' in datapoints.keys():
            x_dataset = datapoints['x_dataset']
            y_dataset = datapoints['y_dataset']
        else:
            print("ERROR: Proper data not provided")
            assert(False)

        # patchify dataset to add 
        x_dataset = patch_obj.patchify_dataset(dataset=x_dataset, dataset_type='x')
        y_dataset = patch_obj.patchify_dataset(dataset=y_dataset, dataset_type='y')
        # vectorize dataset to add
        x_dataset = data_processor.convert_imgdataset_to_vecdataset(dataset=x_dataset)
        y_dataset = data_processor.convert_imgdataset_to_vecdataset(dataset=y_dataset)
        y_dataset = np.array(y_dataset).reshape((len(y_dataset), -1))

        new_X = np.vstack((GP_models[0].X, np.array(x_dataset).reshape((len(x_dataset), -1))))

        for model_num in range(len(GP_models)):
            new_Y = np.vstack((GP_models[model_num].Y, y_dataset[:, model_num]))
            GP_models[model_num].set_XY(X=new_X, Y=new_Y)

        for model_num in range(len(GP_models)):
            if optimize: 
                GP_models[model_num].optimize(messages=True)

        return GP_models