# import torch
# from gpytorch.kernels import Kernel, RBFKernel

# from parallel_opt.parallel_strategy import ParallelisationStrategy


# LOG_ON_DIM = ParallelisationStrategy.can_log_transform_dimension()


# class LogKernel(Kernel):
    
#     """
#     Computes kernel of log input, i.e., K(log(x1), log(x2))
#     """
    
#     def __init__(self, base_kernel: Kernel = RBFKernel()):
#         super().__init__()
#         self.base_kernel = base_kernel

#     def forward(self, x1, x2, diag=False, **params):
#         # use ParallelisationStrategy.log_transform since some dimensions shouldn't apply log transform
#         x1[..., LOG_ON_DIM] = torch.log(x1[..., LOG_ON_DIM])
#         x2[..., LOG_ON_DIM] = torch.log(x2[..., LOG_ON_DIM])
#         return self.base_kernel.forward(
#             x1=x1,
#             x2=x2,
#             diag=diag,
#             **params
#         )
