import jpype
import jpype.imports

try:
    # please change to the abs path of tetrad-gui-7.2.2-launch.jar
    jpype.startJVM(classpath=['please change to the abs path of tetrad-gui-7.2.2-launch.jar'])
except OSError:
    print('JVM already started')

import json
import time
import datetime
import argparse

import numpy as np
import networkx as nx
import edu.cmu.tetrad.search as ts
from algorithms.utils.translate import numpy_data_to_tetrad, tetrad_graph_to_networkx, networkx_graph_to_tetrad, statistic

parser = argparse.ArgumentParser(description='running PC_stable with default argument')
parser.add_argument('--data', type=str, default=None, help='learning data, in .npy format')
parser.add_argument('--truth', type=str, default=None, help='truth graph, in .json (node link format) format')
parser.add_argument('--est', type=str, default=None, help='output estimation graph, in .json (node link format) format')
parser.add_argument('--metrics', type=str, default=None, help='output metrics in .json format')
args = parser.parse_args()


def main():
    n_data = np.load(args.data)
    print('{} -- {} loaded. start algorithm running.'.format(datetime.datetime.now(), args.data))
    s_time = time.time()

    t_data = numpy_data_to_tetrad(n_data)
    pcstable_test = ts.IndTestFisherZ(t_data, 0.05)
    pcstable_algo = ts.PcStable(pcstable_test)
    est_t_graph = pcstable_algo.search()
    est_nx_graph = tetrad_graph_to_networkx(est_t_graph)

    s_time = time.time() - s_time
    print('{} -- algorithm running finished.'.format(datetime.datetime.now()))

    with open(args.est, 'w') as f:
        json.dump(nx.node_link_data(est_nx_graph), f, indent='\t')
    print('{} -- saved estimation graph, {}.'.format(datetime.datetime.now(), args.est))
    
    with open(args.truth) as f:
        truth_nx_graph = nx.node_link_graph(json.load(f))
    truth_nx_graph = nx.relabel_nodes(truth_nx_graph, {n: str(i) for i, n in enumerate(sorted(truth_nx_graph))})
    truth_t_graph = networkx_graph_to_tetrad(truth_nx_graph, est_t_graph)
    print('{} -- {} loaded. start metrics evaluation.'.format(datetime.datetime.now(), args.truth))

    metrics = statistic(truth_t_graph, est_t_graph)
    metrics['runtime'] = s_time

    with open(args.metrics, 'w') as f:
        json.dump(metrics, f, indent='\t')
    print('{} -- metrics saved, {}.'.format(datetime.datetime.now(), args.metrics))


if __name__ == '__main__':
    main()