import torch as t
import torch.nn as nn
from .general import KG


class DistanceKernel(nn.Module):
    def __init__(self, noise, inducing_batch):
        super().__init__()

        assert inducing_batch is not None
        self.inducing_batch = inducing_batch

        if noise:
            self.log_noise = nn.Parameter(-4.*t.ones(()))
        else:
            self.register_buffer("log_noise", -100.*t.ones(()))

    def d2s(self, x, y=None):
        if y is None:
            y = x

        x2 = (x**2).sum(-1)[..., :, None]
        y2 = (y**2).sum(-1)[..., None, :]
        return x2 + y2 - 2.*x@y.transpose(-1, -2)

    def forward(self, x):
        kwargs = {'device': x.device, 'dtype': x.dtype}
        x = x * (-self.log_lengthscale).exp()
        xi = x[:, :self.inducing_batch]
        xt = x[:, self.inducing_batch:]

        d2ii = self.d2s(xi, xi)
        d2it = self.d2s(xi, xt)
        d2tt = t.zeros(xt.shape[:-1], **kwargs)

        noise_var = 1e-8 + self.log_noise.exp()

        Kii = self.kernel(d2ii)
        Kii = Kii + noise_var*t.eye(Kii.shape[-1], **kwargs)
        Kit = self.kernel(d2it)
        Ktt = self.kernel(d2tt) + noise_var

        return KG(Kii, Kit, Ktt)


class SqExpKernel(DistanceKernel):
    def __init__(self, in_features, noise=True, inducing_batch=None):
        super().__init__(noise, inducing_batch)

        self.log_lengthscale = nn.Parameter(t.zeros(in_features))
        self.log_height = nn.Parameter(t.zeros(()))

    def kernel(self, d2s):
        return t.exp(self.log_height - d2s)
