import torch

from helpers.utils import to_torch


class Kernels(torch.nn.Module):
    def __init__(self, kernel_e, kernel_z):
        super().__init__()
        self.e = kernel_e
        self.z = kernel_z


class LinearKernel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        pass

    def forward(self, X):
        return X.matmul(X.t())


class RBFKernel(torch.nn.Module):

    def __init__(self, sigma):
        super().__init__()
        assert sigma > 0, 'sigma must be > 0. Current %s' % str(sigma)
        self.sigma = sigma

    def set_kernel_param(self, sigma=None):
        self.sigma = to_torch(sigma)

    def forward(self, X):
        sumX2 = (X ** 2).sum(1, keepdim=True)
        D2 = sumX2 - 2.0 * X.matmul(X.t()) + sumX2.t()
        D2 = D2.clamp(min=0)

        K = torch.exp(-D2 / (2 * self.sigma ** 2))
        return K


