#!/usr/bin/env python3

# from linear_operator.operators import KeOpsLinearOperator
from linear_operator.operators import KernelLinearOperator

from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel


def _covar_func(x1, x2, **kwargs):
    x1_, x2_ = _lazify_and_expand_inputs(x1, x2)
    K = (-((x1_ - x2_) ** 2).sum(-1) / 2).exp()
    return K


class RBFKernel(KeOpsKernel):
    r"""
    Implements the RBF kernel using KeOps as a driver for kernel matrix multiplies.

    This class can be used as a drop in replacement for :class:`gpytorch.kernels.RBFKernel` in most cases,
    and supports the same arguments.

    :param ard_num_dims: Set this if you want a separate lengthscale for each input
        dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.)
    :param batch_shape: Set this if you want a separate lengthscale for each batch of input
        data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is
        a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
    :param active_dims: Set this if you want to compute the covariance of only
        a few input dimensions. The ints corresponds to the indices of the
        dimensions. (Default: `None`.)
    :param lengthscale_prior: Set this if you want to apply a prior to the
        lengthscale parameter. (Default: `None`)
    :param lengthscale_constraint: Set this if you want to apply a constraint
        to the lengthscale parameter. (Default: `Positive`.)
    :param eps: The minimum value that the lengthscale can take (prevents
        divide by zero errors). (Default: `1e-6`.)

    :ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
        ard_num_dims and batch_shape arguments.
    """

    has_lengthscale = True

    def forward(self, x1, x2, **kwargs):
        x1_ = x1 / self.lengthscale
        x2_ = x2 / self.lengthscale
        # return KernelLinearOperator inst only when calculating the whole covariance matrix
        return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, **kwargs)
