### May 13 2025

import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from scipy.stats import f_oneway, pearsonr, ttest_rel
from sklearn.decomposition import PCA
from scipy.stats import f_oneway, pearsonr, ttest_rel
from ripser import ripser
from sklearn.decomposition import PCA
import shap
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
import seaborn as sns


def compute_betti_numbers(points):
    result = ripser(points, maxdim=2)
    diagrams = result['dgms']
    return [len(dgm) for dgm in diagrams]

def curvature_index(data):
    pca = PCA(n_components=3)
    pca.fit(data)
    return 1.0 - pca.explained_variance_ratio_[0]

def geometry_aware_ctis(beta_base, beta_pert, gamma_g, weights=(1, 1, 1)):
    while len(beta_base) < 3:
        beta_base.append(0)
    while len(beta_pert) < 3:
        beta_pert.append(0)
    return gamma_g * sum(w * abs(b1 - b2) for b1, b2, w in zip(beta_base, beta_pert, weights))



def run_shap_analysis(features, targets, feature_names, output_path):
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    X_train, X_test, y_train, y_test = train_test_split(features, targets, test_size=0.2, random_state=42)
    model.fit(X_train, y_train)

    explainer = shap.Explainer(model, X_train, feature_names=feature_names)
    shap_values = explainer(X_test)

    mean_abs_shap = np.abs(shap_values.values).mean(axis=0)
    sorted_idx = np.argsort(mean_abs_shap)[::-1]
    sorted_names = np.array(feature_names)[sorted_idx]
    sorted_vals = mean_abs_shap[sorted_idx]

    # Plot
    fig, ax = plt.subplots(figsize=(5, 3))  # smaller, more compact
    bars = ax.barh(sorted_names, sorted_vals, color='#1f77b4')
    ax.set_xlabel("Mean |SHAP value|", fontsize=6)
    ax.set_title("SHAP Feature Importance", fontsize=6, pad=10)
    ax.invert_yaxis()
    ax.grid(axis='x', linestyle='--', alpha=0.4)

    for bar in bars:
        width = bar.get_width()
        ax.text(width + 0.001, bar.get_y() + bar.get_height() / 2,
                f"{width:.3f}", va='center', ha='left', fontsize=6)

    plt.tight_layout()
    save_path = os.path.join(output_path, "shap_feature_importance.png")
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"SHAP feature importance saved to: {save_path}")



def run_all_posthoc_analyses(summary_path, results_root, data_dir):
    print(f"\n===> Loading summary from: {summary_path}")
    df = pd.read_csv(summary_path)

    # === ANOVA ===
    print("\n[ANOVA Tests]")
    for metric in ["CTIS", "DeltaBetti2"]:
        groups = [group[metric].values for name, group in df.groupby("LesionType")]
        if all(len(g) > 1 for g in groups):
            stat, pval = f_oneway(*groups)
            print(f"  {metric}: F = {stat:.3f}, p = {pval:.4f}")
        else:
            print(f"  Not enough samples for ANOVA on {metric}")

    # === Correlation ===
    print("\n[Correlation]")
    if len(df) >= 3:
        r, p = pearsonr(df["CTIS"], df["AURC"])
        print(f"  CTIS vs AURC: r = {r:.3f}, p = {p:.4f}")
    else:
        print("  Not enough points for correlation")

    # === SHAP Analysis ===
    print("\n[SHAP Analysis]")
    feature_rows, ctis_targets = [], []
    feature_names = ["DeltaBetti2", "AURC", "Curvature"]

    for sess in os.listdir(results_root):
        if not sess.startswith("session_"):
            continue
        sess_path = os.path.join(results_root, sess)
        metric_file = os.path.join(sess_path, "metrics.csv")
        umap_file = os.path.join(sess_path, "umap_base.npy")
        if not os.path.exists(metric_file) or not os.path.exists(umap_file):
            continue
        try:
            curvature = curvature_index(np.load(umap_file))
            dfm = pd.read_csv(metric_file)
            for _, row in dfm.iterrows():
                features = [
                    row.get("DeltaBetti2", 0),
                    row.get("AURC", 0),
                    curvature
                ]
                feature_rows.append(features)
                ctis_targets.append(row["CTIS"])
        except Exception as e:
            print(f"  SHAP load error in {sess}: {e}")

    if feature_rows and ctis_targets:
        run_shap_analysis(np.array(feature_rows), np.array(ctis_targets), feature_names, results_root)
    else:
        print("  Not enough data for SHAP")

    # === Curvature Modulation Ablation ===
    print("\n[Curvature Modulation Ablation]")
    curved_ctis, flat_ctis = [], []
    for sess in os.listdir(results_root):
        if not sess.startswith("session_"):
            continue
        emb_base = os.path.join(results_root, sess, "umap_base.npy")
        emb_lesion = os.path.join(results_root, sess, "umap_lesion.npy")
        if not (os.path.exists(emb_base) and os.path.exists(emb_lesion)):
            continue
        try:
            base = np.load(emb_base)
            lesion = np.load(emb_lesion)
            gamma = curvature_index(base)
            betti_base = compute_betti_numbers(base)
            betti_lesion = compute_betti_numbers(lesion)
            curved = geometry_aware_ctis(betti_base, betti_lesion, gamma)
            flat = geometry_aware_ctis(betti_base, betti_lesion, gamma_g=1.0)
            curved_ctis.append(curved)
            flat_ctis.append(flat)
        except Exception as e:
            print(f"  Error processing {sess}: {e}")

    if curved_ctis and flat_ctis:
        ctis_diff = np.array(curved_ctis) - np.array(flat_ctis)
        print(f"  Mean ΔCTIS (curved - flat): {ctis_diff.mean():.4f} ± {ctis_diff.std():.4f}")
        t_stat, p_val = ttest_rel(curved_ctis, flat_ctis)
        print(f"  Paired t-test: t = {t_stat:.3f}, p = {p_val:.4f}")
    else:
        print("  No CTIS data available for ablation analysis")

if __name__ == "__main__":
    summary_path = "ctis_master_results_0512/ctis_summary.csv"
    results_root = "ctis_master_results_0512"
    data_dir = "data"
    run_all_posthoc_analyses(summary_path, results_root, data_dir)

