import os
import argparse
import numpy as np
import time
from my_utils import count_accuracy, to_dag
from model import topo_colide_ev, topo_colide_nv

if __name__ == '__main__':    
    parser = argparse.ArgumentParser()

    parser.add_argument('--path', type=str, default=None,
                        help='path to data files')
    parser.add_argument('--data-idx', type=str, default=None,
                        help='dataset index')
    parser.add_argument('--seed', type=int, default=0,
                        help='seed number')

    args = parser.parse_args()

    W_gt = np.load(os.path.join(args.path, "DAG{}.npy".format(args.data_idx)))
    X = np.load(os.path.join(args.path, "data{}.npy".format(args.data_idx)))
    _, d = X.shape

    #####################
    # Running CoLiDE-EV #
    #####################

    model1 = topo_colide_ev(seed=args.seed)
    topo_init = list(np.random.permutation(range(d)))
    t_start = time.time()
    W_hat_ev, sigma_est_ev, _, _ = model1.fit(X=X, topo=topo_init, no_large_search=10, size_small=100, size_large=1000)
    t_end = time.time()
    print(f'convergence time for CoLiDE-EV: {t_end-t_start:.4f}s')
    W_hat_post_ev = to_dag(W_hat_ev, thr=0.3)
    fdr_ev, tpr_ev, fpr_ev, shd_ev, pred_size_ev = count_accuracy(W_gt!=0, W_hat_post_ev!=0)

    #####################
    # Running CoLiDE-NV #
    #####################

    model2 = topo_colide_nv(seed=args.seed)
    topo_init = list(np.random.permutation(range(d)))
    t_start = time.time()
    W_hat_nv, Sigma_est_nv, _, _ = model2.fit(X=X, topo=topo_init, no_large_search=10, size_small=100, size_large=1000)
    t_end = time.time()
    print(f'convergence time for CoLiDE-NV: {t_end-t_start:.4f}s')
    W_hat_post_nv = to_dag(W_hat_nv, thr=0.3)
    fdr_nv, tpr_nv, fpr_nv, shd_nv, pred_size_nv = count_accuracy(W_gt!=0, W_hat_post_nv!=0)

    ######################
    # Displaying Results #
    ######################

    print('=== CoLiDE-EV Results for', args.path, '===')
    print('SHD:', shd_ev, 'FDR:', fdr_ev, 'TPR:', tpr_ev, 'NNZ:', pred_size_ev)

    print('=== CoLiDE-NV Results for', args.path, '===')
    print('SHD:', shd_nv, 'FDR:', fdr_nv, 'TPR:', tpr_nv, 'NNZ:', pred_size_nv)