# Copyright (c) 2022 Copyright holder of the paper Structural Kernel Search via Bayesian Optimization and Symbolical Optimal Transport submitted to NeurIPS 2022 for review.
# All rights reserved.
import gpflow
from typing import Tuple
import numpy as np

gpflow.config.set_default_float(np.float64)
f64 = gpflow.utilities.to_default_float
from tensorflow_probability import distributions as tfd
from bosot.kernels.base_elementary_kernel import BaseElementaryKernel


class RationalQuadraticKernel(BaseElementaryKernel):
    def __init__(
        self,
        input_dimension: int,
        base_lengthscale: float,
        base_variance: float,
        base_alpha: float,
        add_prior: bool,
        lengthscale_prior_parameters: Tuple[float, float],
        variance_prior_parameters: Tuple[float, float],
        alpha_prior_parameters: Tuple[float, float],
        active_on_single_dimension: bool,
        active_dimension: int,
        name: str,
        **kwargs,
    ):
        super().__init__(input_dimension, active_on_single_dimension, active_dimension, name)
        self.kernel = gpflow.kernels.RationalQuadratic(lengthscales=f64(np.repeat(base_lengthscale, self.num_active_dimensions)), variance=f64([base_variance]), alpha=f64([base_alpha]))
        if add_prior:
            a_lengthscale, b_lengthscale = lengthscale_prior_parameters
            a_variance, b_variance = variance_prior_parameters
            a_alpha, b_alpha = alpha_prior_parameters
            self.kernel.lengthscales.prior = tfd.Gamma(
                f64(np.repeat(a_lengthscale, self.num_active_dimensions)),
                f64(np.repeat(b_lengthscale, self.num_active_dimensions)),
            )
            self.kernel.variance.prior = tfd.Gamma(f64([a_variance]), f64([b_variance]))
            self.kernel.alpha.prior = tfd.Gamma(f64([a_alpha]), f64([b_alpha]))
