import os
import torch
import numpy as np
import torch.nn.functional as F

import matplotlib.pyplot as plt
import argparse as argparse
from torch_geometric.data import Data
from utils import *
import math
import pickle
import torch_geometric.transforms as T
from sklearn.feature_selection import mutual_info_classif

from matplotlib import rcParams as rc
rc["font.family"] = "serif"
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}')
rc["font.size"] = 12
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')



parser = argparse.ArgumentParser(description='Single feature analysis (one outer loop).')
parser.add_argument('--dataset', type=str, default='Cora',
                    help='Dataset name (default: Cora) or dataSim_Y_A,X_A,Y_X for synthetic data')
parser.add_argument('--N_iter', type=int, default=5,
                    help='Number of iterations for averaging results (default: 10)')
parser.add_argument('--N_shuffle', type=int, default=10,
                    help='Number of shuffles per feature (default: 10)')
parser.add_argument('--lr', type=float, default=0.01,
                    help='Initial learning rate (default: 0.01)')
parser.add_argument('--decay', type=float, default=0.9,
                    help='Decay rate (default: 0.1)')
parser.add_argument('--isMLP', type=bool, default=False,
                    help='Also do MLP (default: False)')
parser.add_argument('--FI', type=str, default='GFI',
                    help='Feature importance used for fropping (default: GFI)')
parser.add_argument('--r', type=float, default=0.5,
                    help='Fraction of features to prune at each step (default: 0.5)')
args = parser.parse_args()

DS = args.dataset
N_iter = args.N_iter
N_shuffle = args.N_shuffle
lr_init = args.lr
decay_rate = args.decay
useMLP = args.isMLP
FI_method = args.FI
r = args.r

print("r:", r)
if DS in ['Computers', 'Photo']:
    epochs = 800
    verbos_every = 100
else:
    epochs = 400
    verbos_every = 50

num_checkpoints = epochs // verbos_every
print(f"Dataset: {DS}, N_iter: {N_iter}, N_shuffle: {N_shuffle}, lr: {lr_init}, decay: {decay_rate}, useMLP: {useMLP}, FI_method: {FI_method}")
print(num_checkpoints)
set_seed_all(0)
dataset, data_obj, num_node_features, num_classes = load_dataset(DS, device)
data_obk = data_obj.to(device)
# num_checkpoints = epochs
print(f"Training for {epochs} epochs, dropping every {verbos_every} epochs")
feature_drop_GNN = np.zeros((num_node_features, num_checkpoints, N_iter))

Compared_to_baseline = np.zeros((num_checkpoints, N_iter))

if useMLP:
    N_nodes = data_obj.x.size(0)
    idx = torch.arange(N_nodes, device=device, dtype=torch.long)
    edge_index_I = torch.stack([idx, idx], dim=0)  # shape (2, N_nodes)
    feature_drop_mlp = np.zeros((num_node_features, num_checkpoints, N_iter))


Checkpoint_acc = np.zeros((num_checkpoints, N_iter))
Full_model_acc_all = np.zeros(N_iter)

val_acc_base_all = np.zeros((epochs, N_iter))
test_acc_base_all = np.zeros((epochs, N_iter))

val_acc_prune_all = np.zeros((epochs, N_iter))
test_acc_prune_all = np.zeros((epochs, N_iter))

