### run after the posthoc 0513 file that include
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from ripser import ripser
from persim import plot_diagrams
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from scipy.stats import f_oneway, pearsonr
import shap

def compute_betti_numbers(points, return_diagrams=False):
    result = ripser(points, maxdim=2)
    diagrams = result['dgms']
    bettis = [len(dgm) for dgm in diagrams]
    while len(bettis) < 3:
        bettis.append(0)
    if return_diagrams:
        return bettis, diagrams
    return bettis

def curvature_index(data):
    pca = PCA(n_components=3)
    pca.fit(data)
    return 1.0 - pca.explained_variance_ratio_[0]

def geometry_aware_ctis(beta_base, beta_pert, gamma_g, weights=(1, 1, 1)):
    while len(beta_base) < 3:
        beta_base.append(0)
    while len(beta_pert) < 3:
        beta_pert.append(0)
    return gamma_g * sum(w * abs(b1 - b2) for b1, b2, w in zip(beta_base, beta_pert, weights))

# ---------------- SHAP Analysis ----------------
def run_shap_analysis(features, targets, feature_names, output_path):
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    X_train, X_test, y_train, y_test = train_test_split(features, targets, test_size=0.2, random_state=42)
    model.fit(X_train, y_train)
    explainer = shap.Explainer(model, X_train, feature_names=feature_names)
    shap_values = explainer(X_test)

    shap.plots.beeswarm(shap_values, show=False)
    plt.tight_layout()
    plt.savefig(os.path.join(output_path, "shap_feature_importance.png"), dpi=300)
    plt.close()

    mean_shap = np.abs(shap_values.values).mean(axis=0)
    df_shap = pd.DataFrame({"Feature": feature_names, "Mean_SHAP": mean_shap})
    df_shap.sort_values("Mean_SHAP", ascending=False).to_csv(os.path.join(output_path, "shap_values.csv"), index=False)

# ---------------- Diagnostic Plots ----------------
def plot_grad_clustering(grad_files, output_path):
    grad_vectors = [np.load(f).flatten() for f in grad_files if os.path.exists(f)]
    if len(grad_vectors) < 2:
        return
    coords = PCA(n_components=2).fit_transform(grad_vectors)
    plt.figure(figsize=(6, 4))
    plt.scatter(coords[:, 0], coords[:, 1], c=range(len(coords)), cmap='viridis')
    for i, (x, y) in enumerate(coords):
        plt.text(x, y, f"S{i+1}", fontsize=8)
    plt.title("Session-wise GradU PCA Projection")
    plt.tight_layout()
    plt.savefig(os.path.join(output_path, "gradU_pca_sessions.png"))
    plt.close()

def plot_betti1_spikes_txt(betti_paths, output_path):
    import numpy as np
    import matplotlib.pyplot as plt
    import os

    print("[DEBUG] Plotting Betti-1 anomalies from .txt files")

    fig, ax = plt.subplots(figsize=(12, 4.5))
    yticks, ylabels = [], []

    for sid, file in enumerate(betti_paths, 1):
        print(f"Session {sid}: checking file {file}")
        if not os.path.exists(file):
            print(" File missing.")
            continue

        times, betti1_vals = [], []
        with open(file, 'r') as f:
            for line in f:
                if ',' not in line:
                    continue
                try:
                    t_str, bvec_str = line.strip().split(',', 1)
                    t = int(t_str.strip())
                    bvec = eval(bvec_str.strip())  # Assumes trusted input
                    if isinstance(bvec, list) and len(bvec) > 1:
                        times.append(t)
                        betti1_vals.append(bvec[1])  # Betti-1
                except Exception as e:
                    print(f"  Parse error: {e} | Line: {line.strip()}")
                    continue

        if not betti1_vals:
            print("  No Betti-1 data parsed.")
            continue

        betti1_vals = np.array(betti1_vals)
        diffs = np.abs(np.diff(betti1_vals))
        threshold = diffs.mean() + 2 * diffs.std()
        spikes = np.where(diffs > threshold)[0]

        print(f" Found {len(spikes)} spikes at indices: {spikes.tolist()}")

        for t in spikes:
            ax.vlines(t, sid - 0.3, sid + 0.3, color='crimson', alpha=0.8, linewidth=1.5)

        yticks.append(sid)
        ylabels.append(f"Session {sid}")

    ax.set_yticks(yticks)
    ax.set_yticklabels(ylabels, fontsize=10)
    ax.set_xlabel("Time Index", fontsize=11)
    ax.set_title("Session-wise Betti-1 Anomaly Events", fontsize=13, pad=12)

    ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.6)
    fig.tight_layout()
    out_path = os.path.join(output_path, "betti1_anomaly_raster.png")
    fig.savefig(out_path, dpi=300)
    plt.close()

    print(f" Saved improved raster plot to: {out_path}")


