import jpype
import jpype.imports

try:
    jpype.startJVM(
        jpype.getDefaultJVMPath(),
        "-Xmx600G",
        # please change to the abs path of tetrad-gui-7.2.2-launch.jar
        jpype.startJVM(classpath=['please change to the abs path of tetrad-gui-7.2.2-launch.jar'])
    )
except OSError:
    print('JVM already started')

import json
import time
import datetime
import argparse

import numpy as np
import networkx as nx
import edu.cmu.tetrad.search as ts
from algorithms.utils.translate import numpy_data_to_tetrad, tetrad_graph_to_networkx, networkx_graph_to_tetrad, statistic

parser = argparse.ArgumentParser(description='running fGES with default argument')
parser.add_argument('--data', type=str, default=None, help='learning data, in .npy format')
parser.add_argument('--truth', type=str, default=None, help='truth graph, in .json (node link format) format')
parser.add_argument('--est', type=str, default=None, help='output estimation graph, in .json (node link format) format')
parser.add_argument('--metrics', type=str, default=None, help='output metrics in .json format')
args = parser.parse_args()


def main():
    n_data = np.load(args.data)
    print('{} -- {} loaded. start algorithm running.'.format(datetime.datetime.now(), args.data))

    s_time = time.time()
    t_data = numpy_data_to_tetrad(n_data)
    fges_score = ts.SemBicScore(t_data)
    fges_algo = ts.Fges(fges_score)
    fges_algo.setParallelized(True)
    est_t_graph = fges_algo.search()

    est_nx_graph = tetrad_graph_to_networkx(est_t_graph)
    s_time = time.time() - s_time
    print('{} -- algorithm running finished.'.format(datetime.datetime.now()))

    with open(args.est, 'w') as f:
        json.dump(nx.node_link_data(est_nx_graph), f, indent='\t')
    print('{} -- saved estimation graph, {}.'.format(datetime.datetime.now(), args.est))
    
    with open(args.truth) as f:
        truth_nx_graph = nx.node_link_graph(json.load(f))
    truth_nx_graph = nx.relabel_nodes(truth_nx_graph, {n: str(i) for i, n in enumerate(sorted(truth_nx_graph))})
    truth_t_graph = networkx_graph_to_tetrad(truth_nx_graph, est_t_graph)
    print('{} -- {} loaded. start metrics evaluation.'.format(datetime.datetime.now(), args.truth))

    metrics = statistic(truth_t_graph, est_t_graph)
    metrics['runtime'] = s_time

    with open(args.metrics, 'w') as f:
        json.dump(metrics, f, indent='\t')
    print('{} -- metrics saved, {}.'.format(datetime.datetime.now(), args.metrics))


if __name__ == '__main__':
    main()