import time

import gpflow
import tensorflow as tf
from gpflow.models.gpr import RegressionData, Kernel, Optional, MeanFunction, TensorData,\
      Gaussian, add_likelihood_noise_cov, GPModel, InternalDataTrainingLossMixin, multivariate_normal, \
        InputData, MeanAndVariance, assert_params_false
import numpy as np


class Graph_GPR(GPModel, InternalDataTrainingLossMixin):
    def __init__(
            self,
            data: RegressionData,
            kernel: Kernel,
            mean_function: Optional[MeanFunction] = lambda Graph_List: tf.zeros(shape=(len(Graph_List), 1), dtype=tf.float64),
            noise_variance: Optional[TensorData] = None,
            likelihood: Optional[Gaussian] = None,
    ):
        assert (noise_variance is None) or (
                likelihood is None
        ), "Cannot set both `noise_variance` and `likelihood`."
        if likelihood is None:
            if noise_variance is None:
                noise_variance = 1.0
            likelihood = gpflow.likelihoods.Gaussian(noise_variance)
        _, Y_data = data
        super().__init__(kernel, likelihood, mean_function, num_latent_gps=Y_data.shape[-1])
        self.data = data
        self.prepare_kernel_matrix()

    def prepare_kernel_matrix(self):
        G, Y = self.data
        self.Kp, self.Kn, self.Ke = self.kernel.extract_kernel_matrix(G)

    # type-ignore is because of changed method signature:
    def maximum_log_likelihood_objective(self) -> tf.Tensor:  # type: ignore[override]
        return self.log_marginal_likelihood()

    def log_marginal_likelihood(self) -> tf.Tensor:
        r'''
        Computes the log marginal likelihood.

        .. math::
            \log p(Y | \theta).

        '''
        G, Y = self.data
        X = self.kernel._get_V_Matrix(G)
        K = self.kernel.K(self.Kp, self.Kn, self.Ke)
        Ks = add_likelihood_noise_cov(K, self.likelihood, X)

        L = tf.linalg.cholesky(Ks)
        m = self.mean_function(X)

        # [R,] log-likelihoods for each independent dimension of Y
        log_prob = multivariate_normal(Y, m, L)
        return tf.reduce_sum(log_prob)

    def get_K_constants(self):
        # return constant matrix K_XX_inv, K_XX_inv_Y
        G, Y = self.data
        K_XX = self.kernel.K(self.Kp, self.Kn, self.Ke) + tf.eye(len(G), dtype=tf.float64) * 1e-3
        L = tf.linalg.cholesky(K_XX)
        LinvY = tf.linalg.triangular_solve(L, Y, lower=True)
        shift = self.kernel.calculate_shift(G)
        return L.numpy(), LinvY.numpy(), shift

    def predict_f(
            self, Gnew: InputData, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        r'''
        This method computes predictions at X \in R^{N \x D} input points

        .. math::
            p(F* | Y)

        where F* are points on the GP at new data points, Y are noisy observations at training data
        points.
        '''

        assert_params_false(self.predict_f, full_output_cov=full_output_cov)

        G, Y = self.data
        X = self.kernel._get_V_Matrix(G)
        Xnew = self.kernel._get_V_Matrix(Gnew)
        err = Y - self.mean_function(X)

        kmm = self.kernel.K(self.Kp, self.Kn, self.Ke)
        Kp_nn, Kn_nn, Ke_nn = self.kernel.extract_kernel_matrix(Gnew, full_cov=full_cov)
        knn = self.kernel.K(Kp_nn, Kn_nn, Ke_nn)

        if not full_cov:
            knn = tf.linalg.diag_part(knn)

        Kp_mn, Kn_mn, Ke_mn = self.kernel.extract_kernel_matrix(G, Gnew)
        kmn = self.kernel.K(Kp_mn, Kn_mn, Ke_mn)
        kmm_plus_s = add_likelihood_noise_cov(kmm, self.likelihood, X)

        conditional = gpflow.conditionals.base_conditional
        f_mean_zero, f_var = conditional(
            kmn, kmm_plus_s, knn, err, full_cov=full_cov, white=False
        )
        f_mean = f_mean_zero + self.mean_function(Xnew)
        return f_mean, f_var
