import os
import joblib
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import GradientBoostingRegressor

class GBDT:
    def __init__(self, cfg, learnware_id, cfe=None):
        self.learnware_id = learnware_id
        self.path = os.path.join(cfg['dataset_path'], 'learnwares', f'{learnware_id}.pkl')
        self.model = GradientBoostingRegressor(
            n_estimators=cfg['n_estimators'],
            learning_rate=cfg['learning_rate'],
            max_depth=cfg['max_depth'],
            random_state=cfg['seed']
        )

    def __call__(self, x):
        return self.model.predict(x)

    def train(self, trainloader, evalloader):
        # trainloader and evalloader are exactly tuple of (x_train, y_train) and (x_eval, y_eval)
        X_train, y_train = trainloader
        self.model.fit(X_train, y_train)
        mse = self.evaluate(evalloader)[0]
        print(f'Learnware {self.learnware_id} evaluation MSE: {mse:.4f}')
        self.save()

    def evaluate(self, evalloader, **kwargs):
        X_eval, y_eval = evalloader
        y_preds = self(X_eval)
        return mean_squared_error(y_eval, y_preds), [], []

    def save(self):
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        joblib.dump(self.model, self.path)
        print(f'Learnware {self.learnware_id} model saved to {self.path}')

    def load(self):
        if os.path.exists(self.path):
            self.model = joblib.load(self.path)
            return True
        return False