""" explainer_main.py

     Main user interface for the explainer module.
"""
import argparse
import os

import sklearn.metrics as metrics

from tensorboardX import SummaryWriter

import pickle
import shutil
import torch

import models
import utils.io_utils as io_utils
import utils.parser_utils as parser_utils
# from explainer import explain_node_bondary as explain
# from explainer import explain_boundary_nn as explain

# from explainer import explain_boundary_joint_group as explain
from explainer import explain as explain
# from explainer import explain_boundary_joint as explain


from genome_models import *
from gcn import *


def arg_parse():
    parser = argparse.ArgumentParser(description="GNN Explainer arguments.")
    io_parser = parser.add_mutually_exclusive_group(required=False)
    io_parser.add_argument("--dataset", dest="dataset", help="Input dataset.")
    benchmark_parser = io_parser.add_argument_group()
    benchmark_parser.add_argument(
        "--bmname", dest="bmname", help="Name of the benchmark dataset"
    )
    io_parser.add_argument("--pkl", dest="pkl_fname", help="Name of the pkl data file")

    parser_utils.parse_optimizer(parser)

    parser.add_argument("--clean-log", action="store_true", help="If true, cleans the specified log directory before running.")
    parser.add_argument("--logdir", dest="logdir", help="Tensorboard log directory")
    parser.add_argument("--ckptdir", dest="ckptdir", help="Model checkpoint directory")
    parser.add_argument("--cuda", dest="cuda", help="CUDA.")
    parser.add_argument(
        "--gpu",
        dest="gpu",
        action="store_const",
        const=True,
        default=True,
        help="whether to use GPU.",
    )
    parser.add_argument(
        "--epochs", dest="num_epochs", type=int, help="Number of epochs to train."
    )
    parser.add_argument(
        "--hidden-dim", dest="hidden_dim", type=int, help="Hidden dimension"
    )

    parser.add_argument(
        "--output-dim", dest="output_dim", type=int, help="Output dimension"
    )
    parser.add_argument(
        "--num-gc-layers",
        dest="num_gc_layers",
        type=int,
        help="Number of graph convolution layers before each pooling",
    )
    parser.add_argument(
        "--bn",
        dest="bn",
        action="store_const",
        const=True,
        default=False,
        help="Whether batch normalization is used",
    )
    parser.add_argument("--dropout", dest="dropout", type=float, help="Dropout rate.")
    parser.add_argument(
        "--nobias",
        dest="bias",
        action="store_const",
        const=False,
        default=True,
        help="Whether to add bias. Default to True.",
    )
    parser.add_argument(
        "--no-writer",
        dest="writer",
        action="store_const",
        const=False,
        default=True,
        help="Whether to add bias. Default to True.",
    )
    # Explainer
    parser.add_argument("--mask-act", dest="mask_act", type=str, help="sigmoid, ReLU.")
    parser.add_argument(
        "--mask-bias",
        dest="mask_bias",
        action="store_const",
        const=True,
        default=False,
        help="Whether to add bias. Default to True.",
    )
    parser.add_argument(
        "--explain-node", dest="explain_node", type=int, help="Node to explain."
    )
    parser.add_argument(
        "--graph-idx", dest="graph_idx", type=int, help="Graph to explain."
    )
    parser.add_argument(
        "--graph-mode",
        dest="graph_mode",
        action="store_const",
        const=True,
        default=False,
        help="whether to run Explainer on Graph Classification task.",
    )
    parser.add_argument(
        "--multigraph-class",
        dest="multigraph_class",
        type=int,
        help="whether to run Explainer on multiple Graphs from the Classification task for examples in the same class.",
    )
    parser.add_argument(
        "--multinode-class",
        dest="multinode_class",
        type=int,
        help="whether to run Explainer on multiple nodes from the Classification task for examples in the same class.",
    )
    parser.add_argument(
        "--align-steps",
        dest="align_steps",
        type=int,
        help="Number of iterations to find P, the alignment matrix.",
    )

    parser.add_argument(
        "--fname", dest="fname", type=str, help="result file"
    )
    parser.add_argument(
        "--lap_c", dest="lap_c", type=float, help="laplacian coeffecient"
    )
    parser.add_argument(
        "--ent_c", dest="ent_c", type=float, help="entropy coeffecient"
    )
    parser.add_argument(
        "--size_c", dest="size_c", type=float, help="size coeffecient"
    )

    parser.add_argument(
        "--method", dest="method", type=str, help="Method. Possible values: base, att."
    )
    parser.add_argument(
        "--name-suffix", dest="name_suffix", help="suffix added to the output filename"
    )
    parser.add_argument(
        "--add_embedding", dest="add_embedding", default=False, help="add embedding layer "
    )

    parser.add_argument(
        "--explainer-suffix",
        dest="explainer_suffix",
        help="suffix added to the explainer log",
    )

    # TODO: Check argument usage
    parser.set_defaults(
        logdir="log",
        ckptdir="ckpt",
        dataset="syn1",
        opt="adam",  
        opt_scheduler="none",
        cuda="1",
        lr=0.1,
        clip=2.0,
        batch_size=20,
        num_epochs=100,#100,
        hidden_dim=20,
        output_dim=20,
        num_gc_layers=3,#1,3
        dropout=0.0,
        method="base",
        name_suffix="",
        explainer_suffix="",
        align_steps=1000,
        explain_node=None,
        graph_idx=-1,
        mask_act="sigmoid",
        multigraph_class=-1,
        multinode_class=-1,
        add_embedding = False
    )
    return parser.parse_args()


