
import os
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid, Amazon, Planetoid, Amazon, WebKB, WikipediaNetwork, Actor, WikiCS

from torch_geometric.transforms import NormalizeFeatures
import matplotlib.pyplot as plt
import argparse as argparse
from torch_geometric.data import Data
from utils import *
import math
import pickle
import torch_geometric.transforms as T
from sklearn.feature_selection import mutual_info_classif
import os, math, pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams as rc
rc["font.family"] = "serif"
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}')
rc["font.size"] = 12
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')




parser = argparse.ArgumentParser(description='Single feature analysis (one outer loop).')
parser.add_argument('--dataset', type=str, default='Cora',
                    help='Dataset name (default: Cora) or dataSim_Y_A,X_A,Y_X for synthetic data')
parser.add_argument('--r', type=float, default=0.5,
                    help='Fraction of features to prune at each step (default: 0.5)')
args = parser.parse_args()

DS = args.dataset
drop_rate = args.r
if DS in ['Computers', 'Photo']:
    epochs = 800
    verbos_every = 100
else:
    epochs = 400
    verbos_every = 50

interval = verbos_every




os.makedirs("Plots", exist_ok=True)

def load_2d(path):
    """Load pickled array and return (epochs, runs)."""
    with open(path, "rb") as f:
        arr = pickle.load(f)
    arr = np.asarray(arr)
    if arr.ndim == 1:
        arr = arr[:, None]            # (epochs, 1)
    elif arr.ndim == 2:
        # normalize to (epochs, runs)
        if arr.shape[0] < arr.shape[1]:
            arr = arr.T
    else:
        raise ValueError(f"Unexpected array shape {arr.shape} in {path}")
    return arr

def best_test_per_interval_1run(val_1d: np.ndarray, test_1d: np.ndarray, interval: int):
    """Single-run version. val_1d/test_1d: shape (epochs,)."""
    n_epochs = len(val_1d)
    n_intervals = math.ceil(n_epochs / interval)
    best_tests = []
    for i in range(n_intervals):
        start, end = i * interval, min((i + 1) * interval, n_epochs)
        j_rel = np.argmax(val_1d[start:end])
        j = start + j_rel
        best_tests.append(test_1d[j])
    return np.array(best_tests)

def mean_std_percent(arr_2d):
    """Given (epochs, runs), return mean/std in PERCENT over runs per epoch."""
    mean = arr_2d.mean(axis=1) * 100.0
    # sample std; if 1 run, ddof=1 would be nan; fall back to ddof=0
    ddof = 1 if arr_2d.shape[1] > 1 else 0
    std  = arr_2d.std(axis=1, ddof=ddof) * 100.0
    return mean, std

# ---------------- Load all runs ----------------
val_base_all  = load_2d(f"data/ValAccBase_{DS}.pkl")     # (epochs, runs)
test_base_all = load_2d(f"data/TestAccBase_{DS}.pkl")    # (epochs, runs)
epochs, runs_base = val_base_all.shape

methods_all = {}
for tag in ["GFI", "TFI", "MI"]:
    vpath = f"data/ValAcc{tag}_{DS}.pkl"
    tpath = f"data/TestAcc{tag}_{DS}.pkl"
    if os.path.exists(vpath) and os.path.exists(tpath):
        methods_all[tag] = {"val": load_2d(vpath), "test": load_2d(tpath)}

# --------------- Color map (consistent across plots) ---------------
palette = plt.get_cmap('tab10')
color_map = {"BASE": "black"}
for i, tag in enumerate(sorted(methods_all.keys())):
    color_map[tag] = palette(i)

# --------------- X for interval plot: 100%, 75%, 75^2%, ... ---------------
interval = verbos_every
num_intervals = math.ceil(epochs / interval)
x_int = 100 * np.power(1-drop_rate, np.arange(num_intervals))

# --------------- Figure 1: per-epoch curves with shaded error ---------------
plt.figure(figsize=(10, 4), facecolor='white')
plt.suptitle(DS, fontsize=18)
# (a) Validation
plt.subplot(1, 2, 1)
x_ep = np.arange(1, epochs + 1)
base_val_mean, base_val_std = mean_std_percent(val_base_all)
plt.plot(x_ep, base_val_mean, label='Full model', color=color_map["BASE"])
plt.fill_between(x_ep, base_val_mean - base_val_std, base_val_mean + base_val_std,
                 alpha=0.18, edgecolor='none', color=color_map["BASE"])

