import torch
import torch.nn as nn

from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel
from gpytorch.utils.grid import ScaleToBounds
    
    
class FeatureModule(torch.nn.Sequential):
    def __init__(self, dim_seq):
        super(FeatureModule, self).__init__()
        assert len(dim_seq) >= 2
        for i in range(len(dim_seq) - 1):
            self.add_module(f'linear{i}', torch.nn.Linear(dim_seq[i], dim_seq[i+1]))
            if i + 2 < len(dim_seq):
                self.add_module(f'relu{i}', torch.nn.ReLU())
                

class DeepKernel(Kernel):
    
    def __init__(self, dim_seq, base_kernel: Kernel = None, freeze_nn: bool = False):
        super(DeepKernel, self).__init__()
        self.feature_module = FeatureModule(dim_seq=dim_seq)
        if freeze_nn:
            self.feature_module.requires_grad_(False)
        self.kernel = ScaleKernel(MaternKernel(nu=2.5)) if base_kernel is None else base_kernel
        self.scale_to_bounds = ScaleToBounds(-1., 1.)

    def forward(self, x1, x2, diag=False, **params):
        # x1_transform = self.feature_module(x1)
        # x2_transform = self.feature_module(x2)
        x1_transform = self.scale_to_bounds(self.feature_module(x1))
        x2_transform = self.scale_to_bounds(self.feature_module(x2))
        return self.kernel.forward(x1=x1_transform, x2=x2_transform, diag=diag, **params)
