"""
Should support plots for best performing parameters, both for GES and PC
"""

import os
import argparse
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from distutils.dir_util import copy_tree
from utils._metrics import get_metrics
from utils._utils import is_cpdag, dag_to_cpdag, tunable_parameters

LOGS_DIR = "/efs/tmp/causal-benchmark-logs/hyperparameters/"
WORKSPACE_LOGS = "/home/ec2-user/causal-benchmark/tmp/logs/hyperparameters/"


def plot_graphs(logs_dir, workspace_dir, config, method="pc"):
    """
    Create directory with ground_truth pdag and predicted cpdag. 
    In the directory, add raw logs (one raw for each sample dataset).
    Logging is done only over best performing hyperparameters
    
    Parameters
    ----------
    logs_dir : str
        Input and output directory
    config : str
        Data configuration of the graphs to visualize.
        E.g. "small_dense"
    method : str
        Method of inference of the predictions of interest
    """
    ##################### Utilities #####################
    def save_graph_plot(adj, path, confounded=False):
        plt.figure(figsize=(10,10))
        G = nx.from_numpy_array(adj, create_using=nx.DiGraph())
        pos = nx.spring_layout(G, seed=3113794652)  # positions for all nodes
        options = {"edgecolors": "gray", "node_size": 400, "alpha": 0.7}

        d = G.number_of_nodes()
        labels = {k : k for k in range(d)}
        if confounded:
            for k in range(d):
                if k < d/2:
                    labels[k] = f"C{k}"
                else:
                    labels[k] =k - int(d/2)
            nx.draw_networkx_nodes(G, pos, nodelist=range(d)[:int(d/2)], node_color="r", **options)
            nx.draw_networkx_nodes(G, pos, nodelist=range(d)[int(d/2):], node_color="g", **options)        
        else:
            nx.draw_networkx_nodes(G, pos, nodelist=range(d), node_color="g", **options)

        # edges
        nx.draw_networkx_edges(G, pos, width=1.0, alpha=0.5)

        # labels
        nx.draw_networkx_labels(G, pos, labels, font_size=13, font_color="whitesmoke")
        plt.savefig(path)
        plt.close('all')


    def sort_logs(output_file):
        unsorted_df = pd.read_csv(output_file)
        sorted_df = unsorted_df.sort_values(by=["seed_id"], ascending=[True])
        sorted_df.to_csv(output_file, index=False)

    def tunable_parameters(method):
        methods_parameters = {
            "ges" : "lambda",
            "pc" : "alpha",
            "das" : "alpha",
            "score" : "alpha",
            "cam" : "alpha",
            "nogam" : "alpha",
            "diffan" : "alpha"
        }
        return methods_parameters[method]

    def header(method):
        """
        Return header for the logging file
        """
        param_name =  tunable_parameters(method)
        header = f"seed_id,N,size,density,{param_name},shd,tpr,fpr,tnr,fnr,f1,time [s]"
        return header

    ##################### Logs processing #####################

    graphs_dir = os.path.join(logs_dir, "graphs")
    output_dir = os.path.join(graphs_dir, config)
    if not os.path.exists(output_dir):
        if not os.path.exists(graphs_dir):
            os.mkdir(graphs_dir)
        os.mkdir(output_dir)

    output_logs = os.path.join(output_dir, "logs.csv")
    with open(output_logs, "w") as f:
        f.write(header(method) + '\n')

    metadata_file = os.path.join(logs_dir, "metadata.csv")
    stats_file = os.path.join(logs_dir, f"stats_{method}.csv")
    metadata_df = pd.read_csv(metadata_file)
    stats_df = pd.read_csv(stats_file)

    # Select subset of logs of interest
    graph_size, graph_density = config.split("_")
    metadata_df = metadata_df[
        (metadata_df["size"]==graph_size) &
        (metadata_df["density"]==graph_density)
    ].reset_index()
    stats_df = stats_df[
        (stats_df["size"]==graph_size) &
        (stats_df["density"]==graph_density)
    ].reset_index()
    # Find best hypeparameter for which to plot the predictions
    best_param_idx = stats_df.loc[:, "f1"].apply(lambda x : float(x.split("+-")[0])).argmax()
    best_param = stats_df.loc[best_param_idx, tunable_parameters(method)]

    for i in range(len(metadata_df)):
        if metadata_df.loc[i,  tunable_parameters(method)] == best_param: 
            pred_path = metadata_df.loc[i, "pred_location"]
            gt_path = metadata_df.loc[i, "gt_location"]
            pred = np.genfromtxt(pred_path, delimiter=",")
            gt = np.genfromtxt(gt_path, delimiter=",")
            # cpdag_gt = dag_to_cpdag(gt)
            shd, tpr, fpr, tnr, fnr, f1 = get_metrics(is_cpdag(method), pred, gt) # should use cpdag_gt rather than gt
            metrics = [shd, tpr, fpr, tnr, fnr, f1]
            seed_id = metadata_df.loc[i, "seed_id"]
            graph_size= metadata_df.loc[i, "size"]
            graph_density = metadata_df.loc[i, "density"]
            param_value = metadata_df.loc[i,  tunable_parameters(method)]

            # Update logs
            with open(output_logs, "a") as f:
                f.write(",".join([str(val) for val in [seed_id, metadata_df.loc[i, "N"], graph_size, graph_density, param_value] + metrics]) + "\n")

            # Plot graphs
            pred_graph_path = os.path.join(output_dir, f"pred{seed_id}.png")
            gt_graph_path = os.path.join(output_dir, f"gt{seed_id}.png")
            save_graph_plot(pred, pred_graph_path)
            save_graph_plot(gt, gt_graph_path)
            
            # Additionally plot confounded graph in case of confounded scenario
            if scenario == "confounded":
                confounded_path = os.path.join(os.sep.join(gt_path.split(os.sep)[:-1]), f"confounded{seed_id}.csv")
                confounded_gt = np.genfromtxt(confounded_path, delimiter=",")
                confounded_graph_path = os.path.join(output_dir, f"confounded{seed_id}.png")
                save_graph_plot(confounded_gt, confounded_graph_path, confounded=True)
        
    sort_logs(output_logs)
    
    # Copy in workspace logs
    dest_dir = os.path.join(workspace_dir, "graphs")
    copy_tree(graphs_dir, dest_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Visualize predictions and groundtruth graphs")
    parser.add_argument(
        '--methods', 
        nargs='+', 
        help ='Mehods for which visualization is required', 
        type=str,
        required=True
    )
    parser.add_argument(
        "--scenario", 
        default="vanilla", 
        type=str, 
        help="Scenario for which visualization is required"
    )

    args = parser.parse_args()
    scenario = args.scenario

    data_configs = ["small_sparse", "small_dense"] # TODO: CLI argument
    for method in args.methods:
        for config in data_configs:
            plot_graphs(os.path.join(LOGS_DIR, scenario, method), os.path.join(WORKSPACE_LOGS, scenario, method),config, method)

