from collections import Counter
import numpy as np
import argparse
import torch
import torch.nn.functional as F
from models.stone_age_gnn import StoneAgeGNN
from models.stone_age_gnn_dt import StoneAgeGNNDT
from sklearn.model_selection import ShuffleSplit, StratifiedKFold
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import compute_class_weight
from torch_geometric.loader import DataLoader
from tqdm import trange
from datasets.data_loader import load_dataset
from utils.feature_extractor import extract_features
from utils.prune_tree import prune_trees_all_val, get_num_nodes, scale_pooling_num_node_samples
from utils.utils import get_linear_features_used, linear_combo_features
from collections import defaultdict
from datasets.data_loader import neighbors
from ogb.graphproppred import Evaluator
import time
from torch_geometric.utils import degree
from torch_geometric.utils import subgraph
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.data import Data
from torch_geometric.utils import remove_isolated_nodes

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def score_model_mask(m, d, mask):
    m.eval()
    pred = m(d.x, d.edge_index).argmax(dim=-1)
    acc = int((pred[mask] == d.y[mask]).sum()) / len(mask)
    #acc = int((pred[mask] == d.y[mask].squeeze(1)).sum()) / len(mask)
    return acc


def train_eval_model_cv(dataset, dataset_name, params, val_size=0.1, use_pooling=True, debug=False, folds=10,
                        gumbel_noise=True, dataset_mask=False):
    print(f'----- {dataset_name} -----')
    skf = StratifiedKFold(n_splits=folds, shuffle=True, random_state=42)
    idx_list = []
    if use_pooling:
        labels = [graph.y[0] for graph in dataset]
    elif dataset_mask:
        labels = dataset[0].y.detach().numpy()
        if dataset_name == "OGBA":
            labels = labels.flatten()
    else:
        labels = [0 for _ in dataset]
    if dataset_name not in ["OGB-molhiv", "OGB-ppa", "OGB-code2"]:
        for idx in skf.split(np.zeros(len(labels)), labels):
            idx_list.append(idx)
    else: idx_list = [([], []), ([], [])]
    gnn_accs = []
    gnn_dt_accs = []
    after_pruning_accs = []
    num_nodes_remaining = []
    depth_before = []
    depth_after = []
    explanation_accs = []
    num_nodes_before = []
    dt_only_accs = []

    for fold in range(folds):
        print(f'----- Fold {fold + 1}/{folds} -----')
        train_idx, test_idx = idx_list[fold]
        valid_idx = None

        if dataset_name == "OGBA":
            split_idx = dataset.get_idx_split()
            train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]

        if dataset_mask:
            train_dataset = train_idx
            val_dataset = valid_idx
            test_dataset = test_idx
        else:
            train_dataset = dataset[train_idx]
            val_dataset = None
            test_dataset = dataset[test_idx]
            if dataset_name in  ["OGB-molhiv", "OGB-ppa", "OGB-code2"]:
                split_idx = dataset.get_idx_split()
                train_dataset = dataset[split_idx["train"]]
                val_dataset = dataset[split_idx["valid"]]
                test_dataset = dataset[split_idx["valid"]]

        if params["onlydt"]:
            # Graph classification
            if dataset_name in ["MUTAG", "BA_2Motifs", "PROTEINS", "IMDB-BINARY", "REDDIT-BINARY",
                                "Mutagenicity", "BBBP", "IMDB-BINARY", "COLLAB"]:
                X = []
                y = []
                for graph in train_dataset:
                    #print(graph)
                    degrees = list(degree(graph.edge_index[0], graph.num_nodes))
                    if params["withdegrees"]:
                        X.append(torch.cat([torch.tensor([degrees.count(i) for i in range(50)]),
                                        torch.sum(graph.x, dim=0)], dim=0))
                    else:
                        X.append(torch.sum(graph.x, dim=0))
                    y.append(graph.y.item())
                X = np.array(torch.stack(X))
                clf = DecisionTreeClassifier(random_state=0, max_leaf_nodes=params["max_leaf_nodes"])
                clf.fit(X, y)

                X = []
                y = []
                for graph in test_dataset:
                    degrees = list(degree(graph.edge_index[0], graph.num_nodes))
                    if params["withdegrees"]:
                        X.append(torch.cat([torch.tensor([degrees.count(i) for i in range(50)]),
                                        torch.sum(graph.x, dim=0)], dim=0))
                    else:
                        X.append(torch.sum(graph.x, dim=0))
                    y.append(graph.y.item())
                X = np.array(torch.stack(X))
                dt_only_accs.append(clf.score(X, y))
                continue
            # Node classification
            else:
                degrees = degree(dataset[0].edge_index[0], dataset[0].num_nodes).unsqueeze(1)
                if params["withdegrees"]:
                    traindata = torch.cat([dataset[0].x, degrees], dim=1)
                    testdata = torch.cat([dataset[0].x, degrees], dim=1)
                else:
                    traindata = dataset[0].x
                    testdata = dataset[0].x
                y = dataset[0].y
                if dataset_name not in ["Saturation", "Infection"]:
                    traindata = traindata[train_idx, :]
                    testdata = testdata[test_idx, :]
                    trainy = y[train_idx]
                    testy = y[test_idx]
                else:
                    testdegrees = degree(dataset[6].edge_index[0], dataset[0].num_nodes).unsqueeze(1)
                    testdata = torch.cat([dataset[6].x, testdegrees], dim=1) if params["withdegrees"] else dataset[6].x
                    trainy = y
                    testy = dataset[6].y
                clf = DecisionTreeClassifier(random_state=0, max_leaf_nodes=params["max_leaf_nodes"])
                clf.fit(traindata, trainy)
                dt_only_accs.append(clf.score(testdata, testy))
                continue

        gnn_accs_fold, gnn_dt_accs_fold, after_pruning_accs_fold, num_nodes_remaining_fold, explanation_acc_fold, extras = train_model(
            dataset,
            dataset_name,
            params,
            val_size=val_size,
            train_val_dataset=train_dataset,
            val_dataset=val_dataset,
            test_dataset=test_dataset,
            use_pooling=use_pooling,
            debug=debug,
            gumbel_noise=gumbel_noise,
            dataset_mask=dataset_mask,
            fold=fold)
        gnn_accs.append(gnn_accs_fold)
        gnn_dt_accs.append(gnn_dt_accs_fold)
        after_pruning_accs.append(after_pruning_accs_fold)
        num_nodes_remaining.append(num_nodes_remaining_fold)
        depth_before.append(extras["depth_before"])
        depth_after.append(extras["depth_after"])
        num_nodes_before.append(extras["num_nodes_before"])
        explanation_accs.append(explanation_acc_fold)

    if params["onlydt"]:
        print("DT accuracies", dt_only_accs)
        print(np.mean(dt_only_accs), np.std(dt_only_accs))


    print('---------- CV Results ----------')
    print('---------- GNN ----------')
    print("Acc.:", np.mean(gnn_accs, axis=0), "Std.:", np.std(gnn_accs, axis=0))
    print('---------- GNN DT ----------')
    print("Acc.:", np.mean(gnn_dt_accs, axis=0), "Std.:", np.std(gnn_dt_accs, axis=0))
    print('---------- Prunned GNN DT ----------')
    print("Acc.:", np.mean(after_pruning_accs, axis=0), "Std.:", np.std(after_pruning_accs, axis=0))
    print('---------- Nodes in Trees (Before) ----------')
    print("Nodes:", np.mean(num_nodes_before), "Std.:", np.std(num_nodes_before))
    print('---------- Nodes in Trees (After) ----------')
    print("Nodes:", np.mean(num_nodes_remaining), "Std.:", np.std(num_nodes_remaining))
    print('---------- Tree Depth (Before) ----------')
    print("Nodes:", np.mean(depth_before), "Std.:", np.std(depth_before))
    print('---------- Tree Depth (After) ----------')
    print("Nodes:", np.mean(depth_after), "Std.:", np.std(depth_after))
    print('---------- Explanation Acc ----------')
    print("Nodes:", np.mean(explanation_accs), "Std.:", np.std(explanation_accs))


