# Copyright (c) 2022 Copyright holder of the paper Structural Kernel Search via Bayesian Optimization and Symbolical Optimal Transport submitted to NeurIPS 2022 for review.
# All rights reserved.
from typing import List, Optional, Tuple
import gpflow
from bosot.kernels.base_object_kernel import BaseObjectKernel
from gpflow.models.gpr import GPR_with_posterior
import numpy as np
import tensorflow as tf
from gpflow.config import default_float
from gpflow.mean_functions import Constant, MeanFunction
from bosot.models.object_mean_functions import Zero


class ObjectGPR(GPR_with_posterior):
    """
    This is a class that extends the standard Gpflow gpr model from array to objects as input elements
    """

    def __init__(self, data: Tuple[List[object], np.array], kernel: BaseObjectKernel, noise_variance: float = 1.0, mean_function: MeanFunction = Zero()):
        assert isinstance(kernel, BaseObjectKernel)
        X_data, Y_data = data
        X_placeholder = np.zeros((len(Y_data), 1))
        # Initialize standard GPR model - without actual X data
        super().__init__(data=(X_placeholder, Y_data), kernel=kernel, mean_function=mean_function, noise_variance=noise_variance)
        # override actual data tuple
        Y_data_tf = self.array_to_tensor(Y_data)
        self.data = (X_data, Y_data_tf)

    def array_to_tensor(self, array):
        if tf.is_tensor(array):
            return array
        elif isinstance(array, np.ndarray):
            return tf.convert_to_tensor(array)
        return tf.convert_to_tensor(array, dtype=default_float())
