from .utils import rbf_kernel
from .base import Base

class KME(Base):
    def __init__(self, cfg, X, **kwargs):
        super().__init__(cfg, X, **kwargs)

    def generate_helper(self, *args, **kwargs):
        self.KX = self.kernel_x(self.X, self.X)
        self.norm = self.KX.sum().item()

    def compare(self, kme):
        X1 = self.get('X')
        X2 = self.get('X', kme)
        KX12 = self.kernel_x(X1, X2)
        cross_norm = KX12.sum().item()
        return self.norm + kme.norm - 2 * cross_norm

    def save(self):
        return

    def load_helper(self):
        return