import torch as th
from kernels.base import BCK

class RBF(BCK):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.gamma = th.tensor(1.0 / self.inputs) # from sklearn
    
  def kernel(self, A: th.tensor, B: th.tensor):
    """Compute the RBF kernel between preprocessed inputs A and B."""
    return th.exp(-self.gamma * th.cdist(A, B) ** 2)
    

if __name__ == '__main__':
  from config import *
  from run.train import train, fit, execute

  seed = 42
  # RBF 
  metrics, _,_,_ = execute({**MOONS, **RBF(2)}, [seed], fit)
  print(metrics)

  # Example for Trainable FE
  # metrics, _, _, _ = execute({**MNIST10, **RBF(784, epochs=100, compression=8.0)}, [seed], train)