import torch as th
from kernels.base import BCK

class Linear(BCK):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    
  def kernel(self, A: th.tensor, B: th.tensor):
    """Compute the linear kernel between preprocessed inputs A and B."""
    return th.mm(A, B.T)
    

if __name__ == '__main__':
  from config import *
  from run.train import train, fit, execute

  seed = 42
  # Linear 
  metrics, _,_,_ = execute({**MOONS, **LIN(2)}, [seed], fit)
  print(metrics)

  # Example for Trainable FE
  # metrics, _, _, _ = execute({**MNIST10, **Linear(784, epochs=100, compression=8.0)}, [seed], train)
