import sklearn
import pickle
import pathlib
import os
import utils
import catboost
import torch

class CatboostWrapper:

    def __init__(self, clf):
        self.clf = clf

    def __getitem__(self, item):
        return self.predict(item)
    
    def predict(self, item):
        if isinstance(item, torch.Tensor):
            return torch.Tensor(self.clf.predict(item.cpu().numpy()))
        else:
            return self.clf.predict(item)

    def __call__(self, item):
        return self.predict(item)



def load_model(dataset, depth):
    this_directory = pathlib.Path(__file__).parent.resolve()
    model_dir = f"{this_directory}/cache/{dataset}"
    model_path = f"{model_dir}/catboost_{depth}.pkl"
    if not os.path.exists(model_path):
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        X_train, X_test, y_train,  y_test = utils.get_dataset(dataset, with_splits=True)

        clf = catboost.CatBoostRegressor(depth=depth, n_estimators=10)
        #clf2 = sklearn.ensemble.RandomForestRegressor(max_depth=depth)

        # Fit model
        clf.fit(X_train, y_train)
        #clf2.fit(X_train, y_train)
        # Get predictions
        preds = clf.predict(X_test)
        #pred2 = clf2.predict(X_test)
        r2 = sklearn.metrics.r2_score(y_test, preds)

        print(f"model : r2={r2}")
        #r22 = sklearn.metrics.r2_score(y_test, pred2)

        #print(f"model : r22={r22}")

        with open(model_path, "wb") as f:
            pickle.dump([clf, r2], f)
    else:
        with open(model_path, "rb") as f:
            clf, r2 =pickle.load(f)
            print(f"catboost model loaded. depth={depth} and r2={r2}")

    return CatboostWrapper(clf)

if __name__ == "__main__":
    load_model("avGFP", 4)

