import argparse
from PowerKCI import PowerKCI
from syn_data import synthetic_data
from utils import *
import warnings
warnings.filterwarnings('ignore')

def Test(z_tr, z_te, x_tr, x_te, y_tr, y_te, test_method, alpha = 0.05):
    KCI = PowerKCI(z_tr, z_te, x_tr, x_te, y_tr, y_te, test_method=test_method)
    sel_pvalue, med_pvalue = KCI.compute_pvalue()
    alpha5_sel = (sel_pvalue <= alpha)
    alpha5_med = (med_pvalue <= alpha)
    return (med_pvalue, sel_pvalue,
            alpha5_med, alpha5_sel)


def TestMain(CI, fea_dims, data_nums, repeat_times, test_method):
    method = ["median", "power"]
    seed = np.random.randint(0, 10000)
    np.random.seed(seed)
    print("data: syntheic, CI: ", CI, ", fea_dims: ", fea_dims, ", data_nums: ", data_nums)
    print("test: ", seed, ", repeat_time:", repeat_times, ", test_method: ", test_method, ", alpha: ", 0.05)

    med_alpha_all = 0
    power_alpha2_all = 0



    for i in range(repeat_times):
        epoch_seed = np.random.randint(0, 100000)
        X, Y, Z = synthetic_data(fea_dims, nums=data_nums, seed=epoch_seed, CI=CI)
        z_tr, z_te, x_tr, x_te, y_tr, y_te = data_split(X, Y, Z)

        (pvalue_median, pvalue_power,
        alpha_median, alpha_power) = Test(z_tr, z_te, x_tr, x_te, y_tr, y_te,
                               test_method, alpha=0.05)

        med_alpha_all  += alpha_median
        power_alpha2_all += alpha_power
        if i == 5 or i%50 == 0:
            if CI == True:
                print("Type I error, idx ", i,
                      "-> ", method[0], ": (", med_alpha_all, " ", pvalue_median,
                      ") ", method[1], ": (", power_alpha2_all, " ", pvalue_power,
                      ")")
            else:
                print("Type II error, idx ", i,
                      "-> ", method[0], ": (", i - med_alpha_all + 1, " ", pvalue_median,
                      ") ", method[1], ": (", i - power_alpha2_all + 1, " ", pvalue_power,
                      ")")

    print("------------------------------------------------------------")
    print("data: syntheic, CI: ", CI, ", fea_dims: ", fea_dims, ", data_nums: ", data_nums)
    print("test: ", seed, ", repeat_time:", repeat_times, ", test_method: ", test_method, ", alpha: ", 0.05)

    if CI == True:
        # type I error samples:
        print("type I error:")
        print("final--->", "alpha5_", method[0], ": ", med_alpha_all / repeat_times)
        print("final--->", "alpha5_", method[1], ": ", power_alpha2_all / repeat_times)
    else:
        # type II error samples:
        print("type II error:")
        print("final--->", "alpha5_", method[0], ": ", 1 - med_alpha_all / repeat_times)
        print("final--->", "alpha5_", method[1], ": ", 1 - power_alpha2_all / repeat_times)
    print("------------------------------------------------------------")


def main():
    parser = argparse.ArgumentParser(description="parameter setting")
    parser.add_argument("--data_nums", type=int, default=200, help="sample size")
    parser.add_argument("--repeat_times", type=int, default=1000, help="repeat times")
    parser.add_argument("--test_method", type=str, default="chi_square",
                        help="testing method: 'chi_square' for weighted sum of chi square and 'gamma' for Gamma approximimation")
    opt = parser.parse_args()

    for fea_dim in range(1, 10): # dimension of conditioning Z
        TestMain(CI=True,  fea_dims=fea_dim, data_nums=opt.data_nums,
                 repeat_times=opt.repeat_times, test_method = opt.test_method)
        TestMain(CI=False,  fea_dims=fea_dim, data_nums=opt.data_nums,
                 repeat_times=opt.repeat_times, test_method = opt.test_method)

if __name__ == '__main__':
    main()