import jpype
import jpype.imports

try:
    # please change to the abs path of tetrad-gui-7.2.2-launch.jar
    jpype.startJVM(classpath=['please change to the abs path of tetrad-gui-7.2.2-launch.jar'])
except OSError:
    print('JVM already started')

import json
import time
import signal
import resource
import datetime
import argparse

import numpy as np
import networkx as nx
import edu.cmu.tetrad.search as ts
from utils.translate import numpy_data_to_tetrad, tetrad_graph_to_networkx

parser = argparse.ArgumentParser(description='running PC_stable with default argument')
parser.add_argument('--data', type=str, default=None, help='learning data, in .npy format')

parser.add_argument('--alpha', type=float, default=0.05, help='alpha values for fisher Z test, 0.05 in default')
parser.add_argument('--depth', type=int, default=1000, help=' the maximum number of conditioning nodes for any conditional independence checked, 1000 in default')

parser.add_argument('--est', type=str, default=None, help='output estimation graph, in .json (node link format) format')
parser.add_argument('--metrics', type=str, default=None, help='output metrics in .json format')
parser.add_argument('--cpulimit', type=int, default=None, help='cpu time limitation, in seconds')
args = parser.parse_args()


# A method to set the maximum (SOFT) CPU time in seconds
def set_max_cpu_time(seconds):
  # Use resource module to set the limit
  resource.setrlimit(resource.RLIMIT_CPU, (seconds, seconds + 30))
  print('set cpu max time: {}, soft'.format(seconds))


# A method to print the current CPU time in seconds
def pass_soft_cpu_time(signum, frame):
    # Use time module to get the process time
    jpype.shutdownJVM()

    metrics = {
        'bic': 0,
        'cpu_time_final': time.process_time()
    }
    with open(args.metrics, 'w') as f:
        json.dump(metrics, f, indent='\t')

    cpu_time = time.process_time()
    print(f"pass soft cpu time, final CPU time: {cpu_time} seconds")
    exit(0)


def main():
    n_data = np.load(args.data)
    print('{} -- {} loaded. start algorithm running.'.format(datetime.datetime.now(), args.data))

    t_data = numpy_data_to_tetrad(n_data)
    pcstable_test = ts.IndTestFisherZ(t_data, args.alpha)
    pcstable_algo = ts.PcStable(pcstable_test)
    pcstable_algo.setDepth(args.depth)
    est_t_graph = pcstable_algo.search()
    est_nx_graph = tetrad_graph_to_networkx(est_t_graph)
    print('{} -- algorithm running finished.'.format(datetime.datetime.now()))

    with open(args.est, 'w') as f:
        json.dump(nx.node_link_data(est_nx_graph), f, indent='\t')
    print('{} -- saved estimation graph, {}.'.format(datetime.datetime.now(), args.est))

    metrics = {}
    metrics['bic'] = ts.SemBicScorer.scoreDag(ts.SearchGraphUtils.dagFromCPDAG(est_t_graph), t_data)

    with open(args.metrics, 'w') as f:
        json.dump(metrics, f, indent='\t')
    print('{} -- metrics saved, {}.'.format(datetime.datetime.now(), args.metrics))


if __name__ == '__main__':
    if args.cpulimit is not None:
        set_max_cpu_time(args.cpulimit)
        signal.signal(signal.SIGXCPU, pass_soft_cpu_time)
        
    main()