import os
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
import matplotlib.pyplot as plt
import argparse as argparse
from utils import *
import pickle
import scipy.sparse as sp
from torch_geometric.utils import to_dense_adj
from sklearn.feature_selection import mutual_info_classif
from matplotlib import rcParams as rc

# ---- Plot config
rc["font.family"] = "serif"
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}')
rc["font.size"] = 12

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# ------- Args
parser = argparse.ArgumentParser(description='Single feature analysis (one outer loop).')
parser.add_argument('--dataset', type=str, default='Cora',
                    help='Dataset name (default: Cora) or dataSim_Y_A,X_A,Y_X for synthetic data')
parser.add_argument('--N_iter', type=int, default=5,
                    help='Number of iterations for averaging results (default: 10)')
parser.add_argument('--N_shuffle', type=int, default=10,
                    help='Number of shuffles per feature (default: 10)')
args = parser.parse_args()

DS = args.dataset
N_iter = args.N_iter
N_shuffle = args.N_shuffle
r = 0.05  if DS in ['PubMed', 'Computers', 'Photo'] else 0.02  # fraction of features to keep for FS (top r*100 %)


N_iter = args.N_iter  # number of iterations for averaging results
if DS in ['Computers', 'Photo']:
    epochs = 800
    verbos_every = 100
else:
    epochs = 400
    verbos_every = 50

os.makedirs('data', exist_ok=True)
os.makedirs('Plots', exist_ok=True)

# ---------- Helpers
def compute_feature_drop(model, data, feature_idx, n_shuffle, seed_base):
    """Return a vector of mean accuracy drops per feature on VAL mask (len = num_features)."""
    model.eval()
    FI_mask = (data.val_mask)
    with torch.no_grad():
        preds_orig = model(data.x, data.edge_index).argmax(-1)
        base_acc = 100 * (preds_orig[FI_mask] == data.y[FI_mask]).float().mean().item()

    drops = np.zeros(len(feature_idx), dtype=float)
    for j, feat_i in enumerate(feature_idx):
        if (feat_i % 500 == 0) or (feat_i == feature_idx[-1]):
            print(f"  - Permuting Feature {feat_i+1}/{len(feature_idx)}")
        per_shuf = []
        for s in range(n_shuffle):
            data_perm = permute_single_feature(data, feat_i, seed=seed_base + s)
            with torch.no_grad():
                preds_perm = model(data_perm.x, data_perm.edge_index).argmax(-1)
                acc_perm = 100 * (preds_perm[FI_mask] == data_perm.y[FI_mask]).float().mean().item()
            per_shuf.append(base_acc - acc_perm)
        drops[j] = float(np.mean(per_shuf))
    return drops

def compute_feature_drop_mask(model, data, feature_idx):
    """
    Accuracy drop when masking each feature with a constant.
    - fill: 'zero' | 'mean' | 'median'
    - ref_mask: which nodes to evaluate on (e.g., data.val_mask)
    - trainval_mask_for_fill: nodes used to compute mean/median (e.g., data.train_mask | data.val_mask)
    """
    model.eval()
    ref_mask = data.val_mask
    # Base accuracy
    preds_orig = model(data.x, data.edge_index).argmax(-1)
    base_acc = 100 * (preds_orig[ref_mask] == data.y[ref_mask]).float().mean().item()

    drops = np.zeros(len(feature_idx), dtype=float)
    for j, feat_i in enumerate(feature_idx):
        if (feat_i % 500 == 0) or (feat_i == feature_idx[-1]):
            print(f"  - Masking Feature {feat_i+1}/{len(feature_idx)}")
        # make feat_i zero
        data_masked = data.clone()
        data_masked.x = data.x.clone()
        data_masked.x[:, feat_i] = 0.0
        preds_mask = model(data_masked.x, data_masked.edge_index).argmax(-1)
        acc_mask = 100 * (preds_mask[ref_mask] == data_masked.y[ref_mask]).float().mean().item()
        drops[j] = base_acc - acc_mask
    return drops


# ---------- Accumulators (lazy-init after first load to honor "load inside loop")
FI_gnn_sum = None
FI_mlp_sum = None
FI_mi_sum  = None
FI_tfi_sum = None
feature_idx_ref = None
FI_hattr_sum = None
FI_heuc_sum  = None
FI_hge_sum = None
FI_mask_sum = None

