import jpype
import jpype.imports

try:
    # please change to the abs path of tetrad-gui-7.2.2-launch.jar
    jpype.startJVM(classpath=['please change to the abs path of tetrad-gui-7.2.2-launch.jar'])
except OSError:
    print('JVM already started')

import json
import time
import signal
import datetime
import argparse
import resource

import numpy as np
import networkx as nx
import edu.cmu.tetrad.search as ts
from utils.translate import numpy_data_to_tetrad, tetrad_graph_to_networkx, networkx_graph_to_tetrad, statistic


def str2bool(v):
    if v == 'True':
        return True
    elif v == 'False':
        return False
    else:
        argparse.ArgumentTypeError('Boolean value expected.')

parser = argparse.ArgumentParser(description='running fGES with default argument')
parser.add_argument('--data', type=str, default=None, help='learning data, in .npy format')
parser.add_argument('--truth', type=str, default=None, help='truth graph, in .json (node link format) format')

parser.add_argument('--penalty', type=float, default=1.0, help='penalty discount for bic score, 1.0 in default')
parser.add_argument('--faithful', type=str2bool, default=False, help='faithfulness assumed or not, False in default')
parser.add_argument('--degree', type=int, default=1000, help='max degree for the search, 1000 in default')

parser.add_argument('--est', type=str, default=None, help='output estimation graph, in .json (node link format) format')
parser.add_argument('--metrics', type=str, default=None, help='output metrics in .json format')
parser.add_argument('--cpulimit', type=int, default=None, help='cpu time limitation, in seconds')
args = parser.parse_args()


# A method to set the maximum (SOFT) CPU time in seconds
def set_max_cpu_time(seconds):
    # Use resource module to set the limit
    resource.setrlimit(resource.RLIMIT_CPU, (seconds, seconds + 30))
    print('set cpu max time: {}, soft'.format(seconds))


# A method to print the current CPU time in seconds
def pass_soft_cpu_time(signum, frame):
  # Use time module to get the process time
    jpype.shutdownJVM()
    
    metrics = {
        'bic': 0,
        'cpu_time_final': time.process_time()
    }
    with open(args.metrics, 'w') as f:
        json.dump(metrics, f, indent='\t')

    cpu_time = time.process_time()
    print(f"pass soft cpu time, final CPU time: {cpu_time} seconds")
    exit(0)


def main():
    n_data = np.load(args.data)
    print('{} -- {} loaded. start algorithm running.'.format(datetime.datetime.now(), args.data))

    s_time = time.time()
    t_data = numpy_data_to_tetrad(n_data)
    fges_score = ts.SemBicScore(t_data)
    fges_score.setPenaltyDiscount(args.penalty)
    fges_algo = ts.Fges(fges_score)
    fges_algo.setFaithfulnessAssumed(args.faithful)
    fges_algo.setMaxDegree(args.degree)
    est_t_graph = fges_algo.search()
    est_nx_graph = tetrad_graph_to_networkx(est_t_graph)
    s_time = time.time() - s_time
    print('{} -- algorithm running finished.'.format(datetime.datetime.now()))

    with open(args.est, 'w') as f:
        json.dump(nx.node_link_data(est_nx_graph), f, indent='\t')
    print('{} -- saved estimation graph, {}.'.format(datetime.datetime.now(), args.est))
    
    with open(args.truth) as f:
        truth_nx_graph = nx.node_link_graph(json.load(f))
    truth_nx_graph = nx.relabel_nodes(truth_nx_graph, {n: str(i) for i, n in enumerate(sorted(truth_nx_graph))})
    truth_t_graph = networkx_graph_to_tetrad(truth_nx_graph, est_t_graph)
    print('{} -- {} loaded. start metrics evaluation.'.format(datetime.datetime.now(), args.truth))

    metrics = statistic(truth_t_graph, est_t_graph)
    metrics['bic'] = ts.SemBicScorer.scoreDag(ts.SearchGraphUtils.dagFromCPDAG(est_t_graph), t_data)
    metrics['bic_truth'] = ts.SemBicScorer.scoreDag(ts.SearchGraphUtils.dagFromCPDAG(truth_t_graph), t_data)
    metrics['runtime'] = s_time

    with open(args.metrics, 'w') as f:
        json.dump(metrics, f, indent='\t')
    print('{} -- metrics saved, {}.'.format(datetime.datetime.now(), args.metrics))


if __name__ == '__main__':
    if args.cpulimit is not None:
        set_max_cpu_time(args.cpulimit)
        signal.signal(signal.SIGXCPU, pass_soft_cpu_time)

    main()