import numpy as np
import tensorflow as tf
import gpflow
import numpy as np
import gpflow as gpf
import tensorflow as tf

from DataPreprocessor import DataPreprocessor
import datetime


class MultitaskGPModel(object):
    def __init__(self, num_elements, opt_max_iter = 1000, opt_tol = 3.5, L = 12, num_inducing_points = 75, retry = 20, mdl_warm = None):
        self.num_elements = num_elements
        self.raw_data = []
        self.mtgp_model = None
        self.L = L
        self.num_inducing_points = num_inducing_points
        self.opt_tol = opt_tol
        self.opt_max_iter = opt_max_iter
        self.retry = retry
        
    def append_data(self, elem):
        self.raw_data.append(np.array(elem))

    def append_data_list(self, elemlist):
        for elem in elemlist:
            self.raw_data.append(np.array(elem))
    
    def prune(self, idxs):
        '''
        Delete the data corresponding to the deleted indeces.
        '''
        self.prune_indeces = np.array(idxs)
        for eidx, elem in enumerate(self.raw_data):
            self.raw_data[eidx] = np.delete(elem, self.prune_indeces)
        self.num_elements -= self.prune_indeces.shape[0]
        
    def compute_step(self):
        '''
        Preprocess the collected raw data, and build the model
        '''
        def optimize_model_with_scipy(model, X, Y):
            data = X, Y
            optimizer = gpf.optimizers.Scipy()
            print(self.opt_tol)
            optimizer.minimize(
                model.training_loss_closure(data),
                variables=model.trainable_variables,
                method="l-bfgs-b",
                options={"gtol":self.opt_tol, "maxiter":self.opt_max_iter},
            )
            
        if len(self.raw_data) < self.num_inducing_points:
            nidp = len(self.raw_data)
        else:
            nidp = self.num_inducing_points
            
        #setup x, y
        self.data_preprocessor = DataPreprocessor(self.num_elements)
        self.data_preprocessor.append_data_list(self.raw_data)
        self.data_preprocessor.setup_preprocess()
        
        #X represents time
        X = np.array(range(1, len(self.raw_data) + 1), dtype = np.float64).reshape(-1, 1)
        Y = self.data_preprocessor.transform_y(np.array(self.raw_data, dtype = np.float64).T).T
        
        for r in range(self.retry):
            try:
                #inducing points, one for each implicit function
                Zinit = [gpf.inducing_variables.InducingPoints(np.linspace(np.min(X), np.max(X), nidp)[:, None]) for _ in range(self.L)]

                #kernels, one for each implicit function
                kern_list = [gpflow.kernels.TExponential() for _ in range(self.L)]

                #linear model of coregionalization
                kernel = gpf.kernels.LinearCoregionalization(
                    kern_list, W=np.random.randn(self.num_elements, self.L))

                # create multi-output inducing variables from Z
                iv = gpf.inducing_variables.SeparateIndependentInducingVariables(
                    Zinit
                )


                # initialize mean of variational posterior to be of shape MxL
                q_mu = np.zeros((nidp, self.L))
                # initialize \sqrt(Σ) of variational posterior to be of shape LxMxM
                q_sqrt = np.repeat(np.eye(nidp)[None, ...], self.L, axis=0) * 1.0

                # create SVGP model as usual and optimize
                m = gpf.models.SVGP(
                    kernel, gpf.likelihoods.Gaussian(), inducing_variable=iv, q_mu=q_mu, q_sqrt=q_sqrt
                )
                optimize_model_with_scipy(m, X, Y)
                self.mtgp_model = m

                #sanity check to verify that the model works well.
                pm, pc = self.predict_y([500], full_cov = True, full_output_cov = True)
                if np.isnan(pm).any() or np.isnan(pc).any():
                    raise
                break
            except Exception as e:
                # we'd like the model for cheap, but sometimes the data isn't so nice
                # first we try raising the tolerance and hope to get a fit :)
                # if no fit, we have to raise the number of independent GPs.
                if self.opt_tol < 50:
                    self.opt_tol = self.opt_tol * 1.15
                if self.L < 15 and self.opt_tol > 15 and r > 7:
                    self.L = self.L + 1
                
                continue
        self.mtgp_model = m
        
    def predict_y(self, time_seq, full_cov = False, full_output_cov = False):
        # predict the model and do the inverse preprocessing as detailed in the Appendix
        working_data_xs = np.array(time_seq, dtype = np.float64).reshape(-1, 1)
        m, c = self.mtgp_model.predict_f(working_data_xs, full_cov = full_cov, full_output_cov = full_output_cov)
        m, c = m.numpy(), c.numpy()
        
        for i in range(c.shape[0]):
            c[i, :, i, :] = c[:, i, :, i] + (np.eye(c.shape[1]) * (1e-6 + self.mtgp_model.likelihood.variance.numpy()))
        
        m = self.data_preprocessor.inverse_transform_noscale(m.T).T
        for i in range(c.shape[0]):
            for j in range(c.shape[2]):
                tx = c[i, :, j, :].reshape((c.shape[1], c.shape[3]))
                c[i, :, j, :] = self.data_preprocessor.inverse_transform_cov_noscale(tx)
        return m, c
    
    def predict_f(self, time_seq, full_cov = False, full_output_cov = False):
        working_data_xs = np.array(time_seq, dtype = np.float64).reshape(-1, 1)
        m, c = self.mtgp_model.predict_f(working_data_xs, full_cov = full_cov, full_output_cov = full_output_cov)
        return m, c
