#!/usr/bin/env python3
import math

from linear_operator.operators import KernelLinearOperator

from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel


def _covar_func(x1, x2, nu=2.5, **params):
    x1_, x2_ = _lazify_and_expand_inputs(x1, x2)

    sq_distance = ((x1_ - x2_) ** 2).sum(-1)
    distance = (sq_distance + 1e-20).sqrt()
    # ^^ Need to add epsilon to prevent small negative values with the sqrt
    # backward pass (otherwise we get NaNs).
    # using .clamp(1e-20, math.inf) doesn't work in KeOps; it also creates NaNs
    exp_component = (-math.sqrt(nu * 2) * distance).exp()

    if nu == 0.5:
        constant_component = 1
    elif nu == 1.5:
        constant_component = (math.sqrt(3) * distance) + 1
    elif nu == 2.5:
        constant_component = (math.sqrt(5) * distance) + (1 + 5.0 / 3.0 * sq_distance)

    return constant_component * exp_component


class MaternKernel(KeOpsKernel):
    """
    Implements the Matern kernel using KeOps as a driver for kernel matrix multiplies.

    This class can be used as a drop in replacement for :class:`gpytorch.kernels.MaternKernel` in most cases,
    and supports the same arguments.

    :param nu: (Default: 2.5) The smoothness parameter.
    :type nu: float (0.5, 1.5, or 2.5)
    :param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each
        input dimension. It should be `d` if x1 is a `... x n x d` matrix.
    :type ard_num_dims: int, optional
    :param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each
         batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output.
    :type batch_shape: torch.Size, optional
    :param active_dims: (Default: `None`) Set this if you want to
        compute the covariance of only a few input dimensions. The ints
        corresponds to the indices of the dimensions.
    :type active_dims: Tuple(int)
    :param lengthscale_prior: (Default: `None`)
        Set this if you want to apply a prior to the lengthscale parameter.
    :type lengthscale_prior: ~gpytorch.priors.Prior, optional
    :param lengthscale_constraint: (Default: `Positive`) Set this if you want
        to apply a constraint to the lengthscale parameter.
    :type lengthscale_constraint: ~gpytorch.constraints.Interval, optional
    :param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors).
    :type eps: float, optional
    """

    has_lengthscale = True

    def __init__(self, nu=2.5, **kwargs):
        if nu not in {0.5, 1.5, 2.5}:
            raise RuntimeError("nu expected to be 0.5, 1.5, or 2.5")
        super().__init__(**kwargs)
        self.nu = nu

    def forward(self, x1, x2, **kwargs):
        mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]
        x1_ = (x1 - mean) / self.lengthscale
        x2_ = (x2 - mean) / self.lengthscale
        # return KernelLinearOperator inst only when calculating the whole covariance matrix
        return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs)
