import numpy as np
import torch
import torch.nn as nn

__all__ = ['Kernel', 'RBFKernel', 'PeriodicKernel', 'CosineKernel',
           'Matern32Kernel', 'Matern52Kernel', 'RQKernel']


class Kernel(nn.Module):
    """A base class for GP kernels.
    """
    def __init__(self, **kwargs):
        super().__init__()

    def forward(self, x1, x2, diag=False):
        raise NotImplementedError


class RBFKernel(Kernel):
    """An RBF kernel.

    :param lengthscale: A float/list, the initial lengthscale(s).
    :param scale: A float, the initial scale.
    """

    def __init__(self, lengthscale=1., scale=1., **kwargs):
        super().__init__(**kwargs)

        self.lengthscale = nn.Parameter(torch.tensor(lengthscale),
                                        requires_grad=True)
        self.scale = nn.Parameter(torch.tensor(scale), requires_grad=True)

    def forward(self, x1, x2, diag=False):
        """ Returns the covariance matrix defined by the RBF kernel.

        :param x1: A Tensor, [M1, D], the inputs for which to evaluate at.
        :param x2: A Tensor, [M2, D], the inputs for which to evaluate at.
        :param diag: A bool, whether to evaluate only the diagonal values.
        """
        # Add dimensions if needed.
        x1 = x1.unsqueeze(1) if len(x1.shape) == 1 else x1
        x2 = x2.unsqueeze(1) if len(x2.shape) == 1 else x2

        assert x1.shape[-1] == x2.shape[-1], 'Inputs must be the same ' \
                                             'dimension.'
        if not diag:
            x1 = x1.unsqueeze(1)
            x2 = x2.unsqueeze(0)
        else:
            assert x1.shape == x2.shape, 'Inputs must be the same shape.'

        # [M1, M2, D] or [M, D] if diag.
        sd = (x1 - x2) ** 2
        sd.clamp_min_(0)

        # Apply lengthscale and sum over dimensions.
        sd_ = (sd / self.lengthscale ** 2).sum(-1)
        cov = self.scale ** 2 * (-sd_).exp()

        return cov


class PeriodicKernel(Kernel):
    """A periodic kernel.

    :param lengthscale: A float/list, the initial lengthscale(s).
    :param period: A float/list, the initial period/periods.
    :param scale: A float, the initial scale.
    """

    def __init__(self, lengthscale=1., period=1., scale=1., **kwargs):
        super().__init__(**kwargs)

        assert type(lengthscale) == type(period), 'lengthscale and period ' \
                                                  'must be of the same type.'
        self.lengthscale = nn.Parameter(torch.tensor(lengthscale),
                                        requires_grad=True)
        self.scale = nn.Parameter(torch.tensor(scale), requires_grad=True)

        # Period must be constrained to be positive.
        self.raw_period = nn.Parameter(torch.tensor(period).log(),
                                       requires_grad=True)

    def forward(self, x1, x2, diag=False):
        """ Returns the covariance matrix defined by the periodic kernel.

        :param x1: A Tensor, [M1, D] or [M1], the inputs for which to
        evaluate at.
        :param x2: A Tensor, [M2, D] or [M2], the inputs for which to
        evaluate at.
        :param diag: A bool, whether to evaluate only the diagonal values
        """
        # Add dimensions if needed.
        x1 = x1.unsqueeze(1) if len(x1.shape) == 1 else x1
        x2 = x2.unsqueeze(1) if len(x2.shape) == 1 else x2

        assert x1.shape[-1] == x2.shape[-1], 'Inputs must be the same ' \
                                             'dimension.'

        if not diag:
            x1 = x1.unsqueeze(1)
            x2 = x2.unsqueeze(0)
        else:
            assert x1.shape == x2.shape, 'Inputs must be the same shape.'

        # [M1, M2, D] or [M, D] if diag.
        ad = (x1 - x2).abs()

        # Apply period.
        ad_ = 2 * (np.pi * ad / self.raw_period.exp()).sin() ** 2

        # Apply lengthscale and sum over dimensions.
        ad_ = (ad_ / self.lengthscale ** 2).sum(-1)
        cov = self.scale ** 2 * (-ad_).exp()

        return cov


class CosineKernel(Kernel):
    """A purely periodic cosine kernel.

    :param period: A float/list, the initial period/periods.
    :param scale: A float, the initial scale.
    """

    def __init__(self, period=1., scale=1., **kwargs):
        super().__init__(**kwargs)

        self.scale = nn.Parameter(torch.tensor(scale), requires_grad=True)

        # Period must be constrained to be positive.
        self.raw_period = nn.Parameter(torch.tensor(period).log(),
                                       requires_grad=True)

    def forward(self, x1, x2, diag=False):
        """ Returns the covariance matrix defined by the periodic kernel.

        :param x1: A Tensor, [M1, D] or [M1], the inputs for which to
        evaluate at.
        :param x2: A Tensor, [M2, D] or [M2], the inputs for which to
        evaluate at.
        :param diag: A bool, whether to evaluate only the diagonal values
        """
        # Add dimensions if needed.
        x1 = x1.unsqueeze(1) if len(x1.shape) == 1 else x1
        x2 = x2.unsqueeze(1) if len(x2.shape) == 1 else x2

        assert x1.shape[-1] == x2.shape[-1], 'Inputs must be the same ' \
                                             'dimension.'

        if not diag:
            x1 = x1.unsqueeze(1)
            x2 = x2.unsqueeze(0)
        else:
            assert x1.shape == x2.shape, 'Inputs must be the same shape.'

        # [M1, M2, D] or [M, D] if diag.
        ad = (x1 - x2).abs()

        # Apply period and sum over dimensions.
        ad_ = (2 * np.pi * ad / self.raw_period.exp()).cos().sum(-1)
        cov = self.scale ** 2 * ad_

        return cov


