import logging, traceback, os, pickle, sys, optuna
from pathlib import Path


def run_optuna(args, packed_data, run_experiment, utility_metric='accuracy'):
    def objective(trial):
        if args.use_stratification:
            args.num_teachers = trial.suggest_categorical('num_teachers', [2, 4, 8, 16, 32, 64, 128, 256, 512])
        else:
            args.num_teachers = trial.suggest_int('num_teachers', 10, 300)
        args.threshold = trial.suggest_int('threshold', 2, 10)
        # args.fairness_threshold = trial.suggest_float('fairness_threshold', 0.01, 0.1)
        args.sigma_threshold = trial.suggest_float('sigma_threshold', 40, 70)
        args.sigma_gnmax = trial.suggest_float('sigma_gnmax', 10, 30)
        try:
            log = print if args.verbose else lambda *x, **y: None
            student_model_validation_accuracy, validation_disparity, validation_coverage, validation_auc, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_disparity, test_coverage, test_auc = \
                    run_experiment(args, log, packed_data=packed_data, results_db=None)
            trial.set_user_attr("student_model_test_accuracy", student_model_test_accuracy)
            trial.set_user_attr("test_dem_parity", test_disparity)
            trial.set_user_attr("test_coverage", test_coverage)
            trial.set_user_attr("test_auc", test_auc)
            trial.set_user_attr("achieved_eps", achieved_eps)
        except ValueError as e:
            # Print the error type
            print(f"Error in run_experiment", e)
            if "No queries to be answered" not in str(e):
                print(traceback.format_exc())
            student_model_validation_accuracy, validation_disparity, validation_coverage = -1000, 1000, -1000

        if utility_metric == 'accuracy':
            utility_value = student_model_validation_accuracy
        elif utility_metric == 'auc':
            utility_value = validation_auc
        else:
            raise ValueError(f"Unknown utility metric {utility_metric}")
        
        if args.use_inference_time_postprocessing:
            return utility_value, validation_disparity, validation_coverage
        else:
            return utility_value, validation_disparity
    
    # Add stream handler of stdout to show the messages
    optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
    study_name = f"{args.pate_based_model +'IPP' if args.use_inference_time_postprocessing else args.pate_based_model}" \
                    + f"_{args.dataset}_{args.fairness_metric}" \
                    + f"_fairnessThreshold_{args.fairness_threshold}"\
                      f"_budget_{args.budget}_seed_{args.seed}"  # Unique identifier of the study.
    
    Path(args.optuna_db_path).mkdir(parents=True, exist_ok=True)
    storage_name = "sqlite:///{}/{}.db".format(args.optuna_db_path, study_name)
    sampler_path = os.path.join(args.optuna_db_path, f"{study_name}_sampler.pkl")
    
    directions = ['maximize', 'minimize', 'maximize'] if args.use_inference_time_postprocessing else ['maximize', 'minimize']

    if os.path.exists(sampler_path):
        restored_sampler = pickle.load(open(sampler_path, "rb"))
        study = optuna.create_study(directions=directions, study_name=study_name, storage=storage_name, load_if_exists=True, sampler=restored_sampler)
    else:
        study = optuna.create_study(directions=directions, study_name=study_name, storage=storage_name, load_if_exists=True)
    
    try:
        study.optimize(objective, n_trials=args.num_optuna_trials)
        print("Optuna complete:")
        #breakpoint()
        
    except KeyboardInterrupt:
        print("Keyboard Interrupt")
    finally:
        #breakpoint()
        # Save the sampler with pickle to be loaded later.
        with open(sampler_path, "wb") as fout:
            pickle.dump(study.sampler, fout)
        
        return study.best_trials