import torch
import torch.nn as nn

from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel
from gpytorch.utils.grid import ScaleToBounds


class FeatureModule(nn.Module):
    
    def __init__(self, in_dim, out_dim, hidden_width=8, hidden_depth=2):
        super(FeatureModule, self).__init__()
        self.activate = nn.ReLU()
        self.fc_first = nn.Linear(in_dim, hidden_width)
        self.fc_hidden = nn.ModuleList([nn.Linear(hidden_width, hidden_width) for _ in range(hidden_depth-1)])
        self.fc_last = nn.Linear(hidden_width, out_dim)
        
    def forward(self, x):
        x = self.activate(self.fc_first(x))
        for h in self.fc_hidden:
            x = self.activate(h(x))
        return self.fc_last(x)



class DeepKernel(Kernel):
    
    def __init__(self, in_dim, out_dim, base_kernel: Kernel = None, hidden_width: int = 8, hidden_depth: int = 2):
        super(DeepKernel, self).__init__()
        self.feature_module = FeatureModule(in_dim=in_dim, out_dim=out_dim)
        self.kernel = ScaleKernel(MaternKernel(nu=2.5)) if base_kernel is None else base_kernel
        # self.scale_to_bounds = ScaleToBounds(-1., 1.)

    def forward(self, x1, x2, diag=False, **params):
        x1_transform = self.feature_module(x1)
        x2_transform = self.feature_module(x2)
        # x1_transform = self.scale_to_bounds(self.feature_module(x1))
        # x2_transform = self.scale_to_bounds(self.feature_module(x2))
        return self.kernel.forward(x1=x1_transform, x2=x2_transform, diag=diag, **params)
