import sklearn
from global_variables import TEST_SET_RATIO, RANDOM_SEED
from sklearn.ensemble import RandomForestRegressor
import pickle
import pathlib
import os
import utils
import torch

class RandomForestWrapper:
    """
    Class ued to wrap random forest so we can batch query it
    """
    def __init__(self, clf):
        self.clf = clf

    def __getitem__(self, item):
        return self.predict(item)
    
    def predict(self, item):
        if isinstance(item, torch.Tensor):
            return torch.Tensor(self.clf.predict(item.cpu().numpy()))
        else:
            return self.clf.predict(item)

    def __call__(self, item):
        return self.predict(item)


def load_model(dataset, depth):
    this_directory = pathlib.Path(__file__).parent.resolve()
    model_dir = f"{this_directory}/cache/{dataset}"
    model_path = f"{model_dir}/random_forest_{depth}.pkl"

    if not os.path.exists(model_path):
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        X_train, X_test, y_train,  y_test = utils.get_dataset(dataset, with_splits=True)

        clf = RandomForestRegressor(max_depth=depth, n_jobs=5)
        clf.fit(X_train, y_train)
        r2 = clf.score(X_test, y_test)
        print(f"model trained: r2={r2}")

        with open(model_path, "wb") as f:
            pickle.dump([clf, r2], f)


    else:
        with open(f"{model_dir}/random_forest_{depth}.pkl", "rb") as f:
            clf, r2 = pickle.load(f)



    return RandomForestWrapper(clf)

if __name__ == "__main__":
    pass