class Matern32Kernel(Kernel):
    """ A Matern 3/2 kernel.

    :param lengthscale: A float/list, the initial lengthscale(s).
    :param scale: A float, the initial scale.
    """
    def __init__(self, lengthscale=1., scale=1., **kwargs):
        super().__init__(**kwargs)

        self.lengthscale = nn.Parameter(torch.tensor(lengthscale),
                                        requires_grad=True)
        self.scale = nn.Parameter(torch.tensor(scale), requires_grad=True)

    def forward(self, x1, x2, diag=False):
        """ Returns the covariance matrix defined by the Matern 3/2 kernel.

        :param x1: A Tensor, [M1, D], the inputs for which to evaluate at.
        :param x2: A Tensor, [M2, D], the inputs for which to evaluate at.
        :param diag: A bool, whether to evaluate only the diagonal values.
        """
        # Add dimensions if needed.
        x1 = x1.unsqueeze(1) if len(x1.shape) == 1 else x1
        x2 = x2.unsqueeze(1) if len(x2.shape) == 1 else x2

        assert x1.shape[-1] == x2.shape[-1], 'Inputs must be the same ' \
                                             'dimension.'

        if not diag:
            x1 = x1.unsqueeze(1)
            x2 = x2.unsqueeze(0)
        else:
            assert x1.shape == x2.shape, 'Inputs must be the same shape.'

        # [M1, M2, D] or [M, D] if diag.
        sd = (x1 - x2) ** 2
        sd.clamp_min_(0)

        # Apply lengthscale and sum over dimensions.
        sd_ = (3 * sd / self.lengthscale ** 2).sum(-1)
        sd_ = sd_ ** 0.5
        cov = self.scale ** 2 * (1 + sd_) * (-sd_).exp()

        return cov


class Matern52Kernel(Kernel):
    """ A Matern 5/2 kernel.

    :param lengthscale: A float/list, the initial lengthscale(s).
    :param scale: A float, the initial scale.
    """

    def __init__(self, lengthscale=1., scale=1., **kwargs):
        super().__init__(**kwargs)

        self.lengthscale = nn.Parameter(torch.tensor(lengthscale),
                                        requires_grad=True)
        self.scale = nn.Parameter(torch.tensor(scale), requires_grad=True)

    def forward(self, x1, x2, diag=False):
        """ Returns the covariance matrix defined by the Matern 3/2 kernel.

        :param x1: A Tensor, [M1, D], the inputs for which to evaluate at.
        :param x2: A Tensor, [M2, D], the inputs for which to evaluate at.
        :param diag: A bool, whether to evaluate only the diagonal values.
        """
        # Add dimensions if needed.
        x1 = x1.unsqueeze(1) if len(x1.shape) == 1 else x1
        x2 = x2.unsqueeze(1) if len(x2.shape) == 1 else x2

        assert x1.shape[-1] == x2.shape[-1], 'Inputs must be the same ' \
                                             'dimension.'

        if not diag:
            x1 = x1.unsqueeze(1)
            x2 = x2.unsqueeze(0)
        else:
            assert x1.shape == x2.shape, 'Inputs must be the same shape.'

        # [M1, M2, D] or [M, D] if diag.
        sd = (x1 - x2) ** 2
        sd.clamp_min_(0)

        # Apply lengthscale and sum over dimensions.
        sd_ = (5 * sd / self.lengthscale ** 2).sum(-1)
        _sd_ = sd_ ** 0.5
        cov = self.scale ** 2 * (1 + _sd_ + sd_ / 3) * (-_sd_).exp()

        return cov


class RQKernel(Kernel):
    """A Rational Quadratic kernel.

    :param alpha: A float, defines the RQ kernel.
    :param lengthscale: A float/list, the initial lengthscale(s).
    :param scale: A float, the initial scale.
    """
    def __init__(self, alpha=1, lengthscale=1., scale=1., **kwargs):
        super().__init__(**kwargs)

        self.alpha = alpha
        self.lengthscale = nn.Parameter(torch.tensor(lengthscale),
                                        requires_grad=True)
        self.scale = nn.Parameter(torch.tensor(scale), requires_grad=True)

    def forward(self, x1, x2, diag=False):
        """ Returns the covariance matrix defined by the Matern 3/2 kernel.

        :param x1: A Tensor, [M1, D], the inputs for which to evaluate at.
        :param x2: A Tensor, [M2, D], the inputs for which to evaluate at.
        :param diag: A bool, whether to evaluate only the diagonal elements.
        """
        # Add dimensions if needed.
        x1 = x1.unsqueeze(1) if len(x1.shape) == 1 else x1
        x2 = x2.unsqueeze(1) if len(x2.shape) == 1 else x2

        assert x1.shape[-1] == x2.shape[-1], 'Inputs must be the same ' \
                                             'dimension.'
        if not diag:
            x1 = x1.unsqueeze(1)
            x2 = x2.unsqueeze(0)
        else:
            assert x1.shape == x2.shape, 'Inputs must be the same shape.'

        # [M1, M2, D] or [M, D] if diag.
        sd = (x1 - x2) ** 2
        sd.clamp_min_(0)

        # Apply lengthscale and sum over dimensions.
        sd_ = (sd / self.lengthscale ** 2).sum(-1)
        cov = self.scale ** 2 * (1 + sd_ / (2 * self.alpha)) ** (-self.alpha)

        return cov
