import numpy as np
import argparse
import sys
import os
import builtins

builtins.ROOT_PATH = "your_path"
sys.path.append(builtins.ROOT_PATH)
from exp.dataloader import *
from baselines.IT.IT import IT
from baselines.MMDFUSE.MMDFUSE import MMDFuse
from baselines.C2ST.C2ST import C2ST
from baselines.MMDAgg.MMDAgg import MMDAgg
from baselines.MMDDeep.MMDDeep import MMDDeep
from baselines.MMDM.MMDM import MMDM
from baselines.RLTST.RLTST import RLTST

# builtins.IS_LOG = False  # mute the log
print("Logging is {}".format("on" if builtins.IS_LOG else "off"))

parser = argparse.ArgumentParser()

# parameters to generate data
builtins.MODEL_ARCH = None
parser.add_argument('--dataset', default='blob', help='dataset name', type=str)
parser.add_argument('--check', default=1, help='check reject adv (1), reject clean(0)', type=int)
parser.add_argument('--N_more', default=4000, help='number of samples in referenced data', type=int)
parser.add_argument('--N_less', default=[20, 50, 80, 100, 150, 200, 300], help='number of samples in input data', type=int)
parser.add_argument('--rs', default=[819,819,819,819,819,819,819,819], help='random seed', type=list)

# parameters to conduct exp
parser.add_argument('--n_exp', default=10, help='number of experiments', type=int)
parser.add_argument('--n_test', default=100, help='number of test times', type=int)
parser.add_argument('--n_per', default=100, help='number of permutations', type=int)
parser.add_argument('--alpha', default=0.05, help='probability of not reject adv', type=float)

# parameters for MMD-FUSE
parser.add_argument('--is_balanced', default=False, help='is balanced', type=bool)

args = parser.parse_args()

Results = np.zeros((8, args.n_exp))

H_IT = np.zeros(args.n_test)
H_MMDFUSE = np.zeros(args.n_test)
H_C2ST = np.zeros(args.n_test)
H_MMDAgg = np.zeros(args.n_test)
H_MMDM = np.zeros(args.n_test)
H_RLTST = np.zeros(args.n_test)
H_MMDDeep = np.zeros(args.n_test)
H_LOTT = np.zeros(args.n_test)

exp_path = "your_path/exp/"
for i in range(len(args.N_less)):
    # if i != 1:
    #     continue
    
    rs = args.rs[i]
    N_less = args.N_less[i]
    balanced_flag = "balanced" if args.is_balanced else ""
    file_name = args.dataset+'_'+'Nmore'+str(args.N_more)+'_'+'Nless'+str(N_less)+'_'+'rs'+str(rs)+'_'+balanced_flag
    setup_time_log()
    for kk in range(args.n_exp):

        # H_IT, _, _ = IT(args.dataset, args.N_more, N_less, kk*1000+rs+N_less, args.check, args.is_balanced, args.n_test, args.n_per, args.alpha)

        # H_MMDFUSE, _, _ = MMDFuse(args.dataset, args.N_more, N_less, kk*1000+rs+N_less, args.check, args.is_balanced, args.n_test, args.n_per, args.alpha)

        # H_C2ST, _, _ = C2ST(args.dataset, args.N_more, N_less, kk*1000+rs+N_less, args.check, args.is_balanced, args.n_test, args.n_per, args.alpha)
        
        # H_MMDAgg, _, _ = MMDAgg(args.dataset, args.N_more, N_less, kk*1000+rs+N_less, args.check, args.is_balanced, args.n_test, args.n_per, args.alpha)

        # H_MMDM, _, _ = MMDM(args.dataset, args.N_more, N_less, kk*1000+rs+N_less, args.check, args.is_balanced, args.n_test, args.n_per, args.alpha)

        H_RLTST, _, _ = RLTST(args.dataset, args.N_more, N_less, kk*1000+rs+N_less, args.check, args.is_balanced, args.n_test, args.n_per, args.alpha)

        H_MMDDeep, _, _ = MMDDeep(args.dataset, args.N_more, N_less, kk*1000+rs+N_less, args.check, args.is_balanced, args.n_test, args.n_per, args.alpha)

        H_LOTT, _, _ = IT(args.dataset, args.N_more, N_less, kk*1000+rs+N_less, args.check, args.is_balanced, args.n_test, args.n_per, args.alpha, is_selection=False)
    
        Results[0, kk] = np.mean(H_IT)
        Results[1, kk] = np.mean(H_MMDFUSE)
        Results[2, kk] = np.mean(H_C2ST)
        Results[3, kk] = np.mean(H_MMDAgg)
        Results[4, kk] = np.mean(H_MMDM)
        Results[5, kk] = np.mean(H_RLTST)
        Results[6, kk] = np.mean(H_MMDDeep)
        Results[7, kk] = np.mean(H_LOTT)

        if args.check:
            os.makedirs(os.path.join(exp_path, "Results", "test_power", str(args.alpha)), exist_ok=True)
            np.savetxt(os.path.join(exp_path, "Results", "test_power", str(args.alpha), file_name), Results, fmt='%.3f')
        else:
            os.makedirs(os.path.join(exp_path, "Results", "typeI_error", str(args.alpha)), exist_ok=True)
            np.savetxt(os.path.join(exp_path, "Results", "typeI_error", str(args.alpha), file_name), Results, fmt='%.3f')
        # break

    Final_results = np.zeros((Results.shape[0], 2))
    for j in range(Results.shape[0]):
        Final_results[j, 0] = np.mean(Results[j, :])
        Final_results[j, 1] = np.std(Results[j, :])/np.sqrt(args.n_exp)

    if args.check == 1:
        result_file = os.path.join(exp_path, "Results", "test_power", str(args.alpha), file_name)
        with open(result_file, 'a') as f:
            np.savetxt(f, Final_results, fmt='%.3f')
    if args.check == 0:
        result_file = os.path.join(exp_path, "Results", "typeI_error", str(args.alpha), file_name)
        with open(result_file, 'a') as f:
            np.savetxt(f, Final_results, fmt='%.3f')

    print("IT: {:.3f} ± {:.3f}".format(Final_results[0, 0], Final_results[0, 1]))   
    print("MMDFUSE: {:.3f} ± {:.3f}".format(Final_results[1, 0], Final_results[1, 1]))
    print("C2ST: {:.3f} ± {:.3f}".format(Final_results[2, 0], Final_results[2, 1]))
    print("MMDAgg: {:.3f} ± {:.3f}".format(Final_results[3, 0], Final_results[3, 1]))
    print("MMDM: {:.3f} ± {:.3f}".format(Final_results[4, 0], Final_results[4, 1]))
    print("RLTST: {:.3f} ± {:.3f}".format(Final_results[5, 0], Final_results[5, 1]))
    print("MMDDeep: {:.3f} ± {:.3f}".format(Final_results[6, 0], Final_results[6, 1]))
    print("LOTT: {:.3f} ± {:.3f}".format(Final_results[7, 0], Final_results[7, 1]))
    # break