import tensorflow as tf
import gpflow
from .graph_kernels import SPKernel
from .models import Graph_GPR
from gpflow.utilities import print_summary

def graphGP_fit(Gs,
                Y,
                kernel=SPKernel(trainable_lengthscales=False,
                                trainable_variance=False,
                                trainable_alpha=True,
                                trainable_beta=True,
                                trainable_gamma=False,
                                kernel_type="SP",
                                )):
    '''fit a GP model with given graphs,
    Args:
        Gs (array): graph dataset
        Y (array): corresponding y values of graph dataset
        kernel (class): Graph kernel, SSP kernel used as default
    Return:
        m (class): trained graph GP model
    '''
    #
    m = Graph_GPR((Gs, Y), kernel=kernel)
    m.likelihood.variance = tf.constant([1e-3], dtype=tf.float64)
    opt = gpflow.optimizers.Scipy()
    opt.minimize(m.training_loss, m.trainable_variables, compile=True)
    print_summary(m.kernel)
    return m