import logging, sys
import numpy as np, os, pickle
from matplotlib import pyplot as plt
import seaborn as sns;
sns.set()
from compactem.oracles import get_calibrated_rf, get_calibrated_gbm
from compactem.utils.data_format import DataInfo
from compactem.main import compact_using_oracle
from compactem.model_builder import RandomForest, GradientBoostingModel, LinearProbabilityModel, DecisionTree
from compactem.utils.output_processors import Result
from sklearn.metrics import f1_score
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

LOGFILE = r'output/experiment_with_oracle.log'
OP_DIR = 'output'
if not os.path.exists(f"./{OP_DIR}") or not os.path.isdir(f"./{OP_DIR}"):
    os.makedirs(f"./{OP_DIR}")

def setup_logger():
    """
    set up the logger to log in a file as well as stdout
    :return:
    """
    logFormatter = logging.Formatter("%(asctime)s [%(process)d] [%(threadName)s] [%(levelname)-5.5s] "
                                     "[%(filename)s:%(lineno)d] [%(funcName)s]"
                                     "  %(message)s")
    rootLogger = logging.getLogger()
    rootLogger.setLevel(logging.INFO)

    fileHandler = logging.FileHandler(LOGFILE, mode='w')
    fileHandler.setFormatter(logFormatter)
    rootLogger.addHandler(fileHandler)

    consoleHandler = logging.StreamHandler(sys.stdout)
    consoleHandler.setFormatter(logFormatter)
    rootLogger.addHandler(consoleHandler)

setup_logger()


def demo_lpm():
    N, T = 1000, 5
    X, y = load_digits(return_X_y=True)

    # use only N points, for quicker experiments
    X, _, y, _ = train_test_split(X, y, train_size=N, stratify=y, random_state=0)
    # oracle_model = CalibratedRandomForest(params_range={'n_estimators': (5, 10)})
    dataset_info = DataInfo("digits", (X, y), [2, 3, 4, 5], evals=T)

    aggr_results_df = compact_using_oracle(datasets_info=dataset_info,
                                           model_builder_class=LinearProbabilityModel,
                                           oracle=get_calibrated_gbm,
                                           oracle_params={'num_boosting_rounds':200, 'max_depth':5, 'learning_rate': 0.1},
                                           task_dir=r'output/lpm', overwrite=True)
    print("Result summary:")
    print(aggr_results_df[['dataset_name', 'complexity', 'avg_original_score',
                           'avg_new_score', 'pct_improvement']])


def demo_dt():
    N, T = 1000, 5
    X, y = load_digits(return_X_y=True)

    # use only N points, for quicker experiments
    X, _, y, _ = train_test_split(X, y, train_size=N, stratify=y, random_state=0)
    # oracle_model = CalibratedRandomForest(params_range={'n_estimators': (5, 10)})
    dataset_info = DataInfo("digits", (X, y), [2, 3, 4, 5], evals=T)

    aggr_results_df = compact_using_oracle(datasets_info=dataset_info,
                                           model_builder_class=DecisionTree,
                                           oracle=get_calibrated_gbm,
                                           oracle_params={'num_boosting_rounds':200, 'max_depth':5, 'learning_rate': 0.1},
                                           task_dir=r'output/dt', overwrite=True)
    print("Result summary:")
    print(aggr_results_df[['dataset_name', 'complexity', 'avg_original_score',
                           'avg_new_score', 'pct_improvement']])


if __name__== "__main__":
    demo_dt()
    # demo_lpm()
    