import argparse
import glob
import json
import os
import sys

import networkx as nx
import pandas as pd
from matplotlib import pyplot as plt
from numpy import sort


sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from benchmark.data.generate_data import get_pag_skel_with_ada_orientations
from benchmark.utils.causal_graphs import MixedGraph
from benchmark.utils.cache_source_files import copy_referenced_files_to

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Resulting graphs from experiment.')
    parser.add_argument('--algorithms', nargs='+', default=['ridge', 'camuv'])
    params = vars(parser.parse_args())

    #file_dir = sort(glob.glob(os.path.join('.', 'logs', 'paper-plots', 'incr_size_*')))[-params['newest']]
    file_dir = '.'
    # result_dir = os.path.join('.', 'paper_plots', 'incr_size')
    result_dir = file_dir
    copy_referenced_files_to(__file__, os.path.join(result_dir, "plot_graph_results_dump/"))

    with open(os.path.join(file_dir, 'params.json')) as file:
        experiment_params = json.load(file)

    # Plotting

    for data_subdir in glob.glob(os.path.join(file_dir, 'data', '*')):
        subdir = os.path.basename(os.path.normpath(data_subdir))
        graph_dir = os.path.join(file_dir, 'graphs', subdir)
        for i in range(experiment_params['num_datasets']):
            data = pd.read_csv(os.path.join(file_dir, 'data', subdir, 'data_{}.csv'.format(i)), index_col=0)
            ground_truth = MixedGraph.load_graph(os.path.join(graph_dir, 'ground_truth_{}.gml'.format(i))).graph
            indicate_ucp = experiment_params['mechanism'] != 'linear' if 'mechanism' in experiment_params else True
            marginal_gt = get_pag_skel_with_ada_orientations(ground_truth,
                                                             list(data.keys()),
                                                             indicate_ucp
                                                             )

            fig, axs = plt.subplots(1, 2 + len(params['algorithms']))
            pos = nx.spring_layout(ground_truth, k=2)
            nx.draw(ground_truth, pos=pos, ax=axs[0], with_labels=True)
            axs[0].set_title('GT')
            nx.draw(marginal_gt, pos=pos, ax=axs[1], with_labels=True)
            axs[1].set_title('Marginal GT')
            for j, algo in enumerate(params['algorithms']):
                g_hat = MixedGraph.load_graph(os.path.join(graph_dir, 'g_hat_{}_{}.gml'.format(algo, i))).graph
                nx.draw(g_hat, pos=pos, ax=axs[j + 2], with_labels=True)
                axs[j + 2].set_title(algo)

            plt.show()