base_Acc = []
FS_GNN = []
FS_MLP = []
FS_TFI = []
FS_MI = []
FS_RAND = []
FS_HATTR = []
FS_HEUC  = []
FS_HGE = []
FS_MASK = []



for run_idx in range(N_iter):
    print(f"\n=== Iteration {run_idx+1}/{N_iter} ===")
    # Seed first, then load data (so masks change per load if dataset shuffles masks by seed)
    set_seed_all(run_idx)

    # Load dataset inside the loop per request
    dataset, data, num_node_features, num_classes = load_dataset(DS, device)
    data = data.to(device)
    feature_idx = list(range(num_node_features))
    tv_mask = data.val_mask  # | data.train_mask)

    # Lazy initialize accumulators on first pass
    if FI_gnn_sum is None:
        FI_gnn_sum = np.zeros(num_node_features, dtype=float)
        FI_mlp_sum = np.zeros(num_node_features, dtype=float)
        FI_mi_sum  = np.zeros(num_node_features, dtype=float)
        FI_tfi_sum = np.zeros(num_node_features, dtype=float)
        FI_hattr_sum = np.zeros(num_node_features, dtype=float)
        FI_heuc_sum  = np.zeros(num_node_features, dtype=float)
        FI_hge_sum  = np.zeros(num_node_features, dtype=float)
        FI_mask_sum = np.zeros(num_node_features, dtype=float)
        feature_idx_ref = feature_idx


    # --- Train GNN
    set_seed_all(run_idx)
    # model_gnn = GCNNodeClassifier(num_node_features, [512, 512], num_classes).to(device)
    if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
        model_gnn = GCNNodeClassifier_Het(num_node_features, [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model_gnn = GIN(in_dim=num_node_features, hid_dim=512, out_dim=num_classes).to(device)
        else:
            model_gnn = GCNNodeClassifier(num_node_features, [512, 512], num_classes).to(device)
    opt_gnn = torch.optim.Adam(model_gnn.parameters(), lr=0.01, weight_decay=5e-4)
    best_val, best_test_at_val, best_state = run_experiment(model_gnn, opt_gnn, data, epochs=epochs, verbose_every=verbos_every)
    print(f"GNN  | Best Test@Val: {best_test_at_val:.4f}")
    model_gnn.load_state_dict(best_state)
    base_Acc.append(best_test_at_val)
    # --- GNN FI (permutation)
    set_seed_all(run_idx)
    fi_gnn_iter = compute_feature_drop(model_gnn, data, feature_idx, N_shuffle, seed_base=run_idx * 1000 + 1)
    FI_gnn_sum += fi_gnn_iter

    fi_mask_iter = compute_feature_drop_mask(model_gnn, data, feature_idx)
    FI_mask_sum += fi_mask_iter


    # get most important features from GNN FI (higher 90 percentile)
    # outlier_idx_gnn = np.where(fi_gnn_iter >= np.percentile(fi_gnn_iter, top_thresh))[0]
    k = int(r * len(fi_gnn_iter))
    # k = 20
    top_idx_sorted = np.argsort(fi_gnn_iter)[-k:]
    outlier_idx_gnn = top_idx_sorted[np.argsort(fi_gnn_iter[top_idx_sorted])[::-1]]
    set_seed_all(run_idx)
    data_gnn = data.clone()
    data_gnn.x = data.x[:, outlier_idx_gnn]
    if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
        model_gnn = GCNNodeClassifier_Het(len(outlier_idx_gnn), [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model_gnn = GIN(in_dim=len(outlier_idx_gnn), hid_dim=512, out_dim=num_classes).to(device)
        else:
            model_gnn = GCNNodeClassifier(len(outlier_idx_gnn), [512, 512], num_classes).to(device)
    optimizer_gnn = torch.optim.Adam(model_gnn.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_gnn, best_test_at_val_gnn_iter, _ = run_experiment(model_gnn, optimizer_gnn, data_gnn, epochs=epochs, verbose_every=verbos_every)
    FS_GNN.append(best_test_at_val_gnn_iter)

    k = int(r * len(fi_mask_iter))
    top_idx_sorted = np.argsort(fi_mask_iter)[-k:]
    outlier_idx_mask = top_idx_sorted[np.argsort(fi_mask_iter[top_idx_sorted])[::-1]]
    set_seed_all(run_idx)
    data_mask = data.clone()
    data_mask.x = data.x[:, outlier_idx_mask]
    if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
        model_mask = GCNNodeClassifier_Het(len(outlier_idx_mask), [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model_mask = GIN(in_dim=len(outlier_idx_mask), hid_dim=512, out_dim=num_classes).to(device)
        else:
            model_mask = GCNNodeClassifier(len(outlier_idx_mask), [512, 512], num_classes).to(device)
    optimizer_mask = torch.optim.Adam(model_mask.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_mask, best_test_at_val_mask_iter, _ = run_experiment(model_mask, optimizer_mask, data_mask, epochs=epochs, verbose_every=verbos_every)
    FS_MASK.append(best_test_at_val_mask_iter)


    # --- Train MLP
    set_seed_all(run_idx)
    model_mlp = MLPNodeClassifier(num_node_features, [512, 512], num_classes).to(device)
    opt_mlp = torch.optim.Adam(model_mlp.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_mlp, best_test_at_val_mlp, best_state_mlp = run_experiment(model_mlp, opt_mlp, data, epochs=epochs, verbose_every=verbos_every)
    print(f"MLP  | Best Test@Val: {best_test_at_val_mlp:.4f}")
    model_mlp.load_state_dict(best_state_mlp)

    # --- MLP FI (permutation)
    set_seed_all(run_idx)
    fi_mlp_iter = compute_feature_drop(model_mlp, data, feature_idx, N_shuffle, seed_base=run_idx * 1000 + 2)
    FI_mlp_sum += fi_mlp_iter

    # get 10% most important features from MLP FI (higher 90 percentile)
    # outlier_idx_mlp = np.where(fi_mlp_iter >= np.percentile(fi_mlp_iter, top_thresh))[0]
    k = int(r * len(fi_mlp_iter))
    top_idx_sorted = np.argsort(fi_mlp_iter)[-k:]
    outlier_idx_mlp = top_idx_sorted[np.argsort(fi_mlp_iter[top_idx_sorted])[::-1]]
    set_seed_all(run_idx)
    data_mlp = data.clone()
    data_mlp.x = data.x[:, outlier_idx_mlp]
    if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
        model_mlp = GCNNodeClassifier_Het(len(outlier_idx_mlp), [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model_mlp = GIN(in_dim=len(outlier_idx_mlp), hid_dim=512, out_dim=num_classes).to(device)
        else:
            model_mlp = GCNNodeClassifier(len(outlier_idx_mlp), [512, 512], num_classes).to(device)
    optimizer_mlp = torch.optim.Adam(model_mlp.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_mlp, best_test_at_val_mlp_iter, _ = run_experiment(model_mlp, optimizer_mlp, data_mlp, epochs=epochs, verbose_every=verbos_every)
    FS_MLP.append(best_test_at_val_mlp_iter)

    # --- MI (computed on train+val split for THIS iteration)
    set_seed_all(run_idx)
    
    X_tv_np = get_dense_x_split(data, tv_mask)
    y_tv_np = data.y[tv_mask].cpu().numpy().astype(int)
    fi_mi_iter = mutual_info_classif(X_tv_np, y_tv_np, random_state=run_idx).astype(float)
    FI_mi_sum += fi_mi_iter

    # get 10% most important features from MI (higher 90 percentile)
    # outlier_idx_mi = np.where(fi_mi_iter >= np.percentile(fi_mi_iter, top_thresh))[0]
    k = int(r * len(fi_mi_iter))
    top_idx_sorted = np.argsort(fi_mi_iter)[-k:]
    outlier_idx_mi = top_idx_sorted[np.argsort(fi_mi_iter[top_idx_sorted])[::-1]]
    set_seed_all(run_idx)
    data_mi = data.clone()
    data_mi.x = data.x[:, outlier_idx_mi]
    if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
        model_mi = GCNNodeClassifier_Het(len(outlier_idx_mi), [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model_mi = GIN(in_dim=len(outlier_idx_mi), hid_dim=512, out_dim=num_classes).to(device)
        else:
            model_mi = GCNNodeClassifier(len(outlier_idx_mi), [512, 512], num_classes).to(device)
    optimizer_mi = torch.optim.Adam(model_mi.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_mi, best_test_at_val_mi_iter, _ = run_experiment(model_mi, optimizer_mi, data_mi, epochs=epochs, verbose_every=verbos_every)
    FS_MI.append(best_test_at_val_mi_iter)


    # --- TFI (AX vs Y) for THIS iteration
    set_seed_all(run_idx)
    A_np = get_adj_np(data)
    # dense X (on CPU)
    if hasattr(data.x, 'is_sparse') and data.x.is_sparse:
        X_np = data.x.to_dense().cpu().numpy()
    else:
        X_np = data.x.cpu().numpy()
    Xn_np = A_np @ X_np
    tv_mask_np = tv_mask.cpu().numpy()
    Xn_tv_np = Xn_np[tv_mask_np]
    fi_tfi_iter = mutual_info_classif(Xn_tv_np, y_tv_np, random_state=run_idx).astype(float)
    FI_tfi_sum += fi_tfi_iter

    # get 10% most important features from TFI (higher 90 percentile)
    # outlier_idx_tfi = np.where(fi_tfi_iter >= np.percentile(fi_tfi_iter, top_thresh))[0]
    k = int(r * len(fi_tfi_iter))
    top_idx_sorted = np.argsort(fi_tfi_iter)[-k:]
    outlier_idx_tfi = top_idx_sorted[np.argsort(fi_tfi_iter[top_idx_sorted])[::-1]]
    set_seed_all(run_idx)
    data_tfi = data.clone()
    data_tfi.x = data.x[:, outlier_idx_tfi]
    if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
        model_tfi = GCNNodeClassifier_Het(len(outlier_idx_tfi), [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model_tfi = GIN(in_dim=len(outlier_idx_tfi), hid_dim=512, out_dim=num_classes).to(device)
        else:
            model_tfi = GCNNodeClassifier(len(outlier_idx_tfi), [512, 512], num_classes).to(device)
    optimizer_tfi = torch.optim.Adam(model_tfi.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_tfi, best_test_at_val_tfi_iter, _ = run_experiment(model_tfi, optimizer_tfi, data_tfi, epochs=epochs, verbose_every=verbos_every)
    FS_TFI.append(best_test_at_val_tfi_iter)


    # --- Random baseline (repeat multiple times per run)
    rand_accs = []
    for rr in range(5):  # 5 random draws
        set_seed_all(run_idx * 100 + rr)  # keep reproducible randomness
        k = int(np.ceil(r * len(feature_idx)))
        k = max(1, min(k, len(feature_idx)))  # ensure at least 1 and not > num_features
        outlier_idx_rand = np.random.choice(feature_idx, size=k, replace=False)

        data_rand = data.clone()
        data_rand.x = data.x[:, outlier_idx_rand]
        set_seed_all(run_idx)
        if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
            model_rand = GCNNodeClassifier_Het(len(outlier_idx_rand), [512, 512], num_classes).to(device)
        else:
            if DS in ['Photo','Computers']:
                model_rand = GIN(in_dim=len(outlier_idx_rand), hid_dim=512, out_dim=num_classes).to(device)
            else:
                model_rand = GCNNodeClassifier(len(outlier_idx_rand), [512, 512], num_classes).to(device)

        optimizer_rand = torch.optim.Adam(model_rand.parameters(), lr=0.01, weight_decay=5e-4)
        _, best_test_at_val_rand_iter, _ = run_experiment(
            model_rand, optimizer_rand, data_rand, epochs=epochs, verbose_every=verbos_every
        )
        rand_accs.append(best_test_at_val_rand_iter)

    FS_RAND.append(np.mean(rand_accs))


        # --- h_sim-euc (homophily via Euclidean similarity on features)
    set_seed_all(run_idx)
    fi_heuc_iter = compute_hsim_euc_per_feature(data)
    FI_heuc_sum += fi_heuc_iter

    k = int(r * len(fi_heuc_iter))
    top_idx_sorted = np.argsort(fi_heuc_iter)[-k:]
    outlier_idx_heuc = top_idx_sorted[np.argsort(fi_heuc_iter[top_idx_sorted])[::-1]]
    set_seed_all(run_idx)
    data_heuc = data.clone()
    data_heuc.x = data.x[:, outlier_idx_heuc]
    if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
        model_heuc = GCNNodeClassifier_Het(len(outlier_idx_heuc), [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model_heuc = GIN(in_dim=len(outlier_idx_heuc), hid_dim=512, out_dim=num_classes).to(device)
        else:
            model_heuc = GCNNodeClassifier(len(outlier_idx_heuc), [512, 512], num_classes).to(device)
    optimizer_heuc = torch.optim.Adam(model_heuc.parameters(), lr=0.01, weight_decay=5e-4)
    _, best_test_at_val_heuc_iter, _ = run_experiment(model_heuc, optimizer_heuc, data_heuc, epochs=epochs, verbose_every=verbos_every)
    FS_HEUC.append(best_test_at_val_heuc_iter)

    # --- h_attr (attribute assortativity per feature)
    set_seed_all(run_idx)
    fi_hattr_iter = compute_hattr_per_feature(data)
    FI_hattr_sum += fi_hattr_iter

    k = int(r * len(fi_hattr_iter))
    top_idx_sorted = np.argsort(fi_hattr_iter)[-k:]
    outlier_idx_hattr = top_idx_sorted[np.argsort(fi_hattr_iter[top_idx_sorted])[::-1]]
    set_seed_all(run_idx)
    data_hattr = data.clone()
    data_hattr.x = data.x[:, outlier_idx_hattr]
    if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
        model_hattr = GCNNodeClassifier_Het(len(outlier_idx_hattr), [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model_hattr = GIN(in_dim=len(outlier_idx_hattr), hid_dim=512, out_dim=num_classes).to(device)
        else:
            model_hattr = GCNNodeClassifier(len(outlier_idx_hattr), [512, 512], num_classes).to(device)
    optimizer_hattr = torch.optim.Adam(model_hattr.parameters(), lr=0.01, weight_decay=5e-4)
    _, best_test_at_val_hattr_iter, _ = run_experiment(model_hattr, optimizer_hattr, data_hattr, epochs=epochs, verbose_every=verbos_every)
    FS_HATTR.append(best_test_at_val_hattr_iter)


    # --- h_GE (graph-entropy of neighbor similarity)
    set_seed_all(run_idx)
    fi_hge_iter = compute_hge_per_feature(data, nbins=30, sim_type="cos")  # or sim_type="euc"
    FI_hge_sum += fi_hge_iter

    k = int(r * len(fi_hge_iter))
    top_idx_sorted = np.argsort(fi_hge_iter)[-k:]
    outlier_idx_hge = top_idx_sorted[np.argsort(fi_hge_iter[top_idx_sorted])[::-1]]

    data_hge = data.clone()
    data_hge.x = data.x[:, outlier_idx_hge]
    set_seed_all(run_idx)
    if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
        model_hge = GCNNodeClassifier_Het(len(outlier_idx_hge), [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model_hge = GIN(in_dim=len(outlier_idx_hge), hid_dim=512, out_dim=num_classes).to(device)
        else:
            model_hge = GCNNodeClassifier(len(outlier_idx_hge), [512, 512], num_classes).to(device)
    optimizer_hge = torch.optim.Adam(model_hge.parameters(), lr=0.01, weight_decay=5e-4)
    _, best_test_at_val_hge_iter, _ = run_experiment(model_hge, optimizer_hge, data_hge, epochs=epochs, verbose_every=verbos_every)
    FS_HGE.append(best_test_at_val_hge_iter)




# # # ------- Averages across iterations
# FI_gnn = FI_gnn_sum / N_iter
# FI_mlp = FI_mlp_sum / N_iter
# FI_mi  = FI_mi_sum  / N_iter
# FI_tfi = FI_tfi_sum / N_iter

# # ------- Save
# np.save(f'data/{DS}_feature_importance_gnn.npy', FI_gnn)
# np.save(f'data/{DS}_feature_importance_mlp.npy', FI_mlp)
# np.save(f'data/{DS}_feature_importance_mi.npy',  FI_mi)
# np.save(f'data/{DS}_feature_importance_tfi.npy', FI_tfi)

# ------- Compare base ACC with FS ACC
print("\n=== Summary ===")
print(f"Feature number for FS methods: {int(r*num_node_features)} (top {int(r*100)}%)")
print(f"Base ACC (all features)       : {np.mean(base_Acc):.2f} ± {np.std(base_Acc):.2f}")
print(f"FS GNN ACC (top {int(r*100)}% GNN FI)   : {np.mean(FS_GNN):.2f} ± {np.std(FS_GNN):.2f}")
print(f"FS MLP ACC (top {int(r*100)}% MLP FI)   : {np.mean(FS_MLP):.2f} ± {np.std(FS_MLP):.2f}")
print(f"FS TFI ACC (top {int(r*100)}% TFI)      : {np.mean(FS_TFI):.2f} ± {np.std(FS_TFI):.2f}")
print(f"FS MI ACC (top {int(r*100)}% MI)        : {np.mean(FS_MI):.2f} ± {np.std(FS_MI):.2f}")
print(f"FS RAND ACC (random {int(r*100)}%) : {np.mean(FS_RAND):.2f} ± {np.std(FS_RAND):.2f}")
print(f"FS MASK ACC (top {int(r*100)}%) : {np.mean(FS_MASK):.2f} ± {np.std(FS_MASK):.2f}")
print(f"FS h-sim-euc ACC (top {int(r*100)}%) : {np.mean(FS_HEUC):.2f} ± {np.std(FS_HEUC):.2f}")
print(f"FS h-attr    ACC (top {int(r*100)}%) : {np.mean(FS_HATTR):.2f} ± {np.std(FS_HATTR):.2f}")
print(f"FS h-GE      ACC (top {int(r*100)}%) : {np.mean(FS_HGE):.2f} ± {np.std(FS_HGE):.2f}")


# # ------- Plot
# feature_idx = feature_idx_ref if feature_idx_ref is not None else list(range(len(FI_gnn)))
# plt.figure(figsize=(12, 8))

# plt.subplot(2, 2, 1)
# plt.scatter(feature_idx, FI_gnn)
# plt.xlabel('Feature Index')
# plt.ylabel('Mean Accuracy Drop')
# plt.title('FI GNN')
# plt.grid(axis='y')
# plt.tight_layout()

# plt.subplot(2, 2, 2)
# plt.scatter(feature_idx, FI_mlp)
# plt.xlabel('Feature Index')
# plt.ylabel('Mean Accuracy Drop')
# plt.title('FI MLP')
# plt.grid(axis='y')
# plt.tight_layout()

# plt.subplot(2, 2, 3)
# plt.scatter(feature_idx, FI_mi)
# plt.xlabel('Feature Index')
# plt.ylabel('MI(X,Y)')
# plt.title('FI MI (avg over runs)')
# plt.grid(axis='y')
# plt.tight_layout()

# plt.subplot(2, 2, 4)
# plt.scatter(feature_idx, FI_tfi)
# plt.xlabel('Feature Index')
# plt.ylabel('MI(AX,Y)')
# plt.title('FI TFI (avg over runs)')
# plt.grid(axis='y')
# plt.tight_layout()

# plt.savefig(f'Plots/{DS}_feature_importances.png')
# print(f"\nSaved: data/{DS}_feature_importance_*.npy and Plots/{DS}_feature_importances.png")

# # plot the histogram of feature importance scores for each method with 2,2 subplots
# plt.figure(figsize=(12, 8))
# plt.subplot(2, 2, 1)
# plt.hist(FI_gnn, bins=10)
# plt.xlabel('Mean Accuracy Drop')
# plt.ylabel('Count')
# plt.title('FI GNN')
# plt.grid(axis='y')
# plt.tight_layout()
# plt.subplot(2, 2, 2)
# plt.hist(FI_mlp, bins=10)
# plt.xlabel('Mean Accuracy Drop')
# plt.ylabel('Count')
# plt.title('FI MLP')
# plt.grid(axis='y')
# plt.tight_layout()
# plt.subplot(2, 2, 3)
# plt.hist(FI_mi, bins=10)
# plt.xlabel('MI(X,Y)')
# plt.ylabel('Count')
# plt.title('FI MI (avg over runs)')
# plt.grid(axis='y')
# plt.tight_layout()
# plt.subplot(2, 2, 4)
# plt.hist(FI_tfi, bins=10)
# plt.xlabel('MI(AX,Y)')
# plt.ylabel('Count')
# plt.title('FI TFI (avg over runs)')
# plt.grid(axis='y')
# plt.tight_layout()
# plt.savefig(f'Plots/{DS}_feature_importances_hist.png')
