import jpype
import jpype.imports

try:
    # please change to the abs path of tetrad-gui-7.2.2-launch.jar
    jpype.startJVM(classpath=['please change to the abs path of tetrad-gui-7.2.2-launch.jar'])
except OSError:
    print('JVM already started')

import os
import json
import time
import datetime
import argparse
import subprocess

from multiprocessing.pool import ThreadPool

import numpy as np
import networkx as nx
import edu.cmu.tetrad.search as ts
from algorithms.utils.translate import numpy_data_to_tetrad, networkx_graph_to_tetrad, statistic
from partitions.partition_improved import improved_modified_hierarchical_clustering

parser = argparse.ArgumentParser(description='running fGES with default argument')
parser.add_argument('--data', type=str, default=None, help='learning data, in .npy format')
parser.add_argument('--truth', type=str, default=None, help='truth graph, in .json (node link format) format')
parser.add_argument('--output', type=str, help='output file path, in dir')

parser.add_argument('--beta', type=float, help='max ratio of sub graph')
parser.add_argument('--processes', type=int, help='number of processes running in the same time when estimation')

args = parser.parse_args()


def print_with_time(content):
    print(time.strftime("%Y-%m-%d %H:%M:%S UTC %z", time.gmtime()))
    print(content, '\n')


def open_and_wait(cmd, logfile):
    with open(logfile, 'w') as f:
        process = subprocess.Popen(cmd, shell=True, stdout=f, stderr=f)
        print_with_time('open cmd: {}'.format(cmd))
        process.wait()
        print_with_time('finished cmd: {}'.format(cmd))


def main():
    if not os.path.exists(args.output): os.makedirs(args.output)

    n_data = np.load(args.data)
    print('{} -- {} loaded. start algorithm running.'.format(datetime.datetime.now(), args.data))

    s_time = time.time()
    # partition
    print('start partition...')
    cov_mat = np.corrcoef(n_data.T)
    partitions = improved_modified_hierarchical_clustering(corrs=cov_mat, beta=args.beta)
    print('{} partitions with size {}, time {}'.format(len(partitions), 
                                                       [len(partition) for partition in partitions], 
                                                       time.time()-s_time))

    cache_dir = os.path.join(args.output, 'cache')
    print('store sub data in {}'.format(cache_dir))

    if not os.path.exists(cache_dir): os.makedirs(cache_dir)
    for pi, partition in enumerate(partitions):
        np.save(os.path.join(cache_dir, 'data_{}.npy').format(pi), n_data[:,partition])
    print('store finished... time {}'.format(time.time() - s_time))

    tp = ThreadPool(args.processes)
    for pi, partition in enumerate(partitions):
        ci = 'default'
        cmd = 'python ./algorithms/pcstable_params_runner_bic.py' + \
        ' --data {}'.format(os.path.join(cache_dir, 'data_{}.npy').format(pi)) + \
        ' --est {}'.format(os.path.join(cache_dir, 'est_{}_{}.json').format(pi, ci)) + \
        ' --metrics {}'.format(os.path.join(cache_dir, 'metrics_{}_{}.json').format(pi, ci)) + \
        ' --cpulimit {}'.format(7200)
        tp.apply_async(open_and_wait, (cmd, os.path.join(cache_dir, 'logger_{}_{}.log').format(pi, ci)))
    tp.close()
    tp.join()
    print('{} -- algorithm running finished.'.format(datetime.datetime.now()))

    est_nx_graph = nx.DiGraph()
    est_nx_graph.add_nodes_from(str(sn) for sn in range(n_data.shape[1]))
    for pi, partition in enumerate(partitions):
        best_ci = 'default'
        if not os.path.exists(os.path.join(cache_dir, 'est_{}_{}.json').format(pi, best_ci)):
            print_with_time('file not exists - {}'.format(os.path.join(cache_dir, 'est_{}_{}.json').format(pi, best_ci)))
            continue
        with open(os.path.join(cache_dir, 'est_{}_{}.json').format(pi, best_ci)) as f:
            sub_nx_graph = nx.node_link_graph(json.load(f))
        est_nx_graph.add_edges_from(nx.relabel_nodes(sub_nx_graph, 
                                                 {str(inode): str(node) for (inode, node) in enumerate(partition)}).edges)
    s_time = time.time() - s_time

    with open(os.path.join(args.output, 'est_graph.json'), 'w') as f:
        json.dump(nx.node_link_data(est_nx_graph), f, indent='\t')
    print('{} -- saved estimation graph, {}.'.format(datetime.datetime.now(), os.path.join(args.output, 'est_graph.json')))
    est_t_graph = networkx_graph_to_tetrad(est_nx_graph, None)

    with open(args.truth) as f:
        truth_nx_graph = nx.node_link_graph(json.load(f))

    truth_nx_graph = nx.relabel_nodes(truth_nx_graph, {n: str(i) for i, n in enumerate(sorted(truth_nx_graph))})
    truth_t_graph = networkx_graph_to_tetrad(truth_nx_graph, est_t_graph)
    print('{} -- {} loaded. start metrics evaluation.'.format(datetime.datetime.now(), args.truth))

    t_data = numpy_data_to_tetrad(n_data)
    metrics = statistic(truth_t_graph, est_t_graph)
    metrics['bic'] = ts.SemBicScorer.scoreDag(ts.SearchGraphUtils.dagFromCPDAG(est_t_graph), t_data)
    metrics['bic_truth'] = ts.SemBicScorer.scoreDag(ts.SearchGraphUtils.dagFromCPDAG(truth_t_graph), t_data)
    metrics['runtime'] = s_time

    with open(os.path.join(args.output, 'metrics.json'), 'w') as f:
        json.dump(metrics, f, indent='\t')
    print('{} -- metrics saved, {}.'.format(datetime.datetime.now(), os.path.join(args.output, 'metrics.json')))


if __name__ == '__main__':
    main()