import argparse
from PowerKCI import KCIByPower_Parallel
from OriginalKCI import OrgKCI_parallel
from SyntheticData import SyntheticData
from utils import *
import warnings
warnings.filterwarnings('ignore')
import time

def PowerTest(z_tr, z_te, x_tr, x_te, y_tr, y_te, test_method, alpha = 0.05):
    KCI = KCIByPower_Parallel(z_tr, z_te, x_tr, x_te, y_tr, y_te, test_method=test_method)
    sel_xyz_pvalue, sel_yz_pvalue, sel_z_pvalue, med_pvalue = KCI.select_kernels()

    alpha5_selxyz = (sel_xyz_pvalue <= alpha)
    alpha5_selyz = (sel_yz_pvalue <= alpha)
    alpha5_selz = (sel_z_pvalue <= alpha)
    alpha5_med = (med_pvalue <= alpha)

    return (med_pvalue, sel_z_pvalue, sel_yz_pvalue, sel_xyz_pvalue,
            alpha5_med, alpha5_selz, alpha5_selyz, alpha5_selxyz)

def OriginalTest(z_tr, z_te, x_tr, x_te, y_tr, y_te, test_method, alpha = 0.05):
    KCI = OrgKCI_parallel(z_tr, z_te, x_tr, x_te, y_tr, y_te, test_method=test_method)
    sel_xyz_pvalue, sel_xz_pvalue, sel_y_pvalue, med_pvalue = KCI.select_kernels()

    alpha5_selxyz = (sel_xyz_pvalue <= alpha)
    alpha5_selxz = (sel_xz_pvalue <= alpha)
    alpha5_sely = (sel_y_pvalue <= alpha)
    alpha5_med = (med_pvalue <= alpha)

    return (med_pvalue, alpha5_sely, sel_xz_pvalue, sel_xyz_pvalue,
            alpha5_med, alpha5_sely, alpha5_selxz, alpha5_selxyz)



def TestMain(CI, dim_list, data_nums, repeat_times, test_method, noise_scale, T_scale):
    seed = np.random.randint(0, 10000)
    np.random.seed(seed)
    print("data: sample data, CI: ", CI, ", dims: ", dim_list, ", data_nums: ", data_nums, "T_scale: ", T_scale, "noise_scale: ", noise_scale)
    print("test: ", seed, ", repeat_time:", repeat_times, ", test_method: ", test_method, ", alpha: ", 0.05)

    alpha1_all = 0
    alpha2_all = 0
    alpha3_all = 0
    alpha4_all = 0

    time_list = []
    for i in range(repeat_times):

        epoch_seed = np.random.randint(0, 100000)
        X, Y, Z = SyntheticData(dim_list, nums=data_nums, seed=epoch_seed, CI=CI, noise_scale = noise_scale, T_scale=T_scale)
        z_tr, z_te, x_tr, x_te, y_tr, y_te = data_split(X, Y, Z)

        start_time = time.time()

        ## decomposed version
        method = ["median", "sel_z", "sel_yz", "sel_all"]
        (pvalue_1, pvalue_2, pvalue_3, pvalue_4,
        alpha_1, alpha_2, alpha_3, alpha_4) = PowerTest(z_tr, z_te, x_tr, x_te, y_tr, y_te, test_method, alpha=0.05)

        ### oringal KCI without decomposition
        # method = ["median", "sel_y", "sel_xz", "sel_all"]
        # (pvalue_1, pvalue_2, pvalue_3, pvalue_4,
        # alpha_1, alpha_2, alpha_3, alpha_4) = OriginalTest(z_tr, z_te, x_tr, x_te, y_tr, y_te, test_method, alpha=0.05)


        end_time = time.time()
        computation_time = end_time - start_time
        time_list.append(computation_time)

        alpha1_all += alpha_1
        alpha2_all += alpha_2
        alpha3_all += alpha_3
        alpha4_all += alpha_4

        if i == 5 or i%100 == 0:
            if CI == True:
                print("Type I error, idx ", i,
                      "-> ", method[0], ": (", alpha1_all, " ", pvalue_1,
                      ") ", method[1], ": (", alpha2_all, " ", pvalue_2,
                      ") ", method[2], ": (", alpha3_all, " ", pvalue_3,
                      ") ", method[3], ": (", alpha4_all, " ", pvalue_4,
                      "), time: (", round(np.mean(time_list), 3), "+", round(np.var(time_list), 4), ")")
            else:
                print("Type II error, idx ", i,
                      "-> ", method[0], ": (", i - alpha1_all + 1, " ", pvalue_1,
                      ") ", method[1], ": (", i - alpha2_all + 1, " ", pvalue_2,
                      ") ", method[2], ": (", i - alpha3_all + 1, " ", pvalue_3,
                      ") ", method[3], ": (", i - alpha4_all + 1, " ", pvalue_4,
                      "), time: (", round(np.mean(time_list), 3), "+", round(np.var(time_list), 4), ")")

    print("------------------------------------------------------------")
    print("data: sample data, CI: ", CI, ", dim_list: ", dim_list, ", data_nums: ", data_nums, "T scale:", T_scale, "noise_scale: ", noise_scale)
    print("test: ", seed, ", repeat_time:", repeat_times, ", test_method: ", test_method, ", alpha: ", 0.05)

    if CI == True:
        # type I error samples:
        print("type I error:")
        print("final--->", "alpha5_", method[0], ": ", alpha1_all / repeat_times)
        print("final--->", "alpha5_", method[1], ": ", alpha2_all / repeat_times)
        print("final--->", "alpha5_", method[2], ": ", alpha3_all / repeat_times)
        print("final--->", "alpha5_", method[3], ": ", alpha4_all / repeat_times)
        print("time: (", round(np.mean(time_list), 3), "+", round(np.var(time_list), 4), ")")
    else:
        # type II error samples:
        print("type II error:")
        print("final--->", "alpha5_", method[0], ": ", 1 - alpha1_all / repeat_times)
        print("final--->", "alpha5_", method[1], ": ", 1 - alpha2_all / repeat_times)
        print("final--->", "alpha5_", method[2], ": ", 1 - alpha3_all / repeat_times)
        print("final--->", "alpha5_", method[3], ": ", 1 - alpha4_all / repeat_times)
        print("time: (", round(np.mean(time_list), 3), "+", round(np.var(time_list), 4), ")")
    print("------------------------------------------------------------")


