import jpype
import jpype.imports

try:
    # please change to the abs path of tetrad-gui-7.2.2-launch.jar
    jpype.startJVM(classpath=['please change to the abs path of tetrad-gui-7.2.2-launch.jar'])
except OSError:
    print('JVM already started')

import json
import time
import signal
import resource
import datetime
import argparse

import numpy as np
import networkx as nx
import edu.cmu.tetrad.search as ts
from utils.translate import numpy_data_to_tetrad, tetrad_graph_to_networkx, networkx_graph_to_tetrad, statistic

parser = argparse.ArgumentParser(description='running PC_stable with default argument')
parser.add_argument('--data', type=str, default=None, help='learning data, in .npy format')
parser.add_argument('--truth', type=str, default=None, help='truth graph, in .json (node link format) format')

parser.add_argument('--alpha', type=float, default=0.05, help='alpha values for fisher Z test, 0.05 in default')
parser.add_argument('--depth', type=int, default=1000, help=' the maximum number of conditioning nodes for any conditional independence checked, 1000 in default')

parser.add_argument('--est', type=str, default=None, help='output estimation graph, in .json (node link format) format')
parser.add_argument('--metrics', type=str, default=None, help='output metrics in .json format')
parser.add_argument('--cpulimit', type=int, default=None, help='cpu time limitation, in seconds')
args = parser.parse_args()


# A method to set the maximum (SOFT) CPU time in seconds
def set_max_cpu_time(seconds):
  # Use resource module to set the limit
  resource.setrlimit(resource.RLIMIT_CPU, (seconds, seconds + 30))
  print('set cpu max time: {}, soft'.format(seconds))


# A method to print the current CPU time in seconds
def pass_soft_cpu_time(signum, frame):
    # Use time module to get the process time
    jpype.shutdownJVM()

    metrics = {
        'bic': 0,
        'cpu_time_final': time.process_time()
    }
    with open(args.metrics, 'w') as f:
        json.dump(metrics, f, indent='\t')

    cpu_time = time.process_time()
    print(f"pass soft cpu time, final CPU time: {cpu_time} seconds")
    exit(0)


def main():
    n_data = np.load(args.data)
    print('{} -- {} loaded. start algorithm running.'.format(datetime.datetime.now(), args.data))

    s_time = time.time()
    t_data = numpy_data_to_tetrad(n_data)
    pcstable_test = ts.IndTestFisherZ(t_data, args.alpha)
    pcstable_algo = ts.PcStable(pcstable_test)
    pcstable_algo.setDepth(args.depth)
    est_t_graph = pcstable_algo.search()
    est_nx_graph = tetrad_graph_to_networkx(est_t_graph)
    s_time = time.time() - s_time
    print('{} -- algorithm running finished.'.format(datetime.datetime.now()))

    with open(args.est, 'w') as f:
        json.dump(nx.node_link_data(est_nx_graph), f, indent='\t')
    print('{} -- saved estimation graph, {}.'.format(datetime.datetime.now(), args.est))
    
    with open(args.truth) as f:
        truth_nx_graph = nx.node_link_graph(json.load(f))
    truth_nx_graph = nx.relabel_nodes(truth_nx_graph, {n: str(i) for i, n in enumerate(sorted(truth_nx_graph))})
    truth_t_graph = networkx_graph_to_tetrad(truth_nx_graph, est_t_graph)
    print('{} -- {} loaded. start metrics evaluation.'.format(datetime.datetime.now(), args.truth))

    metrics = statistic(truth_t_graph, est_t_graph)
    metrics['bic'] = ts.SemBicScorer.scoreDag(ts.SearchGraphUtils.dagFromCPDAG(est_t_graph), t_data)
    metrics['bic_truth'] = ts.SemBicScorer.scoreDag(ts.SearchGraphUtils.dagFromCPDAG(truth_t_graph), t_data)
    metrics['runtime'] = s_time

    with open(args.metrics, 'w') as f:
        json.dump(metrics, f, indent='\t')
    print('{} -- metrics saved, {}.'.format(datetime.datetime.now(), args.metrics))


if __name__ == '__main__':
    if args.cpulimit is not None:
        set_max_cpu_time(args.cpulimit)
        signal.signal(signal.SIGXCPU, pass_soft_cpu_time)
        
    main()