import itertools
import pickle
from pathlib import Path
from collections import defaultdict
import multiprocessing as mp

# import nbimporter
# from simulation_y_ber_x_ber_marginal_cond_shift_parallel_copy import run_simulation
from simulations import run_simulation


if __name__ == "__main__":
    # n_tests = [30, 135]
    # p_x_y_0s = [0.2, 0.35]
    # p_x_y_1s = [0.9, 0.65]
    n_tests = [135]
    p_x_y_0s = [0.35]
    p_x_y_1s = [0.65]

    # high_corr_delta_x = 0.014
    # low_corr_delta_x = 0.006

    p_null = 0.5
    # for p_null in [0.5]:#[0.44, 0.45, 0.46, 0.47, 0.52, 0.54]:
    for M, w, n_test, (p_x_y_0, p_x_y_1) in itertools.product(
        [2, 5, 16, 32, 128], [0, 2, 5, 10, 20, 50], n_tests, zip(p_x_y_0s, p_x_y_1s)
    ):  # [2,4,8,16,32,64]:

        print(
            f"Running simulation for M={M}, w={w}, n_test={n_test}, p_x_y_0={p_x_y_0}, p_x_y_1={p_x_y_1}"
        )
        p_alt = p_null + 0.02  # 0.02#0.02#.020.02#0.03#0.02 #0

        ########history window##############
        # for w in [0]:  # , #8, 16, 32]:

        config_LS = {
            "n_training": 15,  # 30
            "n_test": n_test,  # 135,#135,#128,#30,
            "num_of_repititions": 500,
            "alpha": 0.05,
            "nsteps": 500,
            "epsilon": 0.1,
            "M": M,
            "p_x_y_0": p_x_y_0,  # 0.45,#0.35,#0.45,0.35,#0.2,#0.35,
            "p_x_y_1": p_x_y_1,  # 0.65,#0.75,#0.65, 0.65,#0.9, #0.65
            "k_classes_x": 1,  # the number of categories of X (1 - bernoulli, >1 - binomial)
            "p_y_null": p_null,  # 0.5 Y's mean value. try 0.45 - 0.5
            "p_y_alt": p_alt,  # 0.52,
            "history_length": w,  # shuold be set to 32
            "min_history_length": 0,
            "settings": "label_shift",
        }

        with mp.Pool(processes=250) as pool:
            try:
                statistic_to_power = defaultdict(list)
                (
                    powers,
                    null_corr,
                    alt_corr,
                    rejection_by_test_results,
                    estimation_error_bias,
                    estimation_error_std,
                ) = run_simulation(config_LS, pool)

                pool.close()
                pool.join()
            except KeyboardInterrupt:
                print("KeyboardInterrupt received, shutting down executor...")
                pool.terminate()  # Terminate all processes immediately
                pool.join()
                print("Pool terminated.")
                raise

            data = {
                "rejection_lists": rejection_by_test_results,
                "config": config_LS,
                "null_corr": null_corr,
                "alt_corr": alt_corr,
                "resulting_powers": powers,
            }
            # save
            # output_file = "simulation_res/hyperparameter_tunning/label_shift_new_version/sim_high_rho_high_ntest_optimzing_growth_rate.pkl"
            # output_file = 'simulation_res/hyperparameter_tunning/label_shift_new_version/sim_low_rho_high_ntest_optimzing_growth_rate_updated.pkl'
            # output_file = 'simulation_res/hyperparameter_tunning/label_shift_new_version/sim_low_rho_low_ntest_sanity_check.pkl'
            output_file = f"simulation_res/hyperparameter_tunning/label_shift_hyperparams/sim_M_{M}_w_{w}_n_test_{n_test}_p_x_y_0_{p_x_y_0}_p_x_y_1_{p_x_y_1}.pkl"
            path = Path(output_file)
            with path.open("wb") as f:
                pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
