# The major idea of the overall GNN model explanation

import argparse
import os
from pathlib import Path

import torch
import torch.nn.functional as F
from dgl import load_graphs

from ex.dgl.models import GraphConvModel
from ex.dgl.explainer import NodeExplainerModule
from ex.dgl.graph_utils import extract_subgraph, visualize_subgraph


def main(args):
    ckpt_data = torch.load(args.model_path)
    model_stat_dict = ckpt_data['model_state']

    # load graph, feat, and label
    g_list, label_dict = load_graphs(str(args.input_path))
    graph = g_list[0]
    labels = graph.ndata['label']
    feats = graph.ndata['feat']
    num_classes = max(labels).item() + 1
    # feat_dim = feats.shape[1]
    # hid_dim = label_dict['hid_dim'].item()
    feat_dim = ckpt_data['feat_dim']
    hidden_dim = ckpt_data['hidden_dim']

    # create a model and load from state_dict
    dummy_model = GraphConvModel(feat_dim, hidden_dim, num_classes)
    dummy_model.load_state_dict(model_stat_dict)

    # Choose a node of the target class to be explained and extract its subgraph.
    # Here just pick the first one of the target class.
    target_list = [i for i, e in enumerate(labels) if e == args.target_class]
    n_idx = torch.tensor([target_list[0]])

    # Extract the computation graph within k-hop of target node and
    # use it for explainability
    sub_graph, ori_n_idxes, new_n_idx = extract_subgraph(
        graph, n_idx, hops=args.hop
    )

    # Sub-graph features.
    sub_feats = feats[ori_n_idxes, :]

    # create an explainer
    explainer = NodeExplainerModule(
        model=dummy_model,
        num_edges=sub_graph.number_of_edges(),
        node_feat_dim=feat_dim,
    )

    # define optimizer
    optim = torch.optim.Adam(
        [explainer.edge_mask, explainer.node_feat_mask],
        lr=args.lr,
        weight_decay=args.wd,
    )

    # train the explainer for the given node
    dummy_model.eval()
    model_logits = dummy_model(sub_graph, sub_feats)
    model_predict = F.one_hot(torch.argmax(model_logits, dim=-1), num_classes)

    for epoch in range(args.epochs):
        explainer.train()
        exp_logits = explainer(sub_graph, sub_feats)
        loss = explainer._loss(exp_logits[new_n_idx], model_predict[new_n_idx])

        optim.zero_grad()
        loss.backward()
        optim.step()

    # visualize the importance of edges
    edge_weights = explainer.edge_mask.sigmoid().detach()
    visualize_subgraph(sub_graph, edge_weights.numpy(), ori_n_idxes, n_idx)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Demo of GNN explainer in DGL'
    )
    parser.add_argument(
        '--dataset',
        type=str,
        default='syn1',
        choices=['syn1', 'syn2', 'syn3', 'syn4', 'syn5'],
    )
    parser.add_argument('--target-class', type=int, default='1')
    parser.add_argument('--hop', type=int, default='3')
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--input-path', type=Path, required=True)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--model-path', type=Path, required=True)
    parser.add_argument('--output-dir', type=Path, default='./output')
    parser.add_argument('--wd', type=float, default=0.0)
    args = parser.parse_args()

    main(args)