def main():
    parser = argparse.ArgumentParser(description="parameter setting")
    parser.add_argument("--data_nums", type=int, default=500, help="sample size")
    parser.add_argument("--repeat_times", type=int, default=1000, help="repeat times")
    parser.add_argument("--test_method", type=str, default="chi_square",
                        help="testing method: 'chi_square' for weighted sum of chi square and 'gamma' for Gamma approximimation")
    opt = parser.parse_args()

    N_scale = 1
    T_scale = 0.5

    opt.data_nums = 500
    ### on the conditioning dimension
    for dim_list in [[1, 1, i] for i in range(1, 10, 1)]:
        TestMain(CI=True, dim_list=dim_list, data_nums=opt.data_nums,
                 repeat_times= opt.repeat_times, test_method = opt.test_method, T_scale=T_scale, noise_scale=N_scale)
        TestMain(CI=False,  dim_list=dim_list, data_nums=opt.data_nums,
                 repeat_times= opt.repeat_times, test_method = opt.test_method, T_scale=T_scale, noise_scale=N_scale)

    ### on the sample size
    # dim_list = [1, 1, 5]
    # for opt.data_nums in [100, 200, 400, 600, 800, 1000]:
    #     TestMain(CI=True,  dim_list=dim_list, data_nums=opt.data_nums,
    #              repeat_times= opt.repeat_times, test_method = opt.test_method, T_scale=T_scale, noise_scale=N_scale)
    #     TestMain(CI=False,  dim_list=dim_list, data_nums=opt.data_nums,
    #              repeat_times= opt.repeat_times, test_method = opt.test_method, T_scale=T_scale, noise_scale=N_scale)
    
    ### on the noise scale
    # dim_list = [1, 1, 5]
    # opt.data_nums = 500
    # for N_scale in [0.5, 1, 1.5, 2, 2.5, 3]:
    #     TestMain(CI=True,  dim_list=dim_list, data_nums=opt.data_nums,
    #              repeat_times= opt.repeat_times, test_method = opt.test_method, T_scale=T_scale, noise_scale=N_scale)
    #     TestMain(CI=False,  dim_list=dim_list, data_nums=opt.data_nums,
    #              repeat_times= opt.repeat_times, test_method = opt.test_method, T_scale=T_scale, noise_scale=N_scale)


if __name__ == '__main__':
    main()