def train_model(dataset, dataset_name, params, test_size=0.1, val_size=0.1, train_val_dataset=None, val_dataset=None,
                test_dataset=None, use_pooling=True, debug=False, gumbel_noise=True, dataset_mask=False, fold=0):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data = dataset[0]
    data = data.to(device)
    if test_dataset is None:
        sss = ShuffleSplit(n_splits=1, test_size=test_size, random_state=41)
        X = [data.x for data in dataset]
        y = [data.y for data in dataset]
        sss.get_n_splits(X, y)
        train_index, test_index = next(sss.split(X, y))
        train_val_dataset = dataset[train_index]
        test_dataset = dataset[test_index]
    if dataset_name in  ["OGB-molhiv", "OGB-ppa", "OGB-code2"]:
        train_loader = DataLoader(train_val_dataset, batch_size=32, shuffle=False)
        train_dataset = train_val_dataset
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    elif dataset_mask:
        if val_dataset is None:
            sss = ShuffleSplit(n_splits=1, test_size=val_size, random_state=41)
            X = [data.x[index] for index in train_val_dataset]
            y = [data.y[index] for index in train_val_dataset]
            sss.get_n_splits(X, y)
            train_index, val_index = next(sss.split(X, y))
            train_mask = train_val_dataset[train_index]
            val_mask = train_val_dataset[val_index]
        else:
            train_mask = train_val_dataset
            val_mask = val_dataset
        test_mask = test_dataset
    else:
        sss = ShuffleSplit(n_splits=1, test_size=val_size, random_state=41)
        X = [data.x for data in train_val_dataset]
        y = [data.y for data in train_val_dataset]
        sss.get_n_splits(X, y)
        train_index, val_index = next(sss.split(X, y))
        train_dataset = train_val_dataset[train_index]
        val_dataset = train_val_dataset[val_index]

        train_loader = DataLoader(train_dataset, batch_size=params["batch_size"])
        val_loader = DataLoader(val_dataset, batch_size=params["batch_size"])
        test_loader = DataLoader(test_dataset, batch_size=1)#params["batch_size"])

    if use_pooling:
        y = [graph.y.detach().numpy()[0] for graph in dataset]
        if dataset_name in  ["OGB-molhiv", "OGB-ppa", "OGB-code2"]:
            y = [label[0] for label in y]
    else:
        y = []
        for graph in dataset:
            labels = graph.y.detach().numpy()
            if dataset_name == "OGBA":
                labels = labels.flatten()
            y += list(labels)
    classes = np.unique(y)
    class_weights = np.ones(dataset.num_classes)
    computed_class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=y)
    if len(computed_class_weights) < dataset.num_classes:
        for i in range(len(classes)):
            class_weights[classes[i]] = computed_class_weights[i]
    else:
        class_weights = computed_class_weights
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    model = StoneAgeGNN(dataset.num_node_features, dataset.num_classes, bounding_parameter=params["bounding_parameter"],
                        state_size=params["state_space"], num_layers=params["number_of_layers"],
                        gumbel=params["gumbel"], softmax_temp=params["softmax_temp"], network=params["network"],
                        use_pooling=use_pooling, skip_connection=params["skip_connection"], dropout=params["dropout"],
                        hidden_units=params["hidden_units"]).to(device)
    if gumbel_noise:
        model.set_beta(1.0)
    optimizer = torch.optim.Adam(model.parameters(), lr=params["learning_rate"])

    def score_model(m, d):
        m.eval()
        total_correct = 0
        all_pred = []
        all_true = []
        for data in d:
            data.to(device)
            out = m(data.x, data.edge_index, data.batch)
            correct = int((out.argmax(-1) == data.y).sum())
            if not use_pooling:
                correct /= len(data.y)
            total_correct += correct
            all_pred.append(torch.unsqueeze(out.argmax(-1), dim=1))
            all_true.append(data.y)
        acc = total_correct / len(d.dataset)
        if dataset_name == "OGB-molhiv":
            evaluator = Evaluator(name="ogbg-molhiv")
            input_dict = {"y_true": torch.cat(all_true, dim=0),
                          "y_pred": torch.cat(all_pred, dim=0)}
            acc = evaluator.eval(input_dict)["rocauc"]
        if dataset_name == "OGB-code2":
            evaluator = Evaluator(name="ogbg-code2")
            input_dict = {"y_true": torch.cat(all_true, dim=0),
                          "y_pred": torch.cat(all_pred, dim=0)}
            acc = evaluator.eval(input_dict)["rocauc"]
        if dataset_name == "OGB-ppa":
            evaluator = Evaluator(name="ogbg-ppa")
            input_dict = {"y_true": torch.cat(all_true, dim=0),
                          "y_pred": torch.cat(all_pred, dim=0)}
            acc = evaluator.eval(input_dict)["acc"]
        return acc

    def train_with_mask(mask):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        if (dataset_name == "OGBA"):
            loss = F.nll_loss(out[mask], data.y[mask].squeeze(1))
        else:
            loss = F.nll_loss(out[mask], data.y[mask])
        loss.backward()
        optimizer.step()
        return float(loss)

    @torch.no_grad()
    def test_with_mask(mask):
        model.eval()
        model.set_argmax(True)
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=-1)
        if (dataset_name == "OGBA"):
            loss = F.nll_loss(out[mask], data.y[mask].squeeze(1))
            acc = int((pred[mask] == data.y[mask].squeeze(1)).sum()) / len(mask)
        else:
            loss = F.nll_loss(out[mask], data.y[mask])
            acc = int((pred[mask] == data.y[mask]).sum()) / len(mask)
        #loss = F.nll_loss(out[mask], data.y[mask])
        #acc = int((pred[mask] == data.y[mask].squeeze(1)).sum()) / len(mask)
        model.set_argmax(False)
        return acc, loss

    def train(loader):
        model.train()
        total_loss = 0
        for data in loader:
            data = data.to(device)
            optimizer.zero_grad()
            output = model(data.x, data.edge_index, data.batch)

            if dataset_name in ["OGB-molhiv", "OGB-ppa", "OGB-code2"]:
                loss = F.nll_loss(output, torch.flatten(data.y), weight=class_weights)
            else:
                loss = F.nll_loss(output, data.y, weight=class_weights)
            loss.backward()
            optimizer.step()
            total_loss += float(loss) * data.num_graphs
        return total_loss / len(loader.dataset)

    @torch.no_grad()
    def test(loader):
        model.eval()
        model.set_argmax(True)
        total_correct = 0
        total_loss = 0
        for data in loader:
            data = data.to(device)
            output = model(data.x, data.edge_index, data.batch)
            if dataset_name in ["OGB-molhiv", "OGB-ppa", "OGB-code2"]:
                y = torch.flatten(data.y)
            else:
                y = data.y
            loss = F.nll_loss(output, y, weight=class_weights)
            correct = int((output.argmax(-1) == y).sum())
            if not use_pooling:
                correct /= len(y)
            total_correct += correct
            total_loss += float(loss) * data.num_graphs
        model.set_argmax(False)
        return total_correct / len(loader.dataset), total_loss / len(loader.dataset)

    model_save_name = dataset_name
    early_stopping_enabled = True
    es_patience = params["es_patience"]
    es_counter = 0
    best_test_acc = 0.0
    best_val_loss = np.inf
    best_val_acc = 0
    pbar = trange(1, params["epochs"] + 1)
    print(time.strftime('%X %x %Z'))
    for epoch in pbar:
        if dataset_mask:
            loss = train_with_mask(train_mask)
            train_acc, _ = test_with_mask(train_mask)
            val_acc, val_loss = test_with_mask(val_mask)
            test_acc, _ = test_with_mask(test_mask)
        else:
            loss = train(train_loader)
            train_acc, _ = test(train_loader)
            val_acc, val_loss = test(val_loader)
            test_acc = score_model(model, test_loader)
        if val_loss < best_val_loss:
            es_counter = 0
            best_val_loss = val_loss
            best_val_acc = val_acc
            best_test_acc = test_acc
            torch.save(model, f'{params["data_dir"]}/model_checkpoints/{model_save_name}.pt')

        pbar.set_description(f'Epoch: {epoch:04d}, Loss: {loss:.3f} Train: {train_acc:.3f},'
                             f' Val: {val_acc:.3f}, Test: {test_acc:.3f},'
                             f' Best Val|Test Acc.:  {best_val_acc:.3f} | {best_test_acc:.3f}')

        if early_stopping_enabled and es_counter > es_patience:
            print("-- Early Stopping --")
            break
        es_counter += 1
    print(time.strftime('%X %x %Z'))
    print('---------- GNN (Acc, Loss) ----------')
    device = torch.device('cpu')
    model = torch.load(f'{params["data_dir"]}/model_checkpoints/{model_save_name}.pt').to(device)
    class_weights = class_weights.to(device)
    data = data.to(device)

    if dataset_mask:
        gnn_accs = [test_with_mask(train_mask)[0], test_with_mask(val_mask)[0],
                    test_with_mask(test_mask)[0]]
        print(
            f'Train: {test_with_mask(train_mask)}, Val: {test_with_mask(val_mask)}, Test: {test_with_mask(test_mask)}')
    else:
        gnn_accs = [score_model(model, train_loader), score_model(model, val_loader),
                       score_model(model, test_loader)]
        print("Train:", gnn_accs[0], "Val:", gnn_accs[1], "Test:", gnn_accs[2])
    # Extract input/output Tree dataset from trained GNN
    layer_names = ['input']
    for i in range(params["number_of_layers"]):
        layer_names.append(f"stone_age.{i}.linear_softmax")
    layer_names.append('output')

    model.eval()
    model.set_gumbel(False)
    model.set_argmax(True)

    if dataset_mask:
        input_outputs_train = extract_features(model, layer_names, data, device, mask=train_mask)
        input_outputs_val = extract_features(model, layer_names, data, device, mask=val_mask)
        input_outputs_test = extract_features(model, layer_names, data, device, mask=test_mask)
    else:
        input_outputs_train = extract_features(model, layer_names, train_dataset, device,
                                               batch_size=params['batch_size'])
        input_outputs_test = extract_features(model, layer_names, test_dataset, device, batch_size=params['batch_size'])
        input_outputs_val = extract_features(model, layer_names, val_dataset, device, batch_size=params['batch_size'])

    trees = {}
    if debug:
        print('---------- Tree Test Accuracy ----------')
    print("Start tree training", time.strftime('%X %x %Z'))
    for layer_name in layer_names:
        data_input_train = input_outputs_train[layer_name]["inputs"]
        data_output_train = np.argmax(input_outputs_train[layer_name]["outputs"], axis=1)

        data_input_test = input_outputs_test[layer_name]["inputs"]
        data_output_test = np.argmax(input_outputs_test[layer_name]["outputs"], axis=1)

        data_input_val = input_outputs_val[layer_name]["inputs"]
        data_output_val = np.argmax(input_outputs_val[layer_name]["outputs"], axis=1)

        if layer_name != 'input' and layer_name != 'output':
            num_features = len(data_input_train[0]) // 2
            data_input_train = linear_combo_features(data_input_train, num_features)
            data_input_test = linear_combo_features(data_input_test, num_features)
            data_input_val = linear_combo_features(data_input_val, num_features)

        if layer_name == 'output' and use_pooling:
            num_features = len(data_input_train[0])
            data_input_train = linear_combo_features(data_input_train, num_features)
            data_input_test = linear_combo_features(data_input_test, num_features)
            data_input_val = linear_combo_features(data_input_val, num_features)
        clf = DecisionTreeClassifier(random_state=0, max_leaf_nodes=params["max_leaf_nodes"])
        clf.fit(data_input_train, data_output_train)

        if layer_name == 'output' and use_pooling:
            clf = scale_pooling_num_node_samples(clf, data_input_train, train_dataset)

        trees[layer_name] = clf

        if debug:
            print(
                f"{layer_name}: Train: {clf.score(data_input_train, data_output_train)}, Test: {clf.score(data_input_test, data_output_test)}")
            print("Feature Importance:", clf.feature_importances_)
            print("Output Classes:", Counter(data_output_test))
            print("Tree Depth:", clf.get_depth())

            if layer_name != 'input':
                print("Linear Features Uses:", get_linear_features_used(clf, num_features))

    # Build & evaluate decision tree GNN model
    model_dt = StoneAgeGNNDT(dataset.num_node_features, dataset.num_classes,
                             bounding_parameter=params["bounding_parameter"],
                             trees=trees,
                             num_layers=params["number_of_layers"],
                             state_size=params["state_space"],
                             use_pooling=use_pooling,
                             skip_connection=params["skip_connection"],
                             linear_feature_combinations=True).to(device)

    model_dt.eval()
    print(time.strftime('%X %x %Z'))
    print('---------- DT GNN Acc ----------')
    print("Tree Depths:", [trees[layer_name].get_depth() for layer_name in layer_names])
    if dataset_mask:
        gnn_dt_accs = [score_model_mask(model_dt, data, train_mask), score_model_mask(model_dt, data, val_mask),
                       score_model_mask(model_dt, data, test_mask)]
        print("Train:", gnn_dt_accs[0], "Val:", gnn_dt_accs[1], "Test:", gnn_dt_accs[2])
    else:
        gnn_dt_accs = [score_model(model_dt, train_loader), score_model(model_dt, val_loader),
                       score_model(model_dt, test_loader)]
        print("Train:", gnn_dt_accs[0], "Val:", gnn_dt_accs[1], "Test:", gnn_dt_accs[2])
    depth_before = np.mean([trees[layer_name].get_depth() for layer_name in layer_names])
    num_nodes_before = get_num_nodes(trees, layer_names)
    nodes_pruned = 0
    extras = {
        "depth_before": 0,
        "depth_after": 0,
        "num_nodes_before": 0,
        "num_nodes_after": 0,
    }
    print("start pruning", time.strftime('%X %x %Z'))
    #return gnn_accs, gnn_dt_accs, [0, 0, 0], 0, 0, extras
    while True:

        def score_prunning(m, mask):
            m.eval()
            pred = m(data.x, data.edge_index).argmax(dim=-1)
            acc = int((pred[mask] == data.y[mask]).sum()) / len(mask)
            return acc

        if dataset_mask:
            trees, num_pruned_nodes = prune_trees_all_val(trees, model_dt, layer_names, train_mask, val_mask,
                                                          debug=debug, score_model=score_prunning, REP_train=True)
        else:
            trees, num_pruned_nodes = prune_trees_all_val(trees, model_dt, layer_names, train_loader, val_loader,
                                                          debug=debug, REP_train=True)
        model_dt.update_trees(trees)
        nodes_pruned += num_pruned_nodes
        if num_pruned_nodes == 0:
            break
        if debug:
            print('---------- GNN DT Test Accuracy After Pruning ----------')
            if dataset_mask:
                print(score_model_mask(model_dt, data, train_mask), score_model_mask(model_dt, data, val_mask),
                      score_model_mask(model_dt, data, test_mask))
            else:
                print(score_model(model_dt, train_loader), score_model(model_dt, val_loader),
                      score_model(model_dt, test_loader))
            print("Tree Depths:", [trees[layer_name].get_depth() for layer_name in layer_names])
    print(time.strftime('%X %x %Z'))
    print('---------- Tree Pruning ----------')
    num_nodes_remaining = get_num_nodes(trees, layer_names)
    print("Nodes Pruned:", nodes_pruned, ", Remaining Nodes:", num_nodes_remaining, ", Tree Depths:",
          [trees[layer_name].get_depth() for layer_name in layer_names])
    if dataset_mask:
        after_pruning_accs = [score_model_mask(model_dt, data, train_mask), score_model_mask(model_dt, data, val_mask),
                              score_model_mask(model_dt, data, test_mask)]
    else:
        after_pruning_accs = [score_model(model_dt, train_loader), score_model(model_dt, val_loader),
                              score_model(model_dt, test_loader)]
    depth_after = np.mean([trees[layer_name].get_depth() for layer_name in layer_names])
    print("Mean Depth Before:", depth_before, "Mean Depth After:", depth_after)
    print("Train:", after_pruning_accs[0], "Val:", after_pruning_accs[1], "Test:", after_pruning_accs[2])

    extras = {
        "depth_before": depth_before,
        "depth_after": depth_after,
        "num_nodes_before": num_nodes_before,
        "num_nodes_after": num_nodes_remaining,
    }
    explanation_accs = 0
    compute_explanation_accuracy = False
    explanation_gt_size = -1
    egt = None
    expl_correct = 0
    expl_total = 0
    if dataset_name == "BA_shapes":
        explanation_gt_size = 5
        compute_explanation_accuracy = True
        egt = dataset.egt
    if dataset_name == "Tree_Cycle":
        explanation_gt_size = 6
        compute_explanation_accuracy = True
        egt = dataset.egt
    if dataset_name == "Tree_Grid":
        explanation_gt_size = 9
        compute_explanation_accuracy = True
        egt = dataset.egt
    if dataset_name == "Infection":
        compute_explanation_accuracy = True
        explanation_gt_size = params["number_of_layers"]
        d = dataset.dataset[0]
        egt = {d.unique_solution_nodes[i]:d.unique_solution_explanations[i] for i in range(len(d.unique_solution_nodes))}
    if dataset_name == "Saturation":
        #pred, explanations = model_dt.explain(data.x, data.edge_index)
        #print (explanations.shape)
        # blue nodes have feature 1, red have feature 2
        # label 0: blue > red label1: red > blue
        # first 20 nodes are blue + red, test white nodes
        nbs = neighbors(dataset.dataset[0])
        compute_explanation_accuracy = True
        explanation_gt_size = 11
        egt = defaultdict(list)
        blue_list = [0, 1, 2, 3, 4 , 5 ,6 ,7, 8 ,9]
        red_list = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
        for n in range(data.x.size()[0]):
            if n < 20:
                continue
            egt[n].append(n)
            is_blue = data.y[n] == 0
            nb_to_check = blue_list if is_blue else red_list
            for nb in nbs[n]:
                if nb in nb_to_check:
                    egt[n].append(nb)
    if compute_explanation_accuracy:
        accs = defaultdict(list)
        pred, explanations = model_dt.explain(data.x, data.edge_index)
        for n in range(data.x.size()[0]):
            if not n in egt:
                continue
            expl = explanations[n, :]
            top_importance = np.argpartition(expl, -explanation_gt_size)[-explanation_gt_size:]
            curr_correct = 0
            for true_node in egt[n]:
                expl_total += 1
                if true_node in top_importance:
                    expl_correct += 1
                    curr_correct += 1
            accs[len(egt[n])].append(curr_correct / len(egt[n]))
            #print (n, top_importance, egt[n])
        print("##########################")
        for k,v in accs.items():
            print(np.mean(v))
        print ("----- run expl accuracy -------")
        print(expl_correct * 1.0 / expl_total)
        explanation_accs = expl_correct * 1.0 / expl_total

    if dataset_name == "BA_shapes":
        pred, explanations = model_dt.explain(data.x, data.edge_index)
        degrees = list(degree(dataset[0].edge_index[0], dataset[0].num_nodes))
        for n in range(data.x.size()[0]):
            if not n in egt:
                continue
            if degrees[n] != 4:
                continue
            expl = explanations[n, :]
            top_importance = np.argsort(expl)[-10:]
            sub = subgraph(top_importance, data.edge_index)[0]
            sub = remove_isolated_nodes(sub)[0]
            subdata = Data(x=None, edge_index=sub)
            g = to_networkx(subdata, to_undirected=True)
            nx.draw(g)
            plt.savefig("/home/user/Desktop/graphchef_bashapes.pdf")
            plt.show()
    if dataset_name == "Tree_Cycle":
        pred, explanations = model_dt.explain(data.x, data.edge_index)
        degrees = list(degree(dataset[0].edge_index[0], dataset[0].num_nodes))
        for n in range(data.x.size()[0]):
            if not n in egt:
                continue
            if degrees[n] != 3:
                continue
            expl = explanations[n, :]
            top_importance = np.argsort(expl)[-10:]
            sub = subgraph(top_importance, data.edge_index)[0]
            sub = remove_isolated_nodes(sub)[0]
            subdata = Data(x=None, edge_index=sub)
            g = to_networkx(subdata, to_undirected=True)
            nx.draw(g)
            plt.savefig("/home/user/Desktop/graphchef_treecycle.pdf")
            plt.show()
    if dataset_name == "Tree_Grid":
        pred, explanations = model_dt.explain(data.x, data.edge_index)
        degrees = list(degree(dataset[0].edge_index[0], dataset[0].num_nodes))
        for n in range(data.x.size()[0]):
            if not n in egt:
                continue
            if degrees[n] != 4:
                continue
            expl = explanations[n, :]
            top_importance = np.argsort(expl)[-10:]
            sub = subgraph(top_importance, data.edge_index)[0]
            sub = remove_isolated_nodes(sub)[0]
            subdata = Data(x=None, edge_index=sub)
            g = to_networkx(subdata, to_undirected=True)
            nx.draw(g)
            plt.savefig("/home/user/Desktop/graphchef_treegrid.pdf")
            plt.show()
    if dataset_name == "MUTAG":
        for d in test_loader:
            g = to_networkx(d, to_undirected=True)
            nx.draw(g)
            plt.show()
            plt.clf()
            pred, explanations = model_dt.explain(d.x, d.edge_index)
            top_importance = np.argsort(explanations[0])[-10:]
            sub = subgraph(top_importance, d.edge_index)[0]
            index = {k:i for i,k in enumerate(top_importance)}
            sub_index_src = [index[k.item()] for k in sub[0]]
            sub_index_dest = [index[k.item()] for k in sub[1]]
            subdata = Data(x=None, edge_index=torch.tensor([sub_index_src, sub_index_dest]))
            g = to_networkx(subdata, to_undirected=True)
            nx.draw(g)
            plt.savefig("/home/user/Desktop/graphchef_mutag.pdf")
            plt.show()
            plt.clf()
    if dataset_name == "REDDIT-BINARY":
        for d in test_loader:
            g = to_networkx(d, to_undirected=True)
            nx.draw(g)
            plt.show()
            plt.clf()
            pred, explanations = model_dt.explain(d.x, d.edge_index)
            top_importance = np.argsort(explanations[0])[-10:]
            print(top_importance)
            sub = subgraph(top_importance, d.edge_index)[0]
            index = {k:i for i,k in enumerate(top_importance)}
            sub_index_src = [index[k.item()] for k in sub[0]]
            sub_index_dest = [index[k.item()] for k in sub[1]]
            subdata = Data(x=torch.ones(10,1), edge_index=torch.tensor([sub_index_src, sub_index_dest]))
            print (subdata)
            g = to_networkx(subdata, to_undirected=True)
            nx.draw(g)
            plt.savefig("/home/user/Desktop/graphchef_reddit.pdf")
            plt.show()
            plt.clf()



    torch.save(model_dt, f'{params["data_dir"]}/model_checkpoints/dt_gnn_{dataset_name}_fold_{fold}.pt')
    return gnn_accs, gnn_dt_accs, after_pruning_accs, num_nodes_remaining, explanation_accs, extras