def main():
    # Load a configuration
    prog_args = arg_parse()
    datasets = ['syn4','syn1', 'syn2', 'syn3']
    # datasets = ['dense_syn4']

    # size_coeff = [0.001, 0.003, 0.006, 0.009, 0.015, 0.025, 0.04, 0.06]
    # lap_coeff = [0., 1.0, 4.0]
    # ent_coeff = [0., 1.0, 4.0]

    # size_coeff = [0.000005, 0.00005, 0.0005, 0.005, 0.05, 0.5]
    # size_coeff = [0.0000005, 0.000001, 0.000005, 0.000008]
    size_coeff = [0.005, 0.05, 0.0005, 0.00005]

    # size_coeff = [0.00008, 0.0001, 0.0004, 0.0008]
    # size_coeff = [0.0006, 0.0008, 0.001, 0.003, 0.006, 0.009]
    lap_coeff = [0.5]
    ent_coeff = [1.0]

    for ds in datasets:
        prog_args.fname =  ds + ".txt"

        for sz_c in size_coeff:
            for lap_c in lap_coeff:
                for ent_c in ent_coeff:
                    prog_args.bmname = ds
                    prog_args.size_c = sz_c
                    prog_args.lap_c = lap_c
                    prog_args.ent_c = ent_c

                    print(prog_args.bmname, " ", prog_args.size_c, " ", prog_args.lap_c, " ", prog_args.ent_c)



                    if prog_args.gpu:
                        os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda
                        print("CUDA", prog_args.cuda)
                    else:
                        print("Using CPU")

                    # Configure the logging directory
                    if prog_args.writer:
                        path = os.path.join(prog_args.logdir, io_utils.gen_explainer_prefix(prog_args))
                        if os.path.isdir(path) and prog_args.clean_log:
                           print('Removing existing log dir: ', path)
                           if not input("Are you sure you want to remove this directory? (y/n): ").lower().strip()[:1] == "y": sys.exit(1)
                           shutil.rmtree(path)
                        writer = SummaryWriter(path)
                    else:
                        writer = None

                    # Load a model checkpoint
                    if prog_args.dataset == 'genome':
                        cg_dict = pickle.load(open("./genome_data/cg_genome_test_80.p", "rb"))
                        cg_dict['train_idx'] = list(range(1000))
                        # cg_dict = torch.load("./genome_data/cg_genome_tensor.pth")

                    else:
                        ckpt = io_utils.load_ckpt(prog_args)
                        cg_dict = ckpt["cg"] # get computation graph
                        print("Loaded model from {}".format(prog_args.ckptdir))

                    input_dim = cg_dict["feat"].shape[2] #n*nodes*dim
                    num_classes = cg_dict["pred"].shape[2] #n*2 why?
                    print("input dim: ", input_dim, "; num classes: ", num_classes)

                    # Determine explainer mode
                    graph_mode = (
                        prog_args.graph_mode
                        or prog_args.multigraph_class >= 0
                        or prog_args.graph_idx >= 0
                    )

                    # build model
                    print("Method: ", prog_args.method)
                    if graph_mode:
                        if prog_args.dataset == 'genome':
                            model = GCN(name='genome_gcn',
                                        cuda=True,
                                        num_epochs=100,
                                        batch_size=1,
                                        num_layer=2,
                                        channels=32,
                                        embedding=64, dropout=True, score=980)
                        else:
                            # Explain Graph prediction
                            model = models.GcnEncoderGraph(
                                input_dim=input_dim,
                                hidden_dim=prog_args.hidden_dim,
                                embedding_dim=prog_args.output_dim,
                                label_dim=num_classes,
                                num_layers=prog_args.num_gc_layers,
                                bn=prog_args.bn,
                                args=prog_args,
                            )
                    else:
                        if prog_args.dataset == "ppi_essential":
                            # class weight in CE loss for handling imbalanced label classes
                            prog_args.loss_weight = torch.tensor([1.0, 5.0], dtype=torch.float).cuda()
                        # Explain Node prediction
                        model = models.GcnEncoderNode(
                            input_dim=input_dim,
                            hidden_dim=prog_args.hidden_dim,
                            embedding_dim=prog_args.output_dim,
                            label_dim=num_classes,
                            num_layers=prog_args.num_gc_layers,
                            bn=prog_args.bn,
                            args=prog_args,
                        )
                    if prog_args.gpu:
                        model = model.cuda()
                    # load state_dict (obtained by model.state_dict() when saving checkpoint)
                    if prog_args.dataset == 'genome':
                        model.setup_model(cg_dict["feat"],cg_dict["label"],cg_dict["adj"])
                        model.load_state_dict(torch.load("./genome_data/gcn_80.pth"))


                    else:
                        model.load_state_dict(ckpt["model_state"])

                    # Create explainer
                    dict_num_nodes = None
                    if "num_nodes" in cg_dict:
                        dict_num_nodes = cg_dict["num_nodes"]

                    explainer = explain.Explainer(
                        model=model,
                        adj=cg_dict["adj"],
                        feat=cg_dict["feat"],
                        label=cg_dict["label"],
                        pred=cg_dict["pred"],
                        train_idx=cg_dict["train_idx"],
                        num_nodes = dict_num_nodes,
                        args=prog_args,
                        writer=writer,
                        print_training=True,
                        graph_mode=graph_mode,
                        graph_idx=prog_args.graph_idx
                    )

                    # TODO: API should definitely be cleaner
                    # Let's define exactly which modes we support
                    # We could even move each mode to a different method (even file)
                    if prog_args.explain_node is not None:
                        explainer.explain(prog_args.explain_node, unconstrained=False)
                    elif graph_mode:
                        if prog_args.multigraph_class >= 0:  #explain particular class
                            # print(cg_dict["label"])
                            # only run for graphs with label specified by multigraph_class
                            labels = cg_dict["label"]
                            preds = np.argmax(cg_dict['pred'][0,:,:], axis=1)
                            graph_indices = []
                            for i, l in enumerate(preds):
                                if l == prog_args.multigraph_class:
                                    graph_indices.append(i)
                                # if len(graph_indices) > 30:
                                #     break

                            print(
                                "Graph indices for label ",
                                prog_args.multigraph_class,
                                " : ",
                                graph_indices,
                            )
                            explainer.explain_graphs(graph_indices=graph_indices)

                        elif prog_args.graph_idx == -1:
                            # just run for a customized set of indices
                            # explainer.explain_graphs(graph_indices=[1, 2, 3, 4])
                            # explainer.explain_graphs(graph_indices=[673, 940, 1578, 1617])
                            explainer.explain_graphs(graph_indices=[74, 274, 374, 474, 574, 674, 774, 874])


                        else:
                            explainer.explain(
                                node_idx=0,
                                graph_idx=prog_args.graph_idx,
                                graph_mode=True,
                                unconstrained=False,
                            )
                            io_utils.plot_cmap_tb(writer, "tab20", 20, "tab20_cmap")
                    else:
                        if prog_args.multinode_class >= 0:
                            print(cg_dict["label"])
                            # only run for nodes with label specified by multinode_class
                            labels = cg_dict["label"][0]  # already numpy matrix

                            node_indices = []
                            for i, l in enumerate(labels):
                                if len(node_indices) > 4:
                                    break
                                if l == prog_args.multinode_class:
                                    node_indices.append(i)
                            print(
                                "Node indices for label ",
                                prog_args.multinode_class,
                                " : ",
                                node_indices,
                            )
                            explainer.explain_nodes(node_indices, prog_args)

                        else:
                            # explain a set of nodes
                            # masked_adj = explainer.explain_nodes_gnn_stats(
                            #     range(400, 700, 5), prog_args
                            # )
                            if prog_args.bmname == "syn1":
                                masked_adj = explainer.explain_nodes_gnn_stats(
                                    range(400, 670, 1), prog_args
                                )

                            elif prog_args.bmname == "syn2":
                                node_list = list(range(400, 700, 1)) + list(range(1100, 1400, 1))
                                # node_list = list(range(0, 1400, 1))

                                masked_adj = explainer.explain_nodes_gnn_stats(
                                    node_list, prog_args
                                )

                            elif prog_args.bmname == "syn3" or prog_args.bmname == "repeat_syn3":
                                prog_args.bmname = "syn3"
                                node_list = list(range(300, 1020, 1))

                                masked_adj = explainer.explain_nodes_gnn_stats(
                                    node_list, prog_args
                                )
                            elif prog_args.bmname == "syn4" or prog_args.bmname == "dense_syn4":
                                prog_args.bmname = "syn4"

                                # node_list = list(range(511, 521, 1))

                                node_list = list(range(511, 871, 1))
                                # node_list = list(range(511, 1051, 1))
                                # node_list = list(range(2047, 2767, 1))

                                masked_adj = explainer.explain_nodes_gnn_stats(
                                    node_list, prog_args
                                )


                            # masked_adj = explainer.explain_nodes_gnn_stats(
                            #     range(0, 1020, 1), prog_args
                            # )
                            # masked_adj = explainer.explain_nodes_gnn_stats(
                            #     range(400, 450, 1), prog_args
                            # )

if __name__ == "__main__":
    main()