def plot_persistence_barcode(diagrams, output_path):
    plt.figure(figsize=(7, 3))
    plot_diagrams(diagrams, show=False)
    plt.tight_layout()
    plt.savefig(os.path.join(output_path, "barcode_example.png"))
    plt.close()

# ---------------- Posthoc Main ----------------
def run_posthoc(summary_path, results_root):
    df = pd.read_csv(summary_path)

    print("\n[ANOVA Tests]")
    for metric in ["CTIS", "DeltaBetti2"]:
        groups = [group[metric].values for name, group in df.groupby("LesionType")]
        if all(len(g) > 1 for g in groups):
            stat, pval = f_oneway(*groups)
            print(f"{metric}: F = {stat:.3f}, p = {pval:.4f}")

    print("\n[CTIS vs AURC Correlation]")
    if len(df) >= 3:
        r, p = pearsonr(df["CTIS"], df["AURC"])
        print(f"CTIS vs AURC: r = {r:.3f}, p = {p:.4f}")
        sns.regplot(data=df, x="CTIS", y="AURC", scatter_kws={'s': 10})
        plt.tight_layout()
        plt.savefig(os.path.join(results_root, "ctis_vs_aurc.png"))
        plt.close()

    print("\n[SHAP Analysis]")
    feature_rows, ctis_targets = [], []
    feature_names = ["DeltaBetti2", "AURC", "Curvature"]
    for sess in os.listdir(results_root):
        if not sess.startswith("session_"): continue
        try:
            mfile = os.path.join(results_root, sess, "metrics.csv")
            emb = np.load(os.path.join(results_root, sess, "umap_base.npy"))
            curv = curvature_index(emb)
            dfm = pd.read_csv(mfile)
            for _, row in dfm.iterrows():
                feature_rows.append([row.get("DeltaBetti2", 0), row.get("AURC", 0), curv])
                ctis_targets.append(row["CTIS"])
        except Exception as e:
            print(f"Skip {sess}: {e}")
    if feature_rows:
        run_shap_analysis(np.array(feature_rows), np.array(ctis_targets), feature_names, results_root)

    print("\n[Gradient Attribution Clustering]")
    grad_files = [os.path.join(results_root, sess, "grad_U.npy") for sess in os.listdir(results_root) if sess.startswith("session_")]
    plot_grad_clustering(grad_files, results_root)

    print("\n[Betti-1 Spike Co-occurrence]")
    betti_txt_files = [
        os.path.join(betti_root, sess, "betti_timeseries.txt")
        for sess in os.listdir(betti_root)
        if sess.startswith("session_")
    ]
    plot_betti1_spikes_txt(betti_txt_files, results_root)


    print("\n[Persistence Barcode Example]")
    try:
        example = [os.path.join(results_root, s, "umap_base.npy") for s in os.listdir(results_root) if s.startswith("session_")][0]
        embed = np.load(example)
        _, diag = compute_betti_numbers(embed, return_diagrams=True)
        plot_persistence_barcode(diag, results_root)
    except Exception as e:
        print(f"Barcode plot skipped: {e}")

if __name__ == "__main__":
    summary_path = "ctis_master_results_0512/ctis_summary.csv"
    results_root = "ctis_master_results_0512"
    betti_root = "pfc7_result"
    run_posthoc(summary_path, results_root)