if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Train and evaluate the DT-GNN',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument("--dataset", default="MUTAG", help="Name of the dataset to run the experiment on")
    parser.add_argument("--data_dir", default="datasets/data", help="Path to the dataset directory")
    parser.add_argument("--batch_size", default=128, type=int)
    parser.add_argument("--state_space", default=10, type=int)
    parser.add_argument("--number_of_layers", default=5, type=int)
    parser.add_argument("--epochs", default=1500, type=int)
    parser.add_argument("--mlp_hidden_units", default=16, type=int)
    parser.add_argument("--folds", default=10, type=int)
    parser.add_argument("--onlydt", action='store_true')
    parser.add_argument("--withdegrees", action='store_true')
    args = parser.parse_args()

    params = {
        "state_space": args.state_space,
        "bounding_parameter": 1000,
        "number_of_layers": args.number_of_layers,
        "learning_rate": 0.01,
        "gumbel": True,
        "skip_connection": True,
        "softmax_temp": 1.0,
        "network": "mlp",
        "epochs": args.epochs,
        "folds": args.folds,
        "gumbel_noise": True,
        "es_patience": 100,
        "batch_size": args.batch_size,
        "data_dir": args.data_dir,
        "max_leaf_nodes": 100,
        "val_size": 0.1,
        "dropout": 0.0,
        "hidden_units": args.mlp_hidden_units,
        "score_explanation": False,
        "dataset_name": args.dataset,
        "onlydt": args.onlydt,
        "withdegrees": args.withdegrees,
    }

    dataset_name = args.dataset
    dataset, dataset_args = load_dataset(dataset_name, args)
    train_eval_model_cv(dataset, dataset_name, params, val_size=params["val_size"],
                        use_pooling=dataset_args["use_pooling"], folds=params["folds"],
                        gumbel_noise=params["gumbel_noise"], dataset_mask=dataset_args["dataset_mask"], debug=False)
