import abc
from abc import ABC

from models.abstract_models.LearningModel import LearningModel
from models.abstract_models.NetworkLearningModel import NetworkLearningModel


class MeanRegressor(LearningModel, ABC):

    def __init__(self, dataset_name: str, saved_models_path: str, seed: int, **kwargs):
        super().__init__(dataset_name, saved_models_path, seed=seed, **kwargs)
        self.mean_regressor_fit_count = 0

    @abc.abstractmethod
    def fit(self, x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            z_train=None, z_val=None, **kwargs):
        if self.mean_regressor_fit_count > 0:
            print(
                f"warning: {self.name} regressor model was fitted {self.mean_regressor_fit_count} times already and is now fitted once again.")
        self.mean_regressor_fit_count += 1

    @abc.abstractmethod
    def predict_mean(self, x, z, **kwargs):
        pass

    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal):
        pass


class NetworkLearningMeanRegressor(MeanRegressor, NetworkLearningModel, ABC):
    def __init__(self, dataset_name: str, saved_models_path: str, seed: int, figures_dir=None):
        MeanRegressor.__init__(self, dataset_name, saved_models_path, seed, figures_dir=figures_dir)
        NetworkLearningModel.__init__(self, dataset_name, saved_models_path, figures_dir=figures_dir, seed=seed)