for run_idx in range(N_iter):
    print(f"\n=== Iteration {run_idx+1}/{N_iter} ===")
    # Seed first, then load data (so masks change per load if dataset shuffles masks by seed)
    set_seed_all(run_idx)

    dataset, data_obj, num_node_features, num_classes = load_dataset(DS, device)
    data_obj = data_obj.to(device)

    # get the base acc:
    if FI_method == 'GFI':
        print("Full model training")
        set_seed_all(run_idx)
        if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
            model_base = GCNNodeClassifier_Het(num_node_features, [512, 512], num_classes).to(device)
        else:
            if DS in ['Photo','Computers']:
                model_base = GIN(in_dim=num_node_features, hid_dim=512, out_dim=num_classes).to(device)
            else:
                model_base = GCNNodeClassifier(num_node_features, [512, 512], num_classes).to(device)
        opt_base = torch.optim.Adam(model_base.parameters(), lr=lr_init) # , weight_decay=5e-4
        
        loss_base, val_acc_base, test_acc_base = [], [], []
        best_va = 0
        for epoch in range(1, epochs + 1):
    
            model_base.train()
            opt_base.zero_grad()
            out = model_base(data_obj.x, data_obj.edge_index)
            loss = F.cross_entropy(out[data_obj.train_mask], data_obj.y[data_obj.train_mask])
            loss.backward()
            opt_base.step()

            model_base.eval()
            with torch.no_grad():
                logits = model_base(data_obj.x, data_obj.edge_index)
                preds = logits.argmax(-1)
                acc = lambda m: (preds[m] == data_obj.y[m]).float().mean().item()
                tr, va, te = acc(data_obj.train_mask), acc(data_obj.val_mask), acc(data_obj.test_mask)
                loss_base.append(loss.item())
                val_acc_base.append(va)
                test_acc_base.append(te)
        
        val_acc_base_all[:, run_idx] = val_acc_base
        test_acc_base_all[:, run_idx] = test_acc_base

    print("Subset model training ...")

    feature_mask = torch.ones(num_node_features, device=device)  # 1 = keep, 0 = pruned
    active_features = np.arange(num_node_features, dtype=int)    # numpy index array for convenience

    set_seed_all(run_idx)
    # model = GCNNodeClassifier(num_node_features, [512, 512], num_classes).to(device) 
    if DS in {'Texas', 'Cornell', 'Wisconsin', 'Actor'}:
        model = GCNNodeClassifier_Het(num_node_features, [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model = GIN(in_dim=num_node_features, hid_dim=512, out_dim=num_classes).to(device)
        else:
            model = GCNNodeClassifier(num_node_features, [512, 512], num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_init)  # weight_decay=5e-4
    FI_mask = (data_obj.val_mask )

    checkpoint = 0
    loss_all, val_acc_all, test_acc_all = [], [], []
    best_loss = 10e9
    checkpoint_epochs = []
    best_va = 0
    for epoch in range(1, epochs + 1):
 
        model.train()
        optimizer.zero_grad()
        x_masked = apply_feature_mask(data_obj.x, feature_mask)
        out = model(x_masked, data_obj.edge_index)
        loss = F.cross_entropy(out[data_obj.train_mask], data_obj.y[data_obj.train_mask])
        loss.backward()
        optimizer.step()

        # evaluation
        model.eval()
        with torch.no_grad():
            x_masked = apply_feature_mask(data_obj.x, feature_mask)
            logits = model(x_masked, data_obj.edge_index)
            preds = logits.argmax(-1)
            acc = lambda m: (preds[m] == data_obj.y[m]).float().mean().item()
            tr, va, te = acc(data_obj.train_mask), acc(data_obj.val_mask), acc(data_obj.test_mask)
            loss_all.append(loss.item())
            val_acc_all.append(va)
            test_acc_all.append(te)
            if va > best_va:
                best_te = te
                best_va = va
        
        if epoch % verbos_every == 0 or (epoch == (epochs)):
            print(f"Epoch {epoch:03d} | Loss {loss:.4f} | Train {tr:.3f} | Best Val {best_va:.3f} | Best Test {best_te:.3f}")
            print("  - Computing feature importances for checkpoint", checkpoint+1)
            with torch.no_grad():
                x_masked = apply_feature_mask(data_obj.x, feature_mask)
                logits = model(x_masked, data_obj.edge_index)     # reuse logits for base_acc
                preds_base = logits.argmax(-1)
                base_acc = 100 * (preds_base[FI_mask] == data_obj.y[FI_mask]).float().mean().item()

                if useMLP:
                    logits_I = model(x_masked, edge_index_I)
                    preds_base_I = logits_I.argmax(-1)
                    base_acc_I = 100 * (preds_base_I[FI_mask] == data_obj.y[FI_mask]).float().mean().item()

            checkpoint_epochs.append(epoch)
            # --- Permutation drops ONLY for active features
            feat_drops = {}  # feat_id -> mean drop
            feat_drops_mlp = {}
            if FI_method == 'GFI':
                for feat_i in active_features:
                    drops = []
                    drops_I = []
                    for s in range(N_shuffle):
                        x_perm = permute_single_feature_on_tensor(x_masked, feat_i, seed=run_idx + s)
                        with torch.no_grad():
                            preds_perm = model(x_perm, data_obj.edge_index).argmax(-1)
                            acc_perm = 100 * (preds_perm[FI_mask] == data_obj.y[FI_mask]).float().mean().item()
                            if useMLP:
                                preds_perm_I = model(x_perm, edge_index_I).argmax(-1)
                                acc_perm_I = 100 * (preds_perm_I[FI_mask ] == data_obj.y[FI_mask]).float().mean().item()

                        drops.append(base_acc - acc_perm)
                        if useMLP:
                            drops_I.append(base_acc_I - acc_perm_I)

                    feat_i_drop = float(np.mean(drops))
                    feat_drops[feat_i] = feat_i_drop
                    feature_drop_GNN[feat_i, checkpoint, run_idx] = feat_i_drop  # still record
                    if useMLP:
                        feat_i_drop_I = float(np.mean(drops_I))
                        feat_drops_mlp[feat_i] = feat_i_drop_I
                        feature_drop_mlp[feat_i, checkpoint, run_idx] = feat_i_drop_I  # still record
            elif FI_method == 'TFI':
                A_np = get_adj_np(data_obj)
                # dense X (on CPU)
                if hasattr(x_masked, 'is_sparse') and x_masked.is_sparse:
                    X_np = x_masked.to_dense().cpu().numpy()
                else:
                    X_np = x_masked.cpu().numpy()
                Xn_np = A_np @ X_np
                tv_mask_np = FI_mask.cpu().numpy()
                Xn_tv_np = Xn_np[tv_mask_np]
                y_tv_np = data_obj.y[FI_mask].cpu().numpy().astype(int)
                TFI_FI = mutual_info_classif(Xn_tv_np, y_tv_np, random_state=run_idx).astype(float)
                for feat_i in active_features:
                    feat_i_drop = TFI_FI[feat_i]
                    feat_drops[feat_i] = feat_i_drop
                    feature_drop_GNN[feat_i, checkpoint, run_idx] = feat_i_drop  # still record
            elif FI_method == 'MI':
                X_tv_np = get_dense_x_split(data_obj, FI_mask)
                y_tv_np = data_obj.y[FI_mask].cpu().numpy().astype(int)
                MI_FI = mutual_info_classif(X_tv_np, y_tv_np, random_state=run_idx).astype(float)
                for feat_i in active_features:
                    feat_i_drop = MI_FI[feat_i]
                    feat_drops[feat_i] = feat_i_drop
                    feature_drop_GNN[feat_i, checkpoint, run_idx] = feat_i_drop  # still record


                
                
                
            checkpoint = checkpoint + 1
            # --- Pruning based on feature drops: lowest quartile of ACTIVE features
            if len(active_features) > 0:
                scores1 = np.array([feat_drops[i] for i in active_features])  # aligned with active_features
                # threshold1 = np.percentile(scores1, 25)  
                # to_prune1 = active_features[scores1 < threshold1] 

                k = int(np.ceil(r * len(scores1)))
                lowest_idx = np.argsort(scores1)[:k]
                to_prune1 = active_features[lowest_idx]

                if useMLP:
                    scores2 = np.array([feat_drops_mlp[i] for i in active_features])
                    threshold2 = np.percentile(scores2, r)
                    to_prune2 = active_features[scores2 < threshold2]
                    to_prune = np.intersect1d(to_prune1, to_prune2, assume_unique=True)
                else:
                    to_prune = to_prune1

                if len(to_prune) > 0:
                    feature_mask[torch.tensor(to_prune, device=device, dtype=torch.long)] = 0.0
                    active_features = np.setdiff1d(active_features, to_prune, assume_unique=True)
                    print(f"  - Pruned {len(to_prune)} features at checkpoint {checkpoint} "
                        f"(remaining active: {len(active_features)})")    
                    
                    for g in optimizer.param_groups:
                        g['lr'] = g['lr'] * decay_rate

                else:
                    print("  - No features pruned at this checkpoint.")

    
    val_acc_prune_all[:, run_idx] = val_acc_all
    test_acc_prune_all[:, run_idx] = test_acc_all


if FI_method == 'GFI':
    with open(f"data/ValAccBase_{DS}.pkl", "wb") as f:
        pickle.dump(val_acc_base_all, f)
    with open(f"data/TestAccBase_{DS}.pkl", "wb") as f:
        pickle.dump(test_acc_base_all, f)
    with open(f"data/ValAccGFI_{DS}.pkl", "wb") as f:
        pickle.dump(val_acc_prune_all, f)
    with open(f"data/TestAccGFI_{DS}.pkl", "wb") as f:
        pickle.dump(test_acc_prune_all, f)

    # save feature_drop_gnn with pickle
    with open(f"data/heatmap/GFI_drop_{DS}.pkl", "wb") as f:
        pickle.dump(feature_drop_GNN, f)

elif FI_method == 'TFI':
    with open(f"data/ValAccTFI_{DS}.pkl", "wb") as f:
        pickle.dump(val_acc_prune_all, f)
    with open(f"data/TestAccTFI_{DS}.pkl", "wb") as f:
        pickle.dump(test_acc_prune_all, f)
elif FI_method == 'MI':
    with open(f"data/ValAccMI_{DS}.pkl", "wb") as f:
        pickle.dump(val_acc_prune_all, f)
    with open(f"data/TestAccMI_{DS}.pkl", "wb") as f:
        pickle.dump(test_acc_prune_all, f)

