import logging, sys,  numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits
from compactem.core import oracle_transfer
from compactem.oracles import get_calibrated_gbm
from compactem.main import compact_using_oracle
from compactem.utils.data_format import DataInfo
from compactem.model_builder import DecisionTree
import pandas as pd
pd.options.display.float_format = '{:,.2f}'.format
from datetime import datetime

logger = logging.getLogger('')
logger.setLevel(logging.DEBUG)
sh = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('[%(asctime)s] %(levelname)s [%(filename)s.%(funcName)s:%(lineno)d] %(message)s',
                              datefmt='%a, %d %b %Y %H:%M:%S')
#fh.setFormatter(formatter)
sh.setFormatter(formatter)
#logger.addHandler(fh)
logger.addHandler(sh)
logging.getLogger('matplotlib.font_manager').disabled = True

sh = logging.StreamHandler(sys.stdout)
def demo_oracle_transfer():

    # use small N, T for quick results

    N, T, num_runs = 1000, 100, 1

    X, y = load_digits(return_X_y=True)

    X, _, y, _ = train_test_split(X, y, train_size=N, stratify=y, random_state=0)
    dataset_info = DataInfo("digits", (X, y), [3, 4, 5], evals=T)

    # if you run this a second time on the same task_dir you might want to set "overwrite=True"
    start_time = datetime.now()
    aggr_results_df = compact_using_oracle(datasets_info=dataset_info,
                                           model_builder_class=DecisionTree,
                                           oracle=get_calibrated_gbm,

                                           task_dir=r'output/quickstart', overwrite=True, optimizer='pysot',
                                           runs=num_runs, optimizer_params={'strategy': 'dycors'},
                                           use_weights=True, sampling_trials=10, num_weight_trials=10)
    end_time = datetime.now()

    print("Result summary:")
    print(aggr_results_df[['dataset_name', 'complexity', 'avg_original_score',
                           'avg_new_score', 'pct_improvement']])
    print(f"runtime: {(end_time-start_time).seconds} sec")
    # aggr_results_df[['dataset_name', 'complexity', 'avg_original_score', 'avg_new_score',
    #                  'pct_improvement']].to_csv(r'docs/sample_op.csv',  float_format='%.2f', index=False)


def sampling_edge_case():
    N = 50
    X = np.linspace(0, 1, N).reshape((-1, 1))
    unc = np.linspace(0, 1, N)
    y = np.ones(N)
    dp_alpha, prior_for_a_beta_A, prior_for_a_beta_B, prior_for_b_beta_A, prior_for_b_beta_B = 1, 0.1, 0.1, 300, 300
    # fig = plt.figure()
    sample_X, sample_y = oracle_transfer.sample_using_oracle(10 * N, X, y, unc, dp_alpha, prior_for_a_beta_A,
                                                             prior_for_a_beta_B,
                                                             prior_for_b_beta_A, prior_for_b_beta_B,
                                                             scale_a=10, scale_b=10,
                                                             pct_from_original=0.0, use_weights=False,
                                                             num_weight_trials=10)
    print(f"std sampling, sample size={len(sample_X)}")

def demo_usage():
    from sklearn.datasets import load_digits
    from compactem.oracles import get_calibrated_rf
    from compactem.main import compact_using_oracle
    from compactem.utils.data_format import DataInfo
    from compactem.model_builder import GradientBoostingModel

    X, y = load_digits(return_X_y=True)
    d1 = DataInfo(dataset_name="digits", data=(X, y),
                  complexity_params=[(2, 2), (2, 3), (4, 4)],
                  evals=50)

    # same dataset, but we will specify some fields as categorical
    d2 = DataInfo(dataset_name="digits_with_categ_feat", data=(X, y),
                  complexity_params=[(3, 3)],
                  evals=100,
                  additional_info={'categorical_idxs': [0, 1, 2]})

    # we will use Random Forest as our oracle, but a reduced search space for this demo
    results = compact_using_oracle(datasets_info=[d1, d2],
                                   model_builder_class=GradientBoostingModel,
                                   oracle=get_calibrated_rf,
                                   oracle_params={'params_range': {'max_depth': [3, 5],
                                                                   'n_estimators': [2, 5, 10, 30]}},
                                   task_dir=r'output/usage_demo', runs=3, overwrite=True)


if __name__== "__main__":
    # demo_oracle_transfer()
    # demo_usage()
    sampling_edge_case()
