from autogluon.core.models import AbstractModel
from autogluon.tabular.trainer.model_presets.presets import MODEL_TYPES


class RealMLPModel(AbstractModel):
    def _fit(self, X, y, X_val=None, y_val=None, time_limit=None, verbosity=0, **kwargs):
        from pytabkit import RealMLP_TD_Regressor

        model_params = self._get_model_params()
        model_params.setdefault("random_state", 0)
        tabkit_metric = "multi_pinball(" + ",".join(str(q) for q in self.params_aux["quantile_levels"]) + ")"

        self.model = RealMLP_TD_Regressor(
            train_metric_name=tabkit_metric,
            val_metric_name=tabkit_metric,
            verbosity=verbosity,
            **model_params,
        )
        self.model.fit(X=X, y=y, X_val=X_val, y_val=y_val, time_to_fit_in_seconds=time_limit)

    def predict(self, X, **kwargs):
        return self.model._predict_raw(X)


MODEL_TYPES["REALMLP"] = RealMLPModel
