import logging
import os
import pickle
import sys
import optuna


def run_optuna(args, run_experiment):
    def objective(trial):
        if args.use_stratification:
            args.num_teachers = trial.suggest_categorical('num_teachers', [2, 4, 8, 16, 32, 64, 128, 256, 512])
        else:
            args.num_teachers = trial.suggest_int('num_teachers', 10, 300)
        args.threshold = trial.suggest_int('threshold', 2, 10)
        args.fairness_threshold = trial.suggest_float('fairness_threshold', 0.01, 0.1)
        args.sigma_threshold = trial.suggest_float('sigma_threshold', 40, 70)
        args.sigma_gnmax = trial.suggest_float('sigma_gnmax', 10, 30)
        try:
            student_model_validation_accuracy, validation_dem_disparity, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_dem_parity = \
                run_experiment(args, lambda *x, **y: None, results_db=None)
            trial.set_user_attr("student_model_test_accuracy", student_model_test_accuracy)
            trial.set_user_attr("test_dem_parity", test_dem_parity)
        except ValueError:
            # Failure case
            student_model_validation_accuracy, validation_dem_disparity = -1000, 1000
        return student_model_validation_accuracy, validation_dem_disparity
    
    # Add stream handler of stdout to show the messages
    optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
    study_name = f"{args.pate_based_model}_{args.dataset}_{args.fairness_metric}_budget_{args.budget}_seed_{args.seed}"  # Unique identifier of the study.
    storage_name = "sqlite:///{}.db".format(study_name)

    if os.path.exists(f"{study_name}_sampler.pkl"):
        restored_sampler = pickle.load(open(f"{study_name}_sampler.pkl", "rb"))
        study = optuna.create_study(directions=['maximize', 'minimize'], study_name=study_name, storage=storage_name, load_if_exists=True, sampler=restored_sampler)
    else:
        study = optuna.create_study(directions=['maximize', 'minimize'], study_name=study_name, storage=storage_name, load_if_exists=True)
    
    try:
        study.optimize(objective, n_trials=args.num_optuna_trials)
    except KeyboardInterrupt:
        print("Keyboard Interrupt")
    finally:
        # Save the sampler with pickle to be loaded later.
        with open(f"{study_name}_sampler.pkl", "wb") as fout:
                pickle.dump(study.sampler, fout)