import time
import copy
import pickle
from tabulate import tabulate
from lgmvae import *
from methods import *
from ce_dataset_model import *
from recourse_evaluator import *
import argparse

# nohup python -u run.py --dataset heloc --model rf --device 1 > ./0_outputs/log/heloc_rf.log 2>&1 &
# nohup python -u run.py --dataset heloc --model nn --device 1 > ./0_outputs/log/heloc_nn.log 2>&1 &

class Timer:
    def __init__(self, name="Execution", n_points=100):
        self.name = name
        self._start_time = None
        self.n_points = n_points
        self.elapsed_time = None

    def __enter__(self):
        self._start_time = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.elapsed_time = (time.perf_counter() - self._start_time) / self.n_points
        # print(f"--- Avg time '{self.name}' for each point is {self.elapsed_time:.3f} seconds. ---\n")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, choices=["heloc", "wine"], default="heloc")
    parser.add_argument("--model", type=str, choices=["nn", "rf"], default="rf")
    parser.add_argument("--device", type=str, choices=["0", "1", "2", "cpu"], default="0")
    parser.add_argument("--n_test_points", type=str, default="100")
    return parser.parse_args()

def main():
    args = parse_args()
    device = f"cuda:{args.device}"
    if args.device == "cpu":
        device = "cpu"
    dname = args.dataset
    mname = args.model

    X_train, y_train, X_val, y_val, X_test, y_test, constraints = get_dataset(dname)
    clf, retrained_clfs = get_models(mname, X_train, y_train, X_val, y_val, X_test, y_test, dname=dname)

    model = get_generative_model_config(mname, dname)
    model.to(device)
    model.eval()

    test_set_size = 0.1
    y_pred_train = clf.predict(X_train)
    X_train_new, X_test_new, y_train_new, y_test_new = train_test_split(
        X_train, 
        y_pred_train, 
        test_size=test_set_size, 
        random_state=42, 
        stratify=y_pred_train 
    )
    cluster_centroids, gen_util_results = evaluate_generative_model_utility(model, clf, X_train_new, y_train_new, X_test_new, y_test_new)

    print("======================= GENERATING COUNTERFACTUALS =======================")

    # part 1. Standard Benchmarking

    # setup explainers
    dice_exp = get_dice_explainer(X_train, y_train, clf)
    tree, X1_class1_clf = get_nnce_tree(X_train, clf) # for nnce and stable ce
    face_exp = get_face_explainer(X_train, clf)

    X_test_ce_all = X_test[np.where(clf.predict(X_test)==0)]
    test_len = min(int(args.n_test_points), len(X_test_ce_all))
    timing_results = np.zeros((8, 5))
    benchmarking_res = np.zeros((8, 5, 6))
    
    # save these for the next benchmark
    face_ces, ours_paths, X_test_ce = None, None, None
    print("======================= EVALUATING COUNTERFACTUALS =======================")
    for i in tqdm(range(5)):
        X_test_ce = copy.deepcopy(X_test_ce_all)
        idx_test = np.arange(0, len(X_test_ce_all))
        np.random.seed(i * 200 + 50)
        np.random.shuffle(idx_test)
        X_test_ce = X_test_ce_all[idx_test[:test_len]]
        r_eval = RecourseEvaluator(X_test_ce, clf, X_train)

        with Timer("dice", test_len) as t:
            dice_ces = generate_dice_ce(X_test_ce, dice_exp)
        timing_results[0][i] = t.elapsed_time
        with Timer("nnce", test_len) as t:
            nnce_ces = generate_nn_ce(X_test_ce, X1_class1_clf, tree)
        timing_results[1][i] = t.elapsed_time
        with Timer("icce", test_len) as t:
            ic_ces = generate_ic_ce(X_test_ce, X1_class1_clf, tree, clf)
        timing_results[2][i] = t.elapsed_time
        with Timer("face", test_len) as t:
            face_ces = generate_face_ce(X_test_ce, face_exp)
        timing_results[3][i] = t.elapsed_time
        with Timer("stce", test_len) as t:
            stce_ces = generate_stable_ce(X_test_ce, X1_class1_clf, clf, tree, threshold=0.85)
        timing_results[4][i] = t.elapsed_time
        with Timer("ours", test_len) as t:
            ours_ces_first, ours_ces_middle, ours_ces_last, ours_paths = generate_ours_ce(X_test_ce, model, clf, cluster_centroids, device)
        timing_results[5][i] = t.elapsed_time
        timing_results[6][i] = t.elapsed_time
        timing_results[7][i] = t.elapsed_time

        # evaluate ces
        util_variables = {'cgmvae': model, 'dice_exp': dice_exp, 'X1_class1_clf':X1_class1_clf, 'tree':tree, 'face_exp':face_exp, 'clf':clf, 'cluster_centroids':cluster_centroids, 'device':device}
        benchmarking_res[0, i] = evaluate_ces(r_eval, dice_ces, retrained_clfs, multi=True, name='dice', ce_function=generate_dice_ce, util_vars=util_variables)
        benchmarking_res[1, i] = evaluate_ces(r_eval, nnce_ces, retrained_clfs, multi=False, name='nnce', ce_function=generate_nn_ce, util_vars=util_variables)
        benchmarking_res[2, i] = evaluate_ces(r_eval, ic_ces, retrained_clfs, multi=True, name='icce', ce_function=generate_ic_ce, util_vars=util_variables)
        benchmarking_res[3, i] = evaluate_ces(r_eval, face_ces, retrained_clfs, multi=False, name='face', ce_function=generate_face_ce, util_vars=util_variables)
        benchmarking_res[4, i] = evaluate_ces(r_eval, stce_ces, retrained_clfs, multi=False, name='stce', ce_function=generate_stable_ce, util_vars=util_variables)
        benchmarking_res[5, i] = evaluate_ces(r_eval, ours_ces_first, retrained_clfs, multi=True, name='ours-first-ce', ce_function=generate_ours_ce, util_vars=util_variables)
        benchmarking_res[6, i] = evaluate_ces(r_eval, ours_ces_middle, retrained_clfs, multi=True, name='ours-middle-ce', ce_function=generate_ours_ce, util_vars=util_variables)
        benchmarking_res[7, i] = evaluate_ces(r_eval, ours_ces_last, retrained_clfs, multi=True, name='ours-last-ce', ce_function=generate_ours_ce, util_vars=util_variables)

    # print results
    results_merged = np.concatenate([benchmarking_res, np.expand_dims(timing_results, axis=2)], axis=2)
    names = ['dice', 'nnce', 'icce', 'face', 'stce', 'ours-first', 'ours-middle', 'ours-last']

    runs = [f"Run {i+1}" for i in range(5)]
    multi_index = pd.MultiIndex.from_product([names, runs], names=["Method", "Run"])
    reshaped_results = results_merged.reshape(-1, 7)
    metric_names = ['val', 'cost', 'lof', 'div', 'm-rob', 'in-rob', 'time']
    results_df = pd.DataFrame(reshaped_results, index=multi_index, columns=metric_names)
    summary_df = results_df.groupby('Method').agg(['mean', 'std'])
    print(tabulate(summary_df, headers='keys', tablefmt='pipe', floatfmt=".4f"))
    summary_df.to_csv(f'./0_outputs/benchmark_results/{dname}_{mname}_ce.csv', index=False)
    with open(f'./0_outputs/benchmark_results/{dname}_{mname}_gen.pkl', 'wb') as f:
        pickle.dump(gen_util_results, f)

    # benchmarking actionability
    print("======================= EVALUATING CE ACTIONABILITY =======================")
    # randomly get some constraints
    num_constraints = 5
    rule_lists = []
    for i in range(len(X_test_ce)):
        np.random.seed(i*10+8)
        idxs = np.random.random_integers(0, len(constraints)-1, size=(num_constraints,))
        rule_lists.append(constraints[idxs])

    # get path ces
    face_interpolate_paths = get_path_ce(X_test_ce, face_ces, X1_class1_clf, tree, method="interpolate")
    face_greedy_paths = get_path_ce(X_test_ce, face_ces, X1_class1_clf, tree, method="greedy")
    ours_constrained_paths = generate_ours_path_ce(X_test_ce, model, clf, cluster_centroids, device, rule_lists)
    # evaluate satisfaction
    face_int_res, _, _ = evaluate_actionability(X_test_ce, face_interpolate_paths, rule_lists, clf, enforce=True)
    face_greedy_res, _, _ = evaluate_actionability(X_test_ce, face_greedy_paths, rule_lists, clf, enforce=True)
    ours_naive_res, _, _ = evaluate_actionability(X_test_ce, ours_paths, rule_lists, clf, enforce=True, n_clusters=int(model.c_dim/model.y_dim))
    ours_constrained_res, _, _ = evaluate_actionability(X_test_ce, ours_constrained_paths, rule_lists, clf, enforce=False, n_clusters=int(model.c_dim/model.y_dim))

    print("face interpolate", "face_greedy", "ours_naive", "ours_constrained")
    print(face_int_res, face_greedy_res, ours_naive_res, ours_constrained_res)

if __name__ == "__main__":
    main()
