import os
from .utils import rbf_kernel, one_hot, indices_to_one_hot

class Base:
    def __init__(self, cfg, X, **kwargs):
        self.cfg = cfg
        self.task = cfg['task']
        self.path = kwargs['path']
        self.device = kwargs['device']
        self.kernel_x = kwargs.get('kernel_x', rbf_kernel)
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        self.X = X.to(self.device)
        self.n = X.shape[0]

    def get(self, val, base=None):
        if base is None:
            return eval(f'self.{val}')
        return base.get(val).to(self.device)

    def generate(self, *args, **kwargs):
        succ = self.load()
        if not succ:
            print(f'Generating specification and saving to {self.path}...')
            self.generate_helper(*args, **kwargs)
            self.save()

    def generate_helper(self, *args, **kwargs):
        raise NotImplementedError('generate_helper function should be implemented in the derived class')

    def compare(self, other):
        raise NotImplementedError('compare function should be implemented in the derived class')

    def save(self):
        raise NotImplementedError('save function should be implemented in the derived class')

    def load(self):
        try:
            self.load_helper()
            # print(f'Specification loaded from {self.path}')
            return True
        except FileNotFoundError:
            return False

    def load_helper(self):
        raise NotImplementedError('load_helper function should be implemented in the derived class')

class CMEBase(Base):
    def __init__(self, cfg, X, Y, **kwargs):
        super().__init__(cfg, X, **kwargs)
        self.lambd = cfg['lambd']
        self.kernel_y = kwargs.get('kernel_y', rbf_kernel)
        self.n_classes = cfg.get('n_classes', None)
        self.Y = Y.to(self.device)

        if self.task == 'classification':
            self.classes = list(set(Y.tolist()))
            self.Y = one_hot(self.Y.long(), self.n_classes)
            self.classes_onehot = indices_to_one_hot(self.classes, self.n_classes).to(self.device)
        elif self.Y.dim() == 1:
            self.Y = self.Y.unsqueeze(1)