
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import HistGradientBoostingClassifier
import torch
"""
This class is to generate the models used as predictors, given some dataset. 
We will probably want to do some shaningans here,  with some smart retraining using the weights of previous models.
"""


class PredictorFGenerator():

    def __init__(self, cfg_exp: dict, cfg_dataset: dict):
        self.cfg_exp = cfg_exp
        self.cfg_dataset = cfg_dataset
        self.loss_metric = cfg_exp['loss_metric']
        self.already_stored_f = False
        if cfg_dataset['name'] in ['yelp','epicgames', 'gauss']:
            self.base_model_constructor = LogisticRegression
        else:
            self.base_model_constructor = HistGradientBoostingClassifier
        

    def get_trained_f(self, X_t, y_t, model_index=None):  # for now it's jut a tree.

        clf = self.base_model_constructor().fit(X_t, y_t)
        return clf

    # return the pe of a model f evaluated on dataset x, y.
    # for now, pe is set to be 1- acc
    def get_loss_metric(self, f, X_t_to_eval, y_t_to_eval) -> float:
        acc = f.score(X_t_to_eval, y_t_to_eval)
        return 1 - acc


class PilotPredictorFGenerator(PredictorFGenerator):

    def __init__(self, cfg_exp: dict):
        self.cfg_exp = cfg_exp
        self.loss_metric = cfg_exp['loss_metric']
        

    def get_trained_f(self, X_t, y_t, model_index=None):  # for now it's jut a tree.

        clf = LogisticRegression().fit(X_t, y_t)
        return clf

    # return the pe of a model f evaluated on dataset x, y.
    # for now, pe is set to be 1- acc
    def get_loss_metric(self, f, X_t_to_eval, y_t_to_eval) -> float:
        acc = f.score(X_t_to_eval, y_t_to_eval)
        return 1 - acc


