""" explainer_main.py

     Main user interface for the explainer module.
"""
import os

import sklearn.metrics as metrics

from tensorboardX import SummaryWriter

import random

import pickle
import shutil
import torch
torch.cuda.empty_cache()

import models
import utils.io_utils as io_utils

from explainer3 import explain_rcexplainer
from explainer3 import explain_rcnoiseexplainer
from explainer3 import explain_rcexplainer_noldb

import utils.accuracy_utils3 as accuracy_utils

from gcn import *

import configs3

def main(prog_args):
    # Load a configuration
    # prog_args = configs.arg_parse()
    torch.manual_seed(prog_args.seed)
    random.seed(prog_args.seed)
    np.random.seed(prog_args.seed)

    if prog_args.gpu:
        torch.cuda.manual_seed(prog_args.seed)
        torch.cuda.manual_seed_all(prog_args.seed)
        os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda
        # print("CUDA", prog_args.cuda)
    # else:
        # print("Using CPU")

    # Configure the logging directory 
    if prog_args.writer:
        path = os.path.join(prog_args.logdir, io_utils.gen_explainer_prefix(prog_args))
        if os.path.isdir(path) and prog_args.clean_log:
          #  print('Removing existing log dir: ', path)
           if not input("Are you sure you want to remove this directory? (y/n): ").lower().strip()[:1] == "y": sys.exit(1)
           shutil.rmtree(path)
        writer = SummaryWriter(path)
    else:
        writer = None

    # Load a model checkpoint
    ckpt = io_utils.load_ckpt(prog_args)

    cg_dict = ckpt["cg"] # get computation graph
    # print("Loaded model from {}".format(prog_args.ckptdir))

    input_dim = cg_dict["feat"].shape[2] #n*nodes*dim
    num_classes = cg_dict["pred"].shape[2] #n*2 why?
    # print("input dim: ", input_dim, "; num classes: ", num_classes)

    # Determine explainer mode
    graph_mode = (
        prog_args.graph_mode
        or prog_args.multigraph_class >= 0
        or prog_args.graph_idx >= 0
    )
    if prog_args.gpu:
        device = 'cuda:0'
    else:
        device = 'cpu'

    # build model
    # print("Method: ", prog_args.method)
    if graph_mode:
        # Explain Graph prediction
        model = models.GcnEncoderGraph(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            pred_hidden_dims=[prog_args.pred_hidden_dim] * prog_args.pred_num_layers,
            bn=prog_args.bn,
            args=prog_args,
            device=device
        )
    else:
        # Explain Node prediction
        model = models.GcnEncoderNode(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            pred_hidden_dims=[prog_args.pred_hidden_dim] * prog_args.pred_num_layers,
            bn=prog_args.bn,
            args=prog_args,
            device=device
        )

    model = model.to(device)

    # load state_dict (obtained by model.state_dict() when saving checkpoint)
    model.load_state_dict(ckpt["model_state"])

    # Create explainer
    dict_num_nodes = None
    if "num_nodes" in cg_dict:
        dict_num_nodes = cg_dict["num_nodes"]
    try:
        if "num_nodes" in cg_dict:
            dict_num_nodes = torch.cat([dict_num_nodes, cg_dict["val_num_nodes"]])
    
        adj = torch.cat([cg_dict["adj"], cg_dict["val_adj"]])
        feat = torch.cat([cg_dict["feat"], cg_dict["val_feat"]])
        label = torch.cat([cg_dict["label"], cg_dict["val_label"]])
        pred = np.concatenate((cg_dict["pred"], cg_dict["val_pred"]), axis=1)
    except:
        adj = cg_dict["adj"]
        feat = cg_dict["feat"]
        label = cg_dict["label"]
        pred = cg_dict["pred"]
    
    if 'emb' not in cg_dict:
        cg_dict['emb'] = None

    elif prog_args.explainer_method == "rcexplainer":
        print('NORMAL EXPLAINER')
        explainer = explain_rcexplainer.ExplainerRCExplainer(
            model=model,
            adj=cg_dict["adj"],
            feat=cg_dict["feat"],
            label=cg_dict["label"],
            pred=cg_dict["pred"],
            emb=cg_dict["emb"],
            train_idx=cg_dict["train_idx"],
            num_nodes = dict_num_nodes,
            args=prog_args,
            writer=writer,
            print_training=True,
            graph_mode=graph_mode,
            graph_idx=prog_args.graph_idx,
            device=device
        )

    elif prog_args.explainer_method == "rcexp_noldb":
        explainer = explain_rcexplainer_noldb.ExplainerRCExplainerNoLDB(
            model=model,
            adj=cg_dict["adj"],
            feat=cg_dict["feat"],
            label=cg_dict["label"],
            emb=cg_dict["emb"],
            pred=cg_dict["pred"],
            train_idx=cg_dict["train_idx"],
            num_nodes=dict_num_nodes,
            args=prog_args,
            writer=writer,
            print_training=True,
            graph_mode=graph_mode,
            graph_idx=prog_args.graph_idx,
            device=device
        )

    elif prog_args.explainer_method == "gnnexplainer":
        explainer = explain_gnnexplainer.ExplainerGnnExplainer(
            model=model,
            adj=adj,
            feat=feat,
            label=label,
            pred=pred,
            train_idx=cg_dict["train_idx"],
            num_nodes = dict_num_nodes,
            args=prog_args,
            writer=writer,
            print_training=False,
            graph_mode=graph_mode,
            graph_idx=prog_args.graph_idx,
            device=device,
        )

    elif prog_args.explainer_method == "rcnoiseexplainer":
        # print('NOISEEXPLAINER')
        explainer = explain_rcnoiseexplainer.ExplainerRCExplainer(
            model=model,
            adj=cg_dict["adj"],
            feat=cg_dict["feat"],
            label=cg_dict["label"],
            pred=cg_dict["pred"],
            emb=cg_dict["emb"],
            train_idx=cg_dict["train_idx"],
            num_nodes = dict_num_nodes,
            args=prog_args,
            writer=writer,
            print_training=True,
            graph_mode=graph_mode,
            graph_idx=prog_args.graph_idx,
            device=device
        )
    range_g = range(3000)

    # TODO: API should definitely be cleaner
    # Let's define exactly which modes we support 
    # We could even move each mode to a different method (even file)
    if prog_args.explain_node is not None:
        explainer.explain(prog_args.explain_node, unconstrained=False)
    elif graph_mode:
        if prog_args.multigraph_class >= 0:  #explain particular class
            # print(cg_dict["label"])
            # only run for graphs with label specified by multigraph_class
            labels = cg_dict["label"]
            preds = np.argmax(cg_dict['pred'][0,:,:], axis=1)
            graph_indices = []
            for i, l in enumerate(preds):
                if l == prog_args.multigraph_class:
                    graph_indices.append(i)
                # if len(graph_indices) > 30:
                #     break

            print(
                # "Graph indices for label ",
                prog_args.multigraph_class,
                " : ",
            )

            orig_graph_indices=graph_indices

            if prog_args.train_data_sparsity is not None:
                graph_indices = random.sample(graph_indices, int(len(graph_indices) * prog_args.train_data_sparsity))
            explainer.explain_graphs(prog_args, graph_indices=graph_indices, test_graph_indices=orig_graph_indices)
        else:
            explainer.explain(
                node_idx=0,
                graph_idx=prog_args.graph_idx,
                graph_mode=True,
                unconstrained=False,
            )
            io_utils.plot_cmap_tb(writer, "tab20", 20, "tab20_cmap")
    else:
        if prog_args.multinode_class >= 0:
            print(cg_dict["label"])
            # only run for nodes with label specified by multinode_class
            labels = cg_dict["label"][0]  # already numpy matrix

            node_indices = []
            for i, l in enumerate(labels):
                if len(node_indices) > 4:
                    break
                if l == prog_args.multinode_class:
                    node_indices.append(i)
            print(
                "Node indices for label ",
                prog_args.multinode_class,
                " : ",
                node_indices,
            )
            explainer.explain_nodes(node_indices, prog_args)



if __name__ == "__main__":
    main(configs3.arg_parse())

