import pickle
from pathlib import Path
from collections import defaultdict
import multiprocessing as mp

# import nbimporter
# from simulation_y_ber_x_ber_marginal_cond_shift_parallel_copy import run_simulation
from simulations import run_simulation


if __name__ == "__main__":

    w = 0
    p_x = 0.5
    delta = 0.02  # 0
    n_test = [135]
    corr_settings = ["medium_corr_b", "high_corr"]
    # corr_settings = ["medium_corr_b"]
    config_shared_X = {
        "n_training": 15,
        "n_test": n_test,
        "num_of_repititions": 500,
        "alpha": 0.05,
        "nsteps": 500,
        "epsilon": 0.1,
        "M": 128,
        "p_y_x_0_null": 0,
        "p_y_x_1_null": 0,
        "p_y_x_0_alt": 0,
        "p_y_x_1_alt": 0,
        "k_classes_x": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)
        "p_x": p_x,  # 0.5 X's mean value.
        "history_length": w,
        "min_history_length": 5,
        "settings": "shared_marginal_x",
        "p_y_null": 0,
    }

    for n_t in n_test:
        for corr in corr_settings:
            print(f"Running simulation for correlation setting: {corr}")
            for M in [128]:  # [2,4,8,16,32]:
                if corr == "low_corr":
                    config_shared_X["p_y_x_0_null"] = 0.55
                    config_shared_X["p_y_x_1_null"] = 0.65
                    config_shared_X["n_test"] = n_t
                    config_shared_X["M"] = M
                elif corr == "medium_corr_b":
                    config_shared_X["p_y_x_0_null"] = 0.4
                    config_shared_X["p_y_x_1_null"] = 0.7
                    config_shared_X["n_test"] = n_t
                    config_shared_X["M"] = M
                elif corr == "high_corr":
                    config_shared_X["p_y_x_0_null"] = 0.2
                    config_shared_X["p_y_x_1_null"] = 0.85
                    config_shared_X["n_test"] = n_t
                    config_shared_X["M"] = M

                config_shared_X["p_y_x_0_alt"] = config_shared_X["p_y_x_0_null"] + delta
                config_shared_X["p_y_x_1_alt"] = config_shared_X["p_y_x_1_null"] - delta
                config_shared_X["p_y_null"] = (
                    p_x * config_shared_X["p_y_x_1_null"]
                    + (1 - p_x) * config_shared_X["p_y_x_0_null"]
                )
                config_shared_X["p_y_alt"] = (
                    p_x * config_shared_X["p_y_x_1_alt"]
                    + (1 - p_x) * config_shared_X["p_y_x_0_alt"]
                )

                with mp.Pool(processes=250) as pool:
                    try:
                        statistic_to_power = defaultdict(list)
                        (
                            powers,
                            null_corr,
                            alt_corr,
                            rejection_by_test_results,
                            estimation_error_bias,
                            estimation_error_std,
                        ) = run_simulation(config_shared_X, pool)
                        pool.close()
                        pool.join()
                    except KeyboardInterrupt:
                        print("KeyboardInterrupt received, shutting down executor...")
                        pool.terminate()  # Terminate all processes immediately
                        pool.join()
                        print("Pool terminated.")
                        raise

                    data = {
                        "rejection_lists": rejection_by_test_results,
                        "config": config_shared_X,
                        "null_corr": null_corr,
                        "alt_corr": alt_corr,
                        "powers": powers,
                    }

                    if delta == 0:
                        output_file = (
                            "simulation_res/hyperparameter_tunning/concept_shift_new_version/"
                            + corr
                            + "_null_ntest_"
                            + str(n_t)
                            + "_sanity_check.pkl"
                        )
                    else:
                        # output_file = 'simulation_res/hyperparameter_tunning/calibration_curves/cs_high_rho_high_ntest_M_'+ str(M) + '.pkl'
                        output_file = (
                            "simulation_res/hyperparameter_tunning/concept_shift_new_version/"
                            + corr
                            + "_non_monotone.pkl"
                        )

                    path = Path(output_file)
                    with path.open("wb") as f:
                        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
