import os
import argparse
import numpy as np
import time
from utils import count_accuracy, to_dag
from model import colide_ev, colide_nv

if __name__ == '__main__':    
    parser = argparse.ArgumentParser()

    parser.add_argument('--path', type=str, default=None,
                        help='path to data files')
    parser.add_argument('--data-idx', type=str, default=None,
                        help='dataset index')
    parser.add_argument('--seed', type=int, default=0,
                        help='seed number')

    args = parser.parse_args()

    W_gt = np.load(os.path.join(args.path, "DAG{}.npy".format(args.data_idx)))
    X = np.load(os.path.join(args.path, "data{}.npy".format(args.data_idx)))

    #####################
    # Running CoLiDE-EV #
    #####################

    model1 = colide_ev(seed=args.seed)
    t_start = time.time()
    W_hat_ev, sigma_est_ev = model1.fit(X, lambda1=0.05, T=4, s=[1.0, .9, .8, .7], warm_iter=2e4, max_iter=7e4, lr=0.0003)
    t_end = time.time()
    print(f'convergence time for CoLiDE-EV: {t_end-t_start:.4f}s')
    print(f'sigma hat for CoLiDE-EV: {sigma_est_ev:.3f}')
    W_hat_post_ev = to_dag(W_hat_ev, thr=0.3)
    fdr_ev, tpr_ev, fpr_ev, shd_ev, pred_size_ev = count_accuracy(W_gt!=0, W_hat_post_ev!=0)

    #####################
    # Running CoLiDE-NV #
    #####################

    model2 = colide_nv(seed=args.seed)
    t_start = time.time()
    W_hat_nv, Sigma_est_nv = model2.fit(X, lambda1=0.05, T=4, s=[1.0, .9, .8, .7], warm_iter=2e4, max_iter=7e4, lr=0.0003)
    t_end = time.time()
    print(f'convergence time for CoLiDE-NV: {t_end-t_start:.4f}s')
    W_hat_post_nv = to_dag(W_hat_nv, thr=0.3)
    fdr_nv, tpr_nv, fpr_nv, shd_nv, pred_size_nv = count_accuracy(W_gt!=0, W_hat_post_nv!=0)

    ######################
    # Displaying Results #
    ######################

    print('=== CoLiDE-EV Results for', args.path, '===')
    print('SHD:', shd_ev, 'FDR:', fdr_ev, 'TPR:', tpr_ev, 'NNZ:', pred_size_ev)

    print('=== CoLiDE-NV Results for', args.path, '===')
    print('SHD:', shd_nv, 'FDR:', fdr_nv, 'TPR:', tpr_nv, 'NNZ:', pred_size_nv)