for tag, d in methods_all.items():
    m, s = mean_std_percent(d["val"])
    plt.plot(x_ep, m, label=f'{tag}', color=color_map[tag])
    plt.fill_between(x_ep, m - s, m + s, alpha=0.18, edgecolor='none', color=color_map[tag])
for i in range(1, num_intervals):
    plt.axvline(x=(i * verbos_every), color='purple', linestyle='--', linewidth=0.5, alpha=0.8)
plt.xlabel('Epochs')
plt.ylabel(r'Validation Accuracy (\%)')
plt.grid(True, alpha=0.3)
plt.legend()

# (b) Test
plt.subplot(1, 2, 2)
base_test_mean, base_test_std = mean_std_percent(test_base_all)
plt.plot(x_ep, base_test_mean, label='Full model', color=color_map["BASE"])
plt.fill_between(x_ep, base_test_mean - base_test_std, base_test_mean + base_test_std,
                 alpha=0.18, edgecolor='none', color=color_map["BASE"])

for tag, d in methods_all.items():
    m, s = mean_std_percent(d["test"])
    plt.plot(x_ep, m, label=f'{tag}', color=color_map[tag])
    plt.fill_between(x_ep, m - s, m + s, alpha=0.18, edgecolor='none', color=color_map[tag])
for i in range(1, num_intervals):
    plt.axvline(x=(i * verbos_every), color='purple', linestyle='--', linewidth=0.5, alpha=0.8)
plt.xlabel('Epochs')
plt.ylabel(r'Test Accuracy (\%)')
plt.grid(True, alpha=0.3)
plt.legend()

plt.tight_layout()
plt.savefig(f"Plots/AccCurves_{DS}.jpg", dpi=200)

# --------------- Interval-wise differences: mean & std over runs ---------------
method_diff_mean = {}
method_diff_std  = {}

for tag, d in methods_all.items():
    val_m_all  = d["val"]   # (epochs, runs_m)
    test_m_all = d["test"]  # (epochs, runs_m)
    runs = min(val_m_all.shape[1], test_m_all.shape[1], runs_base)
    diffs_runs = np.zeros((runs, num_intervals), dtype=float)

    for r in range(runs):
        base_val_r  = val_base_all[:, r]
        base_test_r = test_base_all[:, r]
        full_test_by_int_r = best_test_per_interval_1run(base_val_r, base_test_r, interval)

        meth_val_r  = val_m_all[:, r]
        meth_test_r = test_m_all[:, r]
        meth_test_by_int_r = best_test_per_interval_1run(meth_val_r, meth_test_r, interval)

        diffs_runs[r, :] = (full_test_by_int_r - meth_test_by_int_r) * 100.0

    method_diff_mean[tag] = diffs_runs.mean(axis=0)
    ddof = 1 if runs > 1 else 0
    method_diff_std[tag]  = diffs_runs.std(axis=0, ddof=ddof)

# --------------- Figure 2: interval-wise differences with error bars ---------------
plt.figure(figsize=(6, 4), facecolor='white')
for tag in sorted(method_diff_mean.keys()):
    plt.errorbar(
        x_int,
        -method_diff_mean[tag],
        yerr=method_diff_std[tag],
        fmt='-o',
        capsize=3,
        label=f'{tag} - Full',
        color=color_map[tag],
        ecolor=color_map[tag],
        elinewidth=1,
        markeredgecolor=color_map[tag]
    )
# add a dashed line for each verbose_after checkpoint
for i in range(1, num_intervals):
    plt.axvline(x=x_int[i], color='purple', linestyle='--', linewidth=0.5, alpha=0.8)
plt.xscale('log')
plt.xticks(x_int, [f"{v:.1f}%" for v in x_int])
plt.xlabel(r'Percentage of features used (\%)')
plt.ylabel('Difference vs full model')
plt.title('Checkpoint-wise Test Accuracy Difference, ' + DS)
plt.grid(True, alpha=0.3)
plt.legend()
plt.gca().invert_xaxis()  # left = 100%, right = smaller %
plt.tight_layout()
plt.savefig(f"Plots/Interval_Diff_{DS}.jpg", dpi=200)
