import numpy as np
from scipy.interpolate import UnivariateSpline, LSQUnivariateSpline
from scipy.spatial import cKDTree
import math
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import ConvexHull
from shapely.geometry import Point
from shapely.ops import unary_union
import networkx as nx
from scipy.sparse import csr_matrix
from scipy.spatial.distance import pdist, squareform
from scipy.sparse.csgraph import minimum_spanning_tree
import os
import json
import pickle
from sklearn.metrics import silhouette_score
from sklearn.linear_model import LinearRegression

def _fit_principal_curve_2d(X, n_iter=10, s=None, k=2, n_grid=200, tol=1e-4, use_lsq=True, n_knots=0, verbose=False):
    """
    Fit a 2D principal curve via iterative smoothing splines.
    Returns: t (N,), curve_points (N,2), distances (N,), curve_length (float).
    """
    X = np.asarray(X, dtype=float)
    assert X.ndim == 2 and X.shape[1] == 2, "X must be (N,2)"

    # Center for numerical stability
    mu = X.mean(axis=0, keepdims=True)
    Xc = X - mu

    # --- Initialize parameter t on first PC ---
    # PCA in 2D: eigenvector of covariance with max eigenvalue
    C = np.cov(Xc.T)
    eigvals, eigvecs = np.linalg.eigh(C)
    v1 = eigvecs[:, np.argmax(eigvals)]
    t = Xc @ v1  # (N,)

    # Make t strictly increasing for spline fitting stability
    order = np.argsort(t)
    t_sorted = t[order]
    X_sorted = Xc[order]

    # Choose spline smoothness if not given: a gentle default
    if s is None:
        # scale with N; smaller s -> tighter curve
        s = 50 * len(t_sorted)  # tweakable

    # Iterations: fit x(t), y(t) splines and reproject points to closest curve point
    prev_t = t.copy()
    for it in range(n_iter):
        # Fit splines x(t), y(t) with error handling
        try:
            if use_lsq:
                # choose few interior knots at quantiles of current parameterization
                if n_knots > 0:
                    qs = np.linspace(0.2, 0.8, n_knots)
                    t_int = np.quantile(t_sorted, qs)
                else:
                    t_int = None  # no interior knots => global polynomial of degree k
                # LSQ fits with fixed knots ⇒ controlled flexibility
                spl_x = LSQUnivariateSpline(t_sorted, X_sorted[:, 0], t_int, k=k)
                spl_y = LSQUnivariateSpline(t_sorted, X_sorted[:, 1], t_int, k=k)
            else:
                spl_x = UnivariateSpline(t_sorted, X_sorted[:, 0], s=s, k=k)
                spl_y = UnivariateSpline(t_sorted, X_sorted[:, 1], s=s, k=k)
        except Exception as e:
            if verbose:
                print(f"[principal-curve] spline fitting failed with s={s}, trying with s=None (automatic)")
            # Fallback: let scipy choose s automatically
            spl_x = UnivariateSpline(t_sorted, X_sorted[:, 0], s=None, k=k)
            spl_y = UnivariateSpline(t_sorted, X_sorted[:, 1], s=None, k=k)

        # Dense curve sampling for nearest-point projection
        t_min, t_max = t_sorted[0], t_sorted[-1]
        t_grid = np.linspace(t_min, t_max, n_grid)
        curve_grid = np.column_stack([spl_x(t_grid), spl_y(t_grid)])  # (G,2)

        # Check for NaN or inf values in curve_grid
        if not np.all(np.isfinite(curve_grid)):
            if verbose:
                print(f"[principal-curve] curve_grid contains NaN/inf values, falling back to linear interpolation")
            # Fallback: use linear interpolation instead
            curve_grid = np.column_stack([
                np.interp(t_grid, t_sorted, X_sorted[:, 0]),
                np.interp(t_grid, t_sorted, X_sorted[:, 1])
            ])

        # Project each point to nearest curve sample
        tree = cKDTree(curve_grid)
        dists, idx = tree.query(Xc, k=1)
        t = t_grid[idx]
        # Reorder helper for next iteration’s spline fit
        order = np.argsort(t)
        t_sorted = t[order]
        X_sorted = Xc[order]

        # Convergence on t changes
        delta = np.mean((t - prev_t) ** 2) ** 0.5
        if verbose:
            print(f"[principal-curve] iter {it+1}: Δt={delta:.3e}")
        if delta < tol:
            break
        prev_t = t.copy()

    # Final curve positions at each point’s parameter
    # (evaluate splines directly at final t; no need to snap to grid again)
    try:
        curve_points = np.column_stack([spl_x(t), spl_y(t)])  # (N,2)
        # Check for NaN or inf values
        if not np.all(np.isfinite(curve_points)):
            raise ValueError("curve_points contains NaN/inf values")
    except Exception:
        if verbose:
            print("[principal-curve] final spline evaluation failed, using linear interpolation")
        # Fallback: linear interpolation
        curve_points = np.column_stack([
            np.interp(t, t_sorted, X_sorted[:, 0]),
            np.interp(t, t_sorted, X_sorted[:, 1])
        ])
    
    # Orthogonal distances
    distances = np.linalg.norm(Xc - curve_points, axis=1)
    
    # Curve length along a fine grid
    t_grid = np.linspace(t_sorted[0], t_sorted[-1], n_grid)
    try:
        Cg = np.column_stack([spl_x(t_grid), spl_y(t_grid)])
        if not np.all(np.isfinite(Cg)):
            raise ValueError("Cg contains NaN/inf values")
    except Exception:
        if verbose:
            print("[principal-curve] curve length computation failed, using linear interpolation")
        # Fallback: linear interpolation for curve length
        Cg = np.column_stack([
            np.interp(t_grid, t_sorted, X_sorted[:, 0]),
            np.interp(t_grid, t_sorted, X_sorted[:, 1])
        ])
    
    segs = np.linalg.norm(np.diff(Cg, axis=0), axis=1)
    curve_length = float(np.sum(segs))

    # Undo centering if you need absolute coordinates (not needed for thickness)
    # curve_points_abs = curve_points + mu

    return t, curve_points, distances, curve_length


def thicknessPrincipalCurve(embs, outfile, n_iter=12, s=None, k=3, n_grid=300, tol=1e-4,
                            bootstrap=0, seed=0):
    """
    Compute principal-curve thickness per layer.

    Args:
        embs: list of arrays; each array is (N,2) for a layer’s UMAP embedding.
        n_iter, s, k, n_grid, tol: principal-curve fit hyperparams.
        bootstrap: if >0, do B bootstrap resamples per layer and return CIs.
        seed: RNG seed for bootstrap.

    Returns:
        results: dict with
            - 'thickness': (L,) RMS orthogonal distance per layer
            - 'length':    (L,) curve length per layer
            - 'thickness_ci': (L,2) [low, high] if bootstrap>0 else None
            - 'length_ci':    (L,2) [low, high] if bootstrap>0 else None
            - 'per_layer': list of dicts with raw 'distances' array for each layer
    """
    rng = np.random.default_rng(seed)
    L = len(embs)
    thickness = np.zeros(L, dtype=float)
    length = np.zeros(L, dtype=float)
    thickness_ci = np.full((L, 2), np.nan)
    length_ci = np.full((L, 2), np.nan)
    per_layer = []

    for l, X in enumerate(embs):
        # Fit on full data
        _, _, dists, Lcurve = _fit_principal_curve_2d(
            X, n_iter=n_iter, s=s, k=k, n_grid=n_grid, tol=tol, verbose=False
        )
        tau = float(np.sqrt(np.mean(dists ** 2)))  # RMS thickness
        thickness[l] = tau
        length[l] = Lcurve
        per_layer.append({"distances": dists})

        # Optional bootstrap CIs
        if bootstrap and len(X) >= 10:
            t_boot, L_boot = [], []
            N = len(X)
            for _ in range(bootstrap):
                idx = rng.integers(0, N, size=N)
                Xb = X[idx]
                try:
                    _, _, db, Lb = _fit_principal_curve_2d(
                        Xb, n_iter=n_iter, s=s, k=k, n_grid=n_grid, tol=tol, verbose=False
                    )
                    t_boot.append(np.sqrt(np.mean(db ** 2)))
                    L_boot.append(Lb)
                except Exception:
                    # occasionally splines can fail on degenerate resamples; skip
                    continue
            if len(t_boot) > 0:
                thickness_ci[l, 0] = np.percentile(t_boot, 2.5)
                thickness_ci[l, 1] = np.percentile(t_boot, 97.5)
            if len(L_boot) > 0:
                length_ci[l, 0] = np.percentile(L_boot, 2.5)
                length_ci[l, 1] = np.percentile(L_boot, 97.5)

    np.savez(outfile,
                coeffs=thickness, 
                coeffs_ci_lower=thickness_ci[:, 0], 
                coeffs_ci_upper=thickness_ci[:, 1])

    # return {
    #     "thickness": thickness,
    #     "length": length,
    #     "thickness_ci": None if bootstrap == 0 else thickness_ci,
    #     "length_ci": None if bootstrap == 0 else length_ci,
    #     "per_layer": per_layer,
    # }

def plotPrincipalCurveGrid(

    embs,
    outfile_png,
    n_cols=5,
    point_size=8,
    point_alpha=0.35,
    curve_lw=2.0,
    curve_alpha=0.9,
    rasterized=True,
    # pass-through to _fit_principal_curve_2d

    n_iter=12, s=None, k=3, n_grid=300, tol=1e-4, verbose=False,
):
    """
    For each layer embedding X (N,2):
      • fits the principal curve via _fit_principal_curve_2d,
      • plots points and the fitted curve,
      • annotates RMS thickness τ in the title.
    Saves a grid figure to `outfile_png`.
    """

    L = len(embs)
    n_rows = math.ceil(L / n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows), squeeze=False)

    for l, X in enumerate(embs):
        ax = axes[l // n_cols, l % n_cols]
        X = np.asarray(X, dtype=float)
        assert X.ndim == 2 and X.shape[1] == 2, "Each embedding must be (N,2)"


        # Fit using your existing routine (returns t, curve_points in centered coords)

        t, curve_points_centered, dists, Lcurve = _fit_principal_curve_2d(
            X, n_iter=n_iter, s=s, k=k, n_grid=n_grid, tol=tol, verbose=verbose
        )

        # Convert curve back to original coordinates

        mu = X.mean(axis=0, keepdims=True)
        curve_points = curve_points_centered + mu

        # Order the curve along its parameter t for a clean line plot

        order = np.argsort(t)
        curve_ordered = curve_points[order]

        # Scatter the points

        ax.scatter(
            X[:, 0], X[:, 1],
            s=point_size, alpha=point_alpha, edgecolors='none',
            rasterized=rasterized,
        )

        # Plot the principal curve

        ax.plot(
            curve_ordered[:, 0], curve_ordered[:, 1],
            linewidth=curve_lw, alpha=curve_alpha,
        )

        # Thickness

        tau = float(np.sqrt(np.mean(dists ** 2)))
        ax.set_title(f"Layer {l} — τ={tau:.3g}")
        ax.set_xticks([]); ax.set_yticks([])
        ax.grid(True, ls="--", alpha=0.15)

    # Turn off unused panels

    for i in range(L, n_rows * n_cols):
        axes[i // n_cols, i % n_cols].axis("off")

    plt.tight_layout()
    fig.savefig(outfile_png, dpi=300)
    print(f"Saved principal-curve overlays → {outfile_png}")
    plt.show()




# ---------- 1) TWO-NN intrinsic dimension ----------
def two_nn_id(X):
    # X: (N,2)
    nbrs = NearestNeighbors(n_neighbors=3).fit(X)
    dists, _ = nbrs.kneighbors(X)   # includes self as 0
    r1 = dists[:,1]; r2 = dists[:,2]
    ratios = (r2 + 1e-12) / (r1 + 1e-12)
    # MLE of ID: slope of log(ratios) CDF ~ Pareto(μ)
    z = np.sort(np.log(ratios))
    u = (np.arange(1, len(z)+1)) / (len(z)+1)
    # μ = 1 / slope of (1-u) vs exp(-μ z) ; simpler MLE:
    mu = 1.0 / (np.mean(np.log(ratios)))
    return float(mu)

# ---------- 2) Local anisotropy (mean±std) ----------
def local_anisotropy(X, k=20):
    nbrs = NearestNeighbors(n_neighbors=k).fit(X)
    _, idx = nbrs.kneighbors(X)
    A = []
    for i in range(len(X)):
        Xi = X[idx[i]]
        C = np.cov((Xi - Xi.mean(0)).T)
        w, _ = np.linalg.eigh(C)
        l1, l2 = w[1], w[0] + 1e-12
        A.append(1.0 - l2/l1)
    A = np.array(A)
    return float(A.mean()), float(A.std())

# ---------- 3) Graph 1D-ness on kNN graph ----------
# def graph_1dness(X, k=10):
#     nbrs = NearestNeighbors(n_neighbors=k).fit(X)
#     _, idx = nbrs.kneighbors(X)
#     G = nx.Graph()
#     G.add_nodes_from(range(len(X)))
#     for i in range(len(X)):
#         for j in idx[i,1:]:
#             if i != j:
#                 G.add_edge(int(i), int(j))
#     degs = np.array([d for _, d in G.degree()])
#     V = G.number_of_nodes(); E = G.number_of_edges()
#     C = nx.number_connected_components(G)
#     cycle_rank = E - V + C
#     frac_deg2 = np.mean(degs == 2)
#     endpoint_frac = np.mean(degs == 1)
#     branch_frac = np.mean(degs >= 3)
#     return dict(frac_deg2=float(frac_deg2),
#                 endpoint_frac=float(endpoint_frac),
#                 branch_frac=float(branch_frac),
#                 cycle_rank=float(cycle_rank))

def graph_1dness(
    X,
    mode="mutual-knn",
    k=3,
    eps=None,
    eps_scale=1.5,
    return_graph=False,
):
    """
    Build a sparse proximity graph and report 1D-ness degree stats.

    mode:
      - 'mutual-knn' : undirected edge i--j iff i in kNN(j) AND j in kNN(i)
      - 'mst'        : Euclidean minimum spanning tree (single linkage)
      - 'epsilon'    : connect if ||xi-xj|| <= eps; if eps is None, use
                       eps = eps_scale * median r_k(i) (adaptive)

    Returns dict with: frac_deg2, endpoint_frac, branch_frac, cycle_rank, avg_degree
    If return_graph=True, also returns the NetworkX graph under key 'G'.
    """
    X = np.asarray(X, float)
    n = len(X)
    if n < 3:
        out = dict(frac_deg2=0.0, endpoint_frac=0.0, branch_frac=0.0,
                   cycle_rank=0.0, avg_degree=0.0)
        if return_graph: out["G"] = nx.Graph()
        return out

    if mode == "mutual-knn":
        # Build directed kNN, then keep only mutual edges
        k_eff = min(k, n-1)
        nbrs = NearestNeighbors(n_neighbors=k_eff).fit(X)
        _, idx = nbrs.kneighbors(X)  # idx[:,0] is self
        # directed adjacency in COO lists
        rows, cols = [], []
        for i in range(n):
            for j in idx[i,1:]:
                rows.append(i); cols.append(int(j))
        # mutual: keep i->j where j->i also exists
        A = csr_matrix((np.ones(len(rows), dtype=bool), (rows, cols)), shape=(n,n))
        M = A.multiply(A.T)  # mutual edges
        M.setdiag(False); M.eliminate_zeros()
        G = nx.from_scipy_sparse_array(M, create_using=nx.Graph)

    elif mode == "mst":
        # Euclidean MST over all points (O(n^2) distances; fine for n~100–2k)
        D = squareform(pdist(X))
        T = minimum_spanning_tree(csr_matrix(D))
        T = T + T.T  # make undirected
        G = nx.from_scipy_sparse_array(T, create_using=nx.Graph)

    elif mode == "epsilon":
        # Adaptive epsilon from kNN radii if not provided
        if eps is None:
            k_eff = min(k, n-1)
            nbrs = NearestNeighbors(n_neighbors=k_eff).fit(X)
            dists, _ = nbrs.kneighbors(X)
            r_k = dists[:, -1]
            eps = float(eps_scale * np.median(r_k))
        # Build ε-graph
        D = squareform(pdist(X))
        mask = (D <= eps) & (D > 0)
        rows, cols = np.where(mask)
        G = nx.Graph()
        G.add_nodes_from(range(n))
        G.add_edges_from(zip(rows.tolist(), cols.tolist()))

    else:
        raise ValueError("mode must be 'mutual-knn', 'mst', or 'epsilon'")

    # Degree stats
    degs = np.array([d for _, d in G.degree()], dtype=int)
    V = G.number_of_nodes(); E = G.number_of_edges()
    C = nx.number_connected_components(G)
    cycle_rank = float(E - V + C)   # 0 for a forest; >0 implies cycles

    out = dict(
        frac_deg2=float(np.mean(degs == 2)),
        endpoint_frac=float(np.mean(degs == 1)),
        branch_frac=float(np.mean(degs >= 3)),
        cycle_rank=cycle_rank,
        avg_degree=float(degs.mean() if V else 0.0),
    )
    if return_graph:
        out["G"] = G
    return out

# ---------- 4) Fill ratio (areal occupancy) ----------
def fill_ratio(X, r_scale=0.05):
    # Convex hull area
    hull = ConvexHull(X)
    A_hull = float(hull.area if hasattr(hull, 'area') else hull.volume)
    # Union of small disks around points (kernel dilation)
    # radius set by data scale: r = r_scale * overall std
    s = np.sqrt(np.trace(np.cov(X.T)))
    r = r_scale * s
    union = unary_union([Point(*p).buffer(r) for p in X])
    A_union = float(union.area)
    return float(np.clip(A_union / (A_hull + 1e-12), 0.0, 1.0))

# ---------- 5) Principal-line thickness (RMS) ----------
def principal_line_thickness(X):
    Xc = X - X.mean(0, keepdims=True)
    C = np.cov(Xc.T)
    w, V = np.linalg.eigh(C)
    v1 = V[:, np.argmax(w)]
    t = Xc @ v1
    proj = np.outer(t, v1)
    dists = np.linalg.norm(Xc - proj, axis=1)
    tau = float(np.sqrt(np.mean(dists**2)))
    # normalize by global radius for comparability
    R = float(np.sqrt(np.sum(w)))
    return tau, (tau / (R + 1e-12))


def manifold_1d_tests_per_layer(embs, k_aniso=20, k_graph=3, r_scale=0.05, graph_mode="mutual-knn"):
    L = len(embs)
    ID = np.zeros(L); Amean = np.zeros(L); Astd = np.zeros(L)
    frac2 = np.zeros(L); frac1 = np.zeros(L); frac3p = np.zeros(L); cycles = np.zeros(L); avgdeg = np.zeros(L)
    FR = np.zeros(L); tau = np.zeros(L); tauN = np.zeros(L)

    for ℓ, X in enumerate(embs):
        X = np.asarray(X, float)
        ID[ℓ] = two_nn_id(X)
        Amean[ℓ], Astd[ℓ] = local_anisotropy(X, k=k_aniso)

        g = graph_1dness(X, mode=graph_mode, k=k_graph)
        frac2[ℓ], frac1[ℓ], frac3p[ℓ], cycles[ℓ], avgdeg[ℓ] = (
            g['frac_deg2'], g['endpoint_frac'], g['branch_frac'], g['cycle_rank'], g['avg_degree']
        )

        FR[ℓ] = fill_ratio(X, r_scale=r_scale)
        tau[ℓ], tauN[ℓ] = principal_line_thickness(X)

    return dict(
        id=ID,
        aniso_mean=Amean, aniso_std=Astd,
        frac_deg2=frac2, endpoint_frac=frac1, branch_frac=frac3p, cycle_rank=cycles, avg_degree=avgdeg,
        fill_ratio=FR,
        tau=tau, tau_norm=tauN
    )


def get_data_for_maps(ageAct=None, diseaseAct=None, top100DrugAct=None, symptomAct=None, dosageAct=None):
    import numpy as np
    import pickle, os, json
    analyses = []

    def add_analysis(act_obj, name, color):
        if act_obj is None:
            return
        try:
            umap_metric = None
            try:
                # Support .npz metrics or pickle cache with extended fields
                if act_obj.umapAggFile.endswith('.npz'):
                    umap_data = np.load(act_obj.umapAggFile, allow_pickle=True)
                    if 'aniso_mean' in umap_data.files:
                        umap_metric = umap_data['aniso_mean']
                    elif 'coeffs' in umap_data.files:
                        umap_metric = umap_data['coeffs']
                # Also try side-cache if exists
                cache_path = getattr(act_obj, 'umapCacheFile', None)
                if umap_metric is None and cache_path and os.path.exists(cache_path):
                    with open(cache_path, 'rb') as f:
                        cache = pickle.load(f)
                    # Prefer 1D-ness for age, else clustering coeffs if present
                    if 'aniso_mean' in cache:
                        umap_metric = cache['aniso_mean']
                    elif 'coeffs' in cache:
                        umap_metric = cache['coeffs']
            except Exception:
                pass

            saliency_mean = None
            try:
                sal_data = pickle.load(open(act_obj.saliencyFile, 'rb'))
                if 'avg_sal_per_prompt' in sal_data:
                    avg_sal = sal_data['avg_sal_per_prompt']
                    saliency_mean, _, _ = act_obj.compute_saliency_per_layer(avg_sal)
            except Exception:
                pass

            lesion_scores = None
            try:
                lesion_file = f"results/{act_obj.analysis_name}_lesioning_{act_obj.MODEL_NAME}.json"
                with open(lesion_file, 'r') as f:
                    lesion_data = json.load(f)
                prompt_scores = lesion_data.get('prompt_scores', {})
                layer_scores = {}
                for prompt_result in prompt_scores.values():
                    for layer_idx, score_info in prompt_result.get('scores_and_justifications', {}).items():
                        li = int(layer_idx)
                        if li >= act_obj.num_layers:
                            continue
                        layer_scores.setdefault(li, []).append(score_info['score'])
                lesion_scores = [np.mean(layer_scores[i]) if i in layer_scores else 0.0 for i in range(act_obj.num_layers)]
            except Exception:
                pass

            act_patch_scores = None
            try:
                ap_file = f"results/{act_obj.analysis_name}_activation_patching_{act_obj.MODEL_NAME}.json"
                with open(ap_file, 'r') as f:
                    ap_data = json.load(f)
                if 'all_patching_results' in ap_data:
                    layer_results = ap_data['all_patching_results']
                    layer_effects = {}
                    for layer_idx_str, layer_result in layer_results.items():
                        li = int(layer_idx_str)
                        if li >= act_obj.num_layers:
                            continue
                        pe = layer_result.get('patching_effect', {})
                        for _, effect in pe.items():
                            if effect != 0:
                                layer_effects.setdefault(li, []).append(effect)
                    act_patch_scores = [np.mean(layer_effects[i]) if i in layer_effects else 0.0 for i in range(act_obj.num_layers)]
            except Exception:
                pass

            analyses.append({
                'name': name,
                'umap_metric': umap_metric,
                'saliency': saliency_mean,
                'lesioning': lesion_scores,
                'activation_patching': act_patch_scores,
                'num_layers': act_obj.num_layers,
                'color': color
            })
        except Exception as e:
            print(f"Error loading {name} data: {e}")

    add_analysis(ageAct, 'Age', 'red')
    add_analysis(symptomAct, 'Symptoms', 'brown')
    add_analysis(diseaseAct, 'Diseases', 'purple')
    add_analysis(top100DrugAct, 'Drugs', 'orange')
    add_analysis(dosageAct, 'Dosage', 'purple')

    return analyses

def main_map(ageAct=None, diseaseAct=None, top100DrugAct=None, symptomAct=None, dosageAct=None, text_on_bars=False):
    """
    Creates a main LLM map integrating all analyses showing layers where key events happen.
    
    Args:
        ageAct: AgeAct object with results
        diseaseAct: DiseaseAct object with results  
        top100DrugAct: Top100DrugAct object with results
        symptomAct: SymptomAct object with results
        dosageAct: DosageAct object with results
    """
    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.ndimage import gaussian_filter1d
    import pickle
    
    print('Creating main LLM map...')
    
    # Define analysis configurations
    analyses = get_data_for_maps(ageAct=ageAct, diseaseAct=diseaseAct, top100DrugAct=top100DrugAct, symptomAct=symptomAct, dosageAct=dosageAct)
    

    # Helper function to find peak increase interval
    def find_peak_increase_interval(data, window_size=3):
        """Find the interval with the highest rate of increase after smoothing"""
        if data is None or len(data) < window_size:
            return None, None
        
        try:
            # Apply Gaussian smoothing
            smoothed = gaussian_filter1d(data, sigma=1.0)
            
            # Compute rate of change (first derivative)
            diff = np.diff(smoothed)
            
            # Find the window with highest average increase
            max_avg_increase = -np.inf
            best_start = None
            
            for i in range(len(diff) - window_size + 1):
                window_avg = np.mean(diff[i:i+window_size])
                if window_avg > max_avg_increase:
                    max_avg_increase = window_avg
                    best_start = i
            
            if best_start is not None:
                return best_start + 1, best_start + window_size + 1  # Convert to layer indices
            return None, None
        except Exception as e:
            print(f"Error in find_peak_increase_interval: {e}")
            return None, None
    
    # Helper function to find high saliency intervals
    def find_high_saliency_intervals(data, threshold_percentile=75, min_interval_length=2, max_intervals=3):
        """Find intervals where saliency is higher than the threshold"""
        if data is None or len(data) < min_interval_length:
            return []
        
        try:
            # Apply Gaussian smoothing
            smoothed = gaussian_filter1d(data, sigma=1.0)
            
            # Calculate threshold based on percentile
            threshold = np.percentile(smoothed, threshold_percentile)
            
            # Find regions above threshold
            above_threshold = smoothed > threshold
            
            # Find connected components (intervals)
            intervals = []
            start = None
            
            for i, is_above in enumerate(above_threshold):
                if is_above and start is None:
                    start = i
                elif not is_above and start is not None:
                    if i - start >= min_interval_length:
                        intervals.append((start, i))
                    start = None
            
            # Handle case where interval extends to the end
            if start is not None and len(data) - start >= min_interval_length:
                intervals.append((start, len(data)))
            
            # Sort by average saliency in each interval and take top max_intervals
            if len(intervals) > max_intervals:
                interval_scores = []
                for start, end in intervals:
                    avg_saliency = np.mean(smoothed[start:end])
                    interval_scores.append((avg_saliency, start, end))
                
                # Sort by average saliency (descending) and take top max_intervals
                interval_scores.sort(key=lambda x: x[0], reverse=True)
                intervals = [(start, end) for _, start, end in interval_scores[:max_intervals]]
                intervals.sort(key=lambda x: x[0])  # Sort by start position
            
            return intervals
            
        except Exception as e:
            print(f"Error in find_high_saliency_intervals: {e}")
            return []
    
    # Helper function to find high lesioning intervals
    def find_high_lesioning_intervals(data, threshold_percentile=75, min_interval_length=2, max_intervals=3):
        """Find intervals where lesioning scores are highest (indicating significant degradation)"""
        if data is None or len(data) < min_interval_length:
            return []
        
        try:
            # Apply Gaussian smoothing
            smoothed = gaussian_filter1d(data, sigma=1.0)
            
            # Calculate threshold based on percentile (higher scores = more degradation)
            threshold = np.percentile(smoothed, threshold_percentile)
            
            # Find regions above threshold
            above_threshold = smoothed > threshold
            
            # Find connected components (intervals)
            intervals = []
            start = None
            
            for i, is_above in enumerate(above_threshold):
                if is_above and start is None:
                    start = i
                elif not is_above and start is not None:
                    if i - start >= min_interval_length:
                        intervals.append((start, i))
                    start = None
            
            # Handle case where interval extends to the end
            if start is not None and len(data) - start >= min_interval_length:
                intervals.append((start, len(data)))
            
            # Sort by average lesioning score in each interval and take top max_intervals
            if len(intervals) > max_intervals:
                interval_scores = []
                for start, end in intervals:
                    avg_lesioning = np.mean(smoothed[start:end])
                    interval_scores.append((avg_lesioning, start, end))
                
                # Sort by average lesioning score (descending) and take top max_intervals
                interval_scores.sort(key=lambda x: x[0], reverse=True)
                intervals = [(start, end) for _, start, end in interval_scores[:max_intervals]]
                intervals.sort(key=lambda x: x[0])  # Sort by start position
            
            return intervals
            
        except Exception as e:
            print(f"Error in find_high_lesioning_intervals: {e}")
            return []
    
    # Helper function to find high activation patching intervals
    def find_high_activation_patching_intervals(data, threshold_percentile=75, min_interval_length=2, max_intervals=3):
        """Find intervals where activation patching effects are highest (indicating significant causal role)"""
        if data is None or len(data) < min_interval_length:
            return []
        
        try:
            # Apply Gaussian smoothing
            smoothed = gaussian_filter1d(data, sigma=1.0)
            
            # Calculate threshold based on percentile (higher absolute values = more significant effect)
            threshold = np.percentile(smoothed, threshold_percentile)
            
            # Find regions above threshold (both positive and negative effects)
            above_threshold = smoothed > threshold
            
            # Find connected components (intervals)
            intervals = []
            start = None
            
            for i, is_above in enumerate(above_threshold):
                if is_above and start is None:
                    start = i
                elif not is_above and start is not None:
                    if i - start >= min_interval_length:
                        intervals.append((start, i))
                    start = None
            
            # Handle case where interval extends to the end
            if start is not None and len(data) - start >= min_interval_length:
                intervals.append((start, len(data)))
            
            # Sort by average absolute activation patching effect in each interval and take top max_intervals
            if len(intervals) > max_intervals:
                interval_scores = []
                for start, end in intervals:
                    avg_effect = np.mean(smoothed[start:end])
                    interval_scores.append((avg_effect, start, end))
                
                # Sort by average absolute effect (descending) and take top max_intervals
                interval_scores.sort(key=lambda x: x[0], reverse=True)
                intervals = [(start, end) for _, start, end in interval_scores[:max_intervals]]
                intervals.sort(key=lambda x: x[0])  # Sort by start position
            
            return intervals
            
        except Exception as e:
            print(f"Error in find_high_activation_patching_intervals: {e}")
            return []
    
    # Create the plot
    fig, ax = plt.subplots(figsize=(15, 8))
    
    # Set up the plot
    # Shift layer indexing to start at 1 instead of 0 for display
    ax.set_xlim(1, ageAct.num_layers)
    ax.set_ylim(-0.5, len(analyses) * 2.5 - 0.5)  # Increased height to accommodate activation patching
    ax.invert_yaxis()  # Invert y-axis so higher y values appear at bottom
    ax.set_xlabel('Layer', fontsize=12)
    ax.set_title(f'Main LLM Map: {ageAct.MODEL_NAME}', fontsize=14)
    # Only show vertical grid lines; remove horizontal ones
    ax.grid(False)
    ax.grid(axis='x', alpha=0.3)
    
    # Add horizontal separator lines between analyses
    # Place halfway between the bottom row (activation patching) of the current
    # category and the top row (UMAP) of the next category.
    for i in range(len(analyses) - 1):
        separator_y = (i * 2.5) + 2.0
        ax.axhline(y=separator_y, color='black', linewidth=1.0, alpha=0.4)
    
    # Plot each analysis (reverse order so age is at top)
    for i, analysis in enumerate(analyses):
        y_pos = i * 2.5  # Increased spacing between analysis categories
        
        print(f"\nProcessing {analysis['name']}:")
        
        # Find peak intervals
        umap_start, umap_end = find_peak_increase_interval(analysis['umap_metric'])
        sal_intervals = find_high_saliency_intervals(analysis['saliency'])
        lesion_intervals = find_high_lesioning_intervals(analysis['lesioning'])
        activation_patching_intervals = find_high_activation_patching_intervals(analysis['activation_patching'])

        # Shift all intervals by +1 for 1-based layer indexing in the figure
        def shift_interval(a, b):
            if a is None or b is None:
                return None, None
            # Ensure we don't exceed the actual number of layers
            return a + 1, min(b + 1, ageAct.num_layers)

        def shift_intervals(intervals):
            return [(s + 1, min(e + 1, ageAct.num_layers)) for (s, e) in intervals] if intervals else []

        umap_start_p, umap_end_p = shift_interval(umap_start, umap_end)
        sal_intervals_p = shift_intervals(sal_intervals)
        lesion_intervals_p = shift_intervals(lesion_intervals)
        activation_patching_intervals_p = shift_intervals(activation_patching_intervals)
        
        print(f"  UMAP interval: {umap_start_p}-{umap_end_p}")
        print(f"  Saliency intervals: {sal_intervals_p}")
        print(f"  Lesioning intervals: {lesion_intervals_p}")
        # Convert activation patching intervals to 1-based for display
        activation_patching_intervals_1based = [(s+1, min(e+1, ageAct.num_layers)) for (s, e) in activation_patching_intervals]
        print(f"  Activation patching intervals: {activation_patching_intervals_1based}")
        
        # Plot UMAP interval (blue) - top row
        if umap_start_p is not None and umap_end_p is not None:
            ax.hlines(y=y_pos, xmin=umap_start_p, xmax=umap_end_p, 
                        colors='blue', linewidth=8, alpha=0.7, label='UMAP' if i == 0 else "")
            # Add black boundary
            ax.hlines(y=y_pos, xmin=umap_start_p, xmax=umap_end_p, 
                        colors='black', linewidth=10, alpha=0.3)
            ax.hlines(y=y_pos, xmin=umap_start_p, xmax=umap_end_p, 
                        colors='blue', linewidth=8, alpha=0.7)
            if text_on_bars:
                # Text inside the bar on single line
                text_x = (umap_start_p + umap_end_p) / 2
                umap_label = (
                    f"{analysis['name']} ({umap_start_p}-{umap_end_p})"
                    if (umap_end_p - umap_start_p) == 2 else
                    f"{analysis['name']} UMAP (layers {umap_start_p}-{umap_end_p})"
                )
                ax.text(text_x, y_pos, umap_label, fontsize=9, ha='center', va='center', color='white', weight='bold')
        
        # Plot saliency intervals (green) - second row
        for j, (sal_start, sal_end) in enumerate(sal_intervals_p):
            # Offset multiple intervals slightly to avoid overlap
            y_offset = y_pos + 0.5 + (j * 0.1)
            # Add black boundary
            ax.hlines(y=y_offset, xmin=sal_start, xmax=sal_end, 
                        colors='black', linewidth=10, alpha=0.3)
            ax.hlines(y=y_offset, xmin=sal_start, xmax=sal_end, 
                        colors='green', linewidth=8, alpha=0.7, label='Saliency' if i == 0 and j == 0 else "")
            if text_on_bars:
                # Text inside the bar on single line
                text_x = (sal_start + sal_end) / 2
                sal_label = (
                    f"{analysis['name']} ({sal_start}-{sal_end})"
                    if (sal_end - sal_start) == 2 else
                    f"{analysis['name']} Saliency (layers {sal_start}-{sal_end})"
                )
                ax.text(text_x, y_offset, sal_label, fontsize=9, ha='center', va='center', color='white', weight='bold')
        
        # Plot lesioning intervals (red) - third row
        for j, (lesion_start, lesion_end) in enumerate(lesion_intervals_p):
            # Offset multiple intervals slightly to avoid overlap
            y_offset = y_pos + 1.0 + (j * 0.1)
            # Add black boundary
            ax.hlines(y=y_offset, xmin=lesion_start, xmax=lesion_end, 
                        colors='black', linewidth=10, alpha=0.3)
            ax.hlines(y=y_offset, xmin=lesion_start, xmax=lesion_end, 
                        colors='red', linewidth=8, alpha=0.7, label='Lesioning' if i == 0 and j == 0 else "")
            if text_on_bars:
                # Text inside the bar on single line
                text_x = (lesion_start + lesion_end) / 2
                lesion_label = (
                    f"{analysis['name']} ({lesion_start}-{lesion_end})"
                    if (lesion_end - lesion_start) == 2 else
                    f"{analysis['name']} Lesioning (layers {lesion_start}-{lesion_end})"
                )
                ax.text(text_x, y_offset, lesion_label, fontsize=9, ha='center', va='center', color='white', weight='bold')
        
        # Plot activation patching intervals (orange) - fourth row (bottom)
        for j, (patch_start, patch_end) in enumerate(activation_patching_intervals):
            # Convert to 1-based indexing and ensure we don't exceed the actual number of layers
            patch_start_1based = patch_start + 1
            patch_end_1based = min(patch_end + 1, ageAct.num_layers)
            # Offset multiple intervals slightly to avoid overlap
            y_offset = y_pos + 1.5 + (j * 0.1)
            # Add black boundary
            ax.hlines(y=y_offset, xmin=patch_start_1based, xmax=patch_end_1based, 
                        colors='black', linewidth=10, alpha=0.3)
            ax.hlines(y=y_offset, xmin=patch_start_1based, xmax=patch_end_1based, 
                        colors='orange', linewidth=8, alpha=0.7, label='Act. Pat.' if i == 0 and j == 0 else "")
            if text_on_bars:
                # Text inside the bar on single line
                text_x = (patch_start_1based + patch_end_1based) / 2
                patch_label = (
                    f"{analysis['name']} ({patch_start_1based}-{patch_end_1based})"
                    if (patch_end_1based - patch_start_1based) == 2 else
                    f"{analysis['name']} Act. Pat. (layers {patch_start_1based}-{patch_end_1based})"
                )
                ax.text(text_x, y_offset, patch_label, fontsize=9, ha='center', va='center', color='white', weight='bold')
        
        # Add analysis name on the left
        # ax.text(-2, y_pos + 0.25, analysis['name'], fontsize=12, ha='right', va='center', 
                # weight='bold', color=analysis['color'])
    
    # Add legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], color='blue', lw=8, alpha=0.7, label='UMAP-based'),
        Line2D([0], [0], color='green', lw=8, alpha=0.7, label='Saliency-based'),
        Line2D([0], [0], color='red', lw=8, alpha=0.7, label='Lesioning-based'),
        Line2D([0], [0], color='orange', lw=8, alpha=0.7, label='Activation Patching-based')
    ]
    ax.legend(handles=legend_elements, loc='upper right', framealpha=1.0, facecolor='white')
    
    # Set y-ticks (normal order since we inverted the y-axis)
    # Place ticks midway between Saliency (y_pos+0.5) and Lesioning (y_pos+1.0)
    ax.set_yticks([i * 2.5 + 0.75 for i in range(len(analyses))])
    ax.set_yticklabels([analysis['name'] for analysis in analyses])
    
    # Move legend up by ~ two bar-heights (1.0 in data coords)
    y_data_range = len(analyses) * 3.3
    y_frac_shift = 1.0 / y_data_range if y_data_range > 0 else 0.05
    # Rebuild legend with new anchor
    ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1, 1 + y_frac_shift), framealpha=1.0, facecolor='white')
    
    # Save the plot
    out_file = f"results/main_llm_map_{ageAct.MODEL_NAME}.pdf"
    fig.savefig(out_file, dpi=300, bbox_inches="tight")
    print(f"Saved main LLM map → {out_file}")
    plt.show()


def main_map_continuous(ageAct=None, diseaseAct=None, top100DrugAct=None, symptomAct=None, dosageAct=None,
                        max_opacity=1, min_opacity=0, transition_sharpness=10.0, show_text=True,
                        text_on_bars=False,
                        low_opacity_percentile=0.75):
    """
    Continuous variant of main_map. Instead of discrete intervals per metric, we render
    continuous opacity along the layer axis, where opacity correlates with metric magnitude.
    The opacity is scaled so that values at or above the 75th percentile map to max_opacity,
    and values far below smoothly transition towards min_opacity using a logistic curve
    controlled by transition_sharpness.

    Args mirror main_map; additional controls:
    - max_opacity: maximum alpha for strongest regions
    - min_opacity: minimum alpha for weakest regions
    - transition_sharpness: larger -> crisper transition near threshold
    - low_opacity_percentile: values below this percentile (fraction or 0-100) are clamped
      to min_opacity (default 0.3 = 30%)
    - show_text: keep the same text overlays as main_map when True
    """
    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.ndimage import gaussian_filter1d
    import pickle, os

    print('Creating main LLM map (continuous opacity)...')

    # Reuse collection logic from main_map to assemble analyses data
    analyses = []

    def add_analysis(act_obj, name, color):
        if act_obj is None:
            return
        try:
            umap_metric = None
            try:
                if act_obj.umapAggFile.endswith('.npz'):
                    umap_data = np.load(act_obj.umapAggFile, allow_pickle=True)
                    if 'aniso_mean' in umap_data.files:
                        umap_metric = umap_data['aniso_mean']
                    elif 'coeffs' in umap_data.files:
                        umap_metric = umap_data['coeffs']
                else:
                    with open(act_obj.umapAggFile, 'rb') as f:
                        cache = pickle.load(f)
                    if 'aniso_mean' in cache:
                        umap_metric = cache['aniso_mean']
                    elif 'coeffs' in cache:
                        umap_metric = cache['coeffs']
            except Exception:
                pass

            saliency_mean = None
            try:
                sal_data = pickle.load(open(act_obj.saliencyFile, 'rb'))
                if 'avg_sal_per_prompt' in sal_data:
                    avg_sal = sal_data['avg_sal_per_prompt']
                    saliency_mean, _, _ = act_obj.compute_saliency_per_layer(avg_sal)
            except Exception:
                pass

            lesion_scores = None
            try:
                import json
                lesion_file = f"results/{act_obj.analysis_name}_lesioning_{act_obj.MODEL_NAME}.json"
                with open(lesion_file, 'r') as f:
                    lesion_data = json.load(f)
                prompt_scores = lesion_data.get('prompt_scores', {})
                layer_scores = {}
                for prompt_result in prompt_scores.values():
                    for layer_idx, score_info in prompt_result.get('scores_and_justifications', {}).items():
                        li = int(layer_idx)
                        if li >= act_obj.num_layers:
                            continue
                        layer_scores.setdefault(li, []).append(score_info['score'])
                lesion_scores = [np.mean(layer_scores[i]) if i in layer_scores else 0.0 for i in range(act_obj.num_layers)]
            except Exception:
                pass

            act_patch_scores = None
            try:
                import json
                ap_file = f"results/{act_obj.analysis_name}_activation_patching_{act_obj.MODEL_NAME}.json"
                with open(ap_file, 'r') as f:
                    ap_data = json.load(f)
                if 'all_patching_results' in ap_data:
                    layer_results = ap_data['all_patching_results']
                    layer_effects = {}
                    for layer_idx_str, layer_result in layer_results.items():
                        li = int(layer_idx_str)
                        if li >= act_obj.num_layers:
                            continue
                        pe = layer_result.get('patching_effect', {})
                        for _, effect in pe.items():
                            if effect != 0:
                                layer_effects.setdefault(li, []).append(effect)
                    act_patch_scores = [np.mean(layer_effects[i]) if i in layer_effects else 0.0 for i in range(act_obj.num_layers)]
            except Exception:
                pass

            analyses.append({
                'name': name,
                'umap_metric': umap_metric,
                'saliency': saliency_mean,
                'lesioning': lesion_scores,
                'activation_patching': act_patch_scores,
                'num_layers': act_obj.num_layers,
                'color': color
            })
        except Exception as e:
            print(f"Error loading {name} data (continuous): {e}")

    add_analysis(ageAct, 'Age', 'red')
    add_analysis(symptomAct, 'Symptoms', 'brown')
    add_analysis(diseaseAct, 'Diseases', 'purple')
    add_analysis(top100DrugAct, 'Drugs', 'orange')
    add_analysis(dosageAct, 'Dosage', 'purple')

    if not analyses:
        print('No analysis data available for continuous map')
        return

    # Prepare canvas similar to main_map
    fig, ax = plt.subplots(figsize=(15, 8))
    total_layers = analyses[0]['num_layers'] if analyses[0]['num_layers'] else 0
    ax.set_xlim(1, total_layers)
    ax.set_ylim(-0.5, len(analyses) * 2.5 - 0.5)
    ax.invert_yaxis()
    ax.set_xlabel('Layer', fontsize=12)
    ax.set_title(f'Main LLM Map (Continuous): {ageAct.MODEL_NAME}', fontsize=14)
    ax.grid(False)
    ax.grid(axis='x', alpha=0.3)

    # Separator lines between categories
    for i in range(len(analyses) - 1):
        ax.axhline(y=(i * 2.5) + 2.0, color='black', linewidth=1.0, alpha=0.4)

    # Helper: convert per-layer scores to opacity using logistic ramp around 75th percentile
    def scores_to_opacity(scores: np.ndarray):
        # Return zero opacity if no scores
        if scores is None:
            zeros = np.zeros(total_layers, dtype=float)
            return zeros, zeros
        # Convert and sanitize non-finite values
        arr_raw = np.array(scores, dtype=float)
        if arr_raw.size == 0:
            zeros = np.zeros(total_layers, dtype=float)
            return zeros, zeros
        # Ensure fixed length = total_layers (pad/truncate with zeros)
        arr = np.zeros(total_layers, dtype=float)
        n = min(arr_raw.size, total_layers)
        if n > 0:
            arr[:n] = arr_raw[:n]
        # Replace non-finite entries with zero
        arr = np.where(np.isfinite(arr), arr, 0.0)
        # Handle constant arrays early to avoid 0/0
        min_v = float(np.nanmin(arr))
        max_v = float(np.nanmax(arr))
        if not np.isfinite(min_v) or not np.isfinite(max_v) or (max_v - min_v) < 1e-12:
            zeros = np.zeros_like(arr)
            return zeros, zeros
        # Normalize to [0,1]
        arr_norm = (arr - min_v) / (max_v - min_v + 1e-8)
        # Determine floor/center points
        pct = low_opacity_percentile
        if pct <= 1.0:
            pct = pct * 100.0
        low_thresh = float(np.percentile(arr_norm, pct))
        max_val = float(np.nanmax(arr_norm))
        # Center logistic midway between floor and max
        center_thresh = 0.5 * (low_thresh + max_val)

        # Logistic mapping, with guards for degenerate ranges
        def sigmoid(x):
            return 1.0 / (1.0 + np.exp(-x))

        sig = sigmoid(transition_sharpness * (arr_norm - center_thresh))
        sig_low = sigmoid(transition_sharpness * (low_thresh - center_thresh))
        sig_high = sigmoid(transition_sharpness * (max_val - center_thresh))
        denom = (sig_high - sig_low)
        if abs(denom) < 1e-8 or not np.isfinite(denom):
            s_norm = np.zeros_like(arr_norm)
        else:
            s_norm = (sig - sig_low) / (denom + 1e-8)
            s_norm = np.clip(s_norm, 0.0, 1.0)

        alpha = min_opacity + (max_opacity - min_opacity) * s_norm
        # Hard floor: everything at/below low_percentile becomes exactly min_opacity
        alpha[arr_norm <= low_thresh] = float(min_opacity)
        # Final sanitize alpha to [0,1] and replace non-finite
        alpha = np.clip(np.nan_to_num(alpha, nan=float(min_opacity), posinf=float(max_opacity), neginf=float(min_opacity)), 0.0, 1.0)
        return alpha, s_norm

    # Render continuous bars per category and metric rows (UMAP, Saliency, Lesioning, Act. Pat.)
    for i, analysis in enumerate(analyses):
        y_pos = i * 2.5
        # Collect per-layer arrays (length = total_layers)
        umap_arr = analysis['umap_metric'] if analysis['umap_metric'] is not None else np.zeros(total_layers)
        sal_arr = analysis['saliency'] if analysis['saliency'] is not None else np.zeros(total_layers)
        les_arr = analysis['lesioning'] if analysis['lesioning'] is not None else np.zeros(total_layers)
        ap_arr = analysis['activation_patching'] if analysis['activation_patching'] is not None else np.zeros(total_layers)

        # Convert to opacity per layer
        umap_alpha, umap_snorm = scores_to_opacity(umap_arr)
        sal_alpha,  sal_snorm  = scores_to_opacity(sal_arr)
        les_alpha,  les_snorm  = scores_to_opacity(les_arr)
        ap_alpha,   ap_snorm   = scores_to_opacity(ap_arr)

        # Debug prints: ensure about low_opacity_percentile fraction are zero
        def summarize(label, alpha, s_norm):
            zeros = int((alpha <= 1e-8).sum())
            print(f"{analysis['name']} {label}: zeros {zeros}/{total_layers}, alpha[min={alpha.min():.3f}, max={alpha.max():.3f}], s_norm[min={s_norm.min():.3f}, max={s_norm.max():.3f}]")

        summarize('UMAP', umap_alpha, umap_snorm)
        summarize('Saliency', sal_alpha, sal_snorm)
        summarize('Lesioning', les_alpha, les_snorm)
        summarize('Act.Pat.', ap_alpha, ap_snorm)

        # Draw as many thin segments to approximate continuous opacity across layers
        def draw_row(y, color_name, alpha_arr):
            for L in range(total_layers):
                x0, x1 = L + 1, L + 2  # 1-based bins, width ~1
                a = float(alpha_arr[L])
                if a <= 0:
                    continue
                # Black boundary backdrop scales with opacity
                ax.hlines(y=y, xmin=x0, xmax=min(x1, total_layers),
                          colors='black', linewidth=10, alpha=0.15 * a)
                # Colored bar on top
                ax.hlines(y=y, xmin=x0, xmax=min(x1, total_layers),
                          colors=color_name, linewidth=8, alpha=a)

        # Top to bottom rows
        draw_row(y_pos, 'blue',   umap_alpha)
        draw_row(y_pos + 0.5, 'green',  sal_alpha)
        draw_row(y_pos + 1.0, 'red',    les_alpha)
        draw_row(y_pos + 1.5, 'orange', ap_alpha)

        if show_text and text_on_bars:
            # Keep same labels as main_map for now (non-adaptive)
            ax.text((total_layers + 1)/2, y_pos, f"{analysis['name']} UMAP", fontsize=9,
                    ha='center', va='center', color='white', weight='bold')
            ax.text((total_layers + 1)/2, y_pos + 0.5, f"{analysis['name']} Saliency", fontsize=9,
                    ha='center', va='center', color='white', weight='bold')
            ax.text((total_layers + 1)/2, y_pos + 1.0, f"{analysis['name']} Lesioning", fontsize=9,
                    ha='center', va='center', color='white', weight='bold')
            ax.text((total_layers + 1)/2, y_pos + 1.5, f"{analysis['name']} Act. Pat.", fontsize=9,
                    ha='center', va='center', color='white', weight='bold')

    # Legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], color='blue', lw=8, alpha=0.7, label='UMAP-based'),
        Line2D([0], [0], color='green', lw=8, alpha=0.7, label='Saliency-based'),
        Line2D([0], [0], color='red', lw=8, alpha=0.7, label='Lesioning-based'),
        Line2D([0], [0], color='orange', lw=8, alpha=0.7, label='Act. Pat.-based')
    ]
    ax.legend(handles=legend_elements, loc='upper right', framealpha=1.0, facecolor='white')

    ax.set_yticks([i * 2.5 + 0.75 for i in range(len(analyses))])
    ax.set_yticklabels([analysis['name'] for analysis in analyses])

    # Move legend up slightly
    y_data_range = len(analyses) * 3.3
    y_frac_shift = 1.0 / y_data_range if y_data_range > 0 else 0.05
    ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1, 1 + y_frac_shift), framealpha=1.0, facecolor='white')

    out_file = f"results/main_llm_map_continuous_{ageAct.MODEL_NAME}.pdf"
    fig.savefig(out_file, dpi=300, bbox_inches='tight')
    print(f"Saved continuous main LLM map → {out_file}")
    plt.show()


def metrics_table_all(results_dir: str = "results", out_tex: str = "results/concept_metrics.tex", out_tex_avg: str = "results/concept_metrics_avg.tex", out_tex_last: str = "results/concept_metrics_last.tex"):
    """
    Aggregate per-concept interpretability metrics across models and emit a LaTeX table.

    Concepts and metrics:
      - Age: R^2 from linear regression predicting age from 2D UMAP coordinates.
              Compute per-layer R^2 separately for pronouns {he, she, someone} and average; take best layer.
      - Symptoms: Silhouette score (best layer) using 30-D UMAP and symptom group labels.
      - Diseases: Silhouette score (best layer) using 30-D UMAP and specialty labels.
      - Drugs: Silhouette score (best layer) using 30-D UMAP and mechanism labels.
      - Dosages: Activation patching effect (mean non-zero effect per layer; best layer).

    The function discovers available models by scanning cached artifacts in `results_dir`.
    Writes a LaTeX table to `out_tex` and returns the path.
    """
    os.makedirs(os.path.dirname(out_tex), exist_ok=True)

    # Pretty model name mapping and preferred column order
    preferred_models = [
        ("Llama-3.3-70B-Instruct", "Llama 70B"),
        ("Gemma-3-27b-it", "Gemma 27B"),
        ("MedGemma-27b-text-it", "MedGemma 27B"),
        ("Qwen3-32B", "Qwen 32B"),
        ("gpt-oss-120b", "GPT-OSS 120B"),
    ]

    # Discover present models from age artifacts first, then union with others
    present = set()
    try:
        for fn in os.listdir(results_dir):
            if fn.startswith("age_umap_cache_") and fn.endswith(".pkl"):
                present.add(fn[len("age_umap_cache_"):-4])
            if fn.startswith("age_umap_agg_") and fn.endswith(".npz"):
                present.add(fn[len("age_umap_agg_"):-4])
            if fn.startswith("symptom_umap_cache_") and fn.endswith(".pkl"):
                present.add(fn[len("symptom_umap_cache_"):-4])
            if fn.startswith("disease_umap_cache_") and fn.endswith(".pkl"):
                present.add(fn[len("disease_umap_cache_"):-4])
            if fn.startswith("top100-drugs_umap_cache_") and fn.endswith(".pkl"):
                present.add(fn[len("top100-drugs_umap_cache_"):-4])
            if fn.startswith("dosage_activation_patching_") and fn.endswith(".json"):
                # skip finegrained variants from model list
                mk = fn[len("dosage_activation_patching_"):-5]
                if not mk.startswith("finegrained_"):
                    present.add(mk)
    except FileNotFoundError:
        pass

    # Build ordered list of (model_key, pretty_name)
    ordered_models = []
    seen = set()
    for key, pretty in preferred_models:
        if key in present:
            ordered_models.append((key, pretty))
            seen.add(key)
    
    # # Append any extra models discovered
    # for key in sorted(present):
    #     if key not in seen:
    #         ordered_models.append((key, key))

    if not ordered_models:
        print("No models detected in results/. Skipping metrics table generation.")
        return out_tex

    # Helpers
    def _safe_load_pickle(path):
        try:
            with open(path, "rb") as f:
                return pickle.load(f)
        except Exception:
            return None

    def _silhouette_best(embs_list, labels, npz_path=None):
        if embs_list is None or labels is None:
            return None, None, None
        labels = np.asarray(labels)
        if len(np.unique(labels)) < 2:
            return None, None, None
        best_val, best_layer, best_std = None, None, None
        for li, emb in enumerate(embs_list):
            try:
                if len(emb) == len(labels) and len(np.unique(labels)) > 1:
                    s = silhouette_score(emb, labels)
                    if best_val is None or s > best_val:
                        best_val, best_layer = float(s), li
                        # Try to get std from bootstrap confidence intervals
                        std = None
                        if npz_path and os.path.exists(npz_path):
                            try:
                                agg = np.load(npz_path, allow_pickle=True)
                                if 'coeffs_ci_lower' in agg.files and 'coeffs_ci_upper' in agg.files:
                                    ci_lower = agg['coeffs_ci_lower'][li]
                                    ci_upper = agg['coeffs_ci_upper'][li]
                                    # Approximate std from 95% CI: std ≈ (CI_upper - CI_lower) / (2 * 1.96)
                                    std = float((ci_upper - ci_lower) / (2 * 1.96))
                            except Exception:
                                pass
                        best_std = std
            except Exception:
                continue
        return best_val, best_layer, best_std

    def _age_r2_best(cache):
        if cache is None:
            return None, None, None
        embs = cache.get('embs')
        subj_markers = np.array(cache.get('subj_markers', []))
        ages = np.array(cache.get('ages', []), dtype=float)
        if embs is None or ages.size == 0 or subj_markers.size == 0:
            return None, None, None
        markers = ['o', 's', '*']  # he, she, someone
        best_val, best_layer, best_std = None, None, None
        for li, X in enumerate(embs):
            per_pron = []
            for m in markers:
                mask = (subj_markers == m)
                Xi = X[mask]
                yi = ages[mask]
                if Xi.shape[0] >= 3 and Xi.shape[1] >= 2 and np.std(yi) > 0:
                    try:
                        reg = LinearRegression().fit(Xi, yi)
                        r2 = float(reg.score(Xi, yi))
                        per_pron.append(r2)
                    except Exception:
                        pass
            if per_pron:
                mean_r2 = float(np.mean(per_pron))
                std_r2 = float(np.std(per_pron))
                if best_val is None or mean_r2 > best_val:
                    best_val, best_layer, best_std = mean_r2, li, std_r2
        return best_val, best_layer, best_std

    def _circularity_scores(json_path):
        """Compute CSFS (lower is better) and CSLS (higher is better) from circularity JSON."""
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            
            closest_to_first = data.get('closest_to_first_stage', {})
            closest_to_end = data.get('closest_to_end_stage', {})
            
            if not closest_to_first or not closest_to_end:
                return (None, None), (None, None)
            
            # CSFS: Average across all diseases, lower is better
            csfs_scores = []
            csfs_layers = []
            for disease, layer_data in closest_to_first.items():
                if layer_data:
                    # Convert layer keys to int and get values
                    layer_scores = {int(k): v for k, v in layer_data.items() if v is not None}
                    if layer_scores:
                        # Find layer with minimum score (closest to first stage)
                        min_score = min(layer_scores.values())
                        min_layer = min(k for k, v in layer_scores.items() if v == min_score)
                        csfs_scores.append(min_score)
                        csfs_layers.append(min_layer)
            
            # CSLS: Average across all diseases, higher is better  
            csls_scores = []
            csls_layers = []
            for disease, layer_data in closest_to_end.items():
                if layer_data:
                    # Convert layer keys to int and get values
                    layer_scores = {int(k): v for k, v in layer_data.items() if v is not None}
                    if layer_scores:
                        # Find layer with maximum score (closest to end stage)
                        max_score = max(layer_scores.values())
                        max_layer = max(k for k, v in layer_scores.items() if v == max_score)
                        csls_scores.append(max_score)
                        csls_layers.append(max_layer)
            
            # Average across diseases
            csfs_avg = float(np.mean(csfs_scores)) if csfs_scores else None
            csfs_std = float(np.std(csfs_scores)) if csfs_scores else None
            csfs_layer = int(np.mean(csfs_layers)) if csfs_layers else None
            
            csls_avg = float(np.mean(csls_scores)) if csls_scores else None
            csls_std = float(np.std(csls_scores)) if csls_scores else None
            csls_layer = int(np.mean(csls_layers)) if csls_layers else None
            
            return (csfs_avg, csfs_layer, csfs_std), (csls_avg, csls_layer, csls_std)
            
        except Exception:
            return (None, None, None), (None, None, None)

    def _age_r2_avg_std(cache):
        """Compute average and std R^2 across all layers for age."""
        if cache is None:
            return None, None
        embs = cache.get('embs')
        subj_markers = np.array(cache.get('subj_markers', []))
        ages = np.array(cache.get('ages', []), dtype=float)
        if embs is None or ages.size == 0 or subj_markers.size == 0:
            return None, None
        markers = ['o', 's', '*']  # he, she, someone
        r2_layers = []
        for X in embs:
            per_pron = []
            for m in markers:
                mask = (subj_markers == m)
                Xi = X[mask]
                yi = ages[mask]
                if Xi.shape[0] >= 3 and Xi.shape[1] >= 2 and np.std(yi) > 0:
                    try:
                        reg = LinearRegression().fit(Xi, yi)
                        per_pron.append(float(reg.score(Xi, yi)))
                    except Exception:
                        pass
            if per_pron:
                r2_layers.append(np.mean(per_pron))
        if r2_layers:
            return float(np.mean(r2_layers)), float(np.std(r2_layers))
        return None, None

    def _silhouette_avg_std(embs_list, labels):
        """Compute average and std silhouette score across all layers."""
        if embs_list is None or labels is None:
            return None, None
        labels = np.asarray(labels)
        if len(np.unique(labels)) < 2:
            return None, None
        scores = []
        for emb in embs_list:
            try:
                if len(emb) == len(labels) and len(np.unique(labels)) > 1:
                    s = silhouette_score(emb, labels)
                    scores.append(float(s))
            except Exception:
                continue
        if scores:
            return float(np.mean(scores)), float(np.std(scores))
        return None, None

    def _circularity_avg_std(json_path):
        """Compute average and std CSFS and CSLS across all layers."""
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            
            closest_to_first = data.get('closest_to_first_stage', {})
            closest_to_end = data.get('closest_to_end_stage', {})
            
            if not closest_to_first or not closest_to_end:
                return (None, None), (None, None)
            
            # CSFS: Average across all diseases and all layers
            csfs_all_scores = []
            for disease, layer_data in closest_to_first.items():
                if layer_data:
                    layer_scores = [v for v in layer_data.values() if v is not None]
                    if layer_scores:
                        csfs_all_scores.extend(layer_scores)
            
            # CSLS: Average across all diseases and all layers
            csls_all_scores = []
            for disease, layer_data in closest_to_end.items():
                if layer_data:
                    layer_scores = [v for v in layer_data.values() if v is not None]
                    if layer_scores:
                        csls_all_scores.extend(layer_scores)
            
            csfs_avg = float(np.mean(csfs_all_scores)) if csfs_all_scores else None
            csfs_std = float(np.std(csfs_all_scores)) if csfs_all_scores else None
            csls_avg = float(np.mean(csls_all_scores)) if csls_all_scores else None
            csls_std = float(np.std(csls_all_scores)) if csls_all_scores else None
            
            return (csfs_avg, csfs_std), (csls_avg, csls_std)
            
        except Exception:
            return (None, None), (None, None)

    def _dosages_avg_std(json_path):
        """Compute average and std of all non-zero patching effects across all layers."""
        try:
            with open(json_path, 'r') as f:
                ap = json.load(f)
            all_effects = []
            all_res = ap.get('all_patching_results', {})
            for layer_idx_str, layer_res in all_res.items():
                pe = layer_res.get('patching_effect', {})
                vals = [v for v in pe.values() if isinstance(v, (int, float)) and np.isfinite(v) and v != 0]
                all_effects.extend(vals)
            if all_effects:
                return float(np.mean(all_effects)), float(np.std(all_effects))
            return None, None
        except Exception:
            return None, None

    def _age_r2_last(cache):
        """Compute R^2 for age at the last layer with std."""
        if cache is None:
            return None, None
        embs = cache.get('embs')
        subj_markers = np.array(cache.get('subj_markers', []))
        ages = np.array(cache.get('ages', []), dtype=float)
        if embs is None or ages.size == 0 or subj_markers.size == 0:
            return None, None
        markers = ['o', 's', '*']  # he, she, someone
        last_layer = len(embs) - 1
        X = embs[last_layer]
        per_pron = []
        for m in markers:
            mask = (subj_markers == m)
            Xi = X[mask]
            yi = ages[mask]
            if Xi.shape[0] >= 3 and Xi.shape[1] >= 2 and np.std(yi) > 0:
                try:
                    reg = LinearRegression().fit(Xi, yi)
                    r2 = float(reg.score(Xi, yi))
                    per_pron.append(r2)
                except Exception:
                    pass
        if per_pron:
            return float(np.mean(per_pron)), float(np.std(per_pron))
        return None, None

    def _silhouette_last(embs_list, labels, npz_path=None):
        """Compute silhouette score at the last layer with std from bootstrap CI."""
        if embs_list is None or labels is None:
            return None, None
        labels = np.asarray(labels)
        if len(np.unique(labels)) < 2:
            return None, None
        last_layer = len(embs_list) - 1
        try:
            if len(embs_list[last_layer]) == len(labels) and len(np.unique(labels)) > 1:
                score = float(silhouette_score(embs_list[last_layer], labels))
                
                # Try to get std from bootstrap confidence intervals
                std = None
                if npz_path and os.path.exists(npz_path):
                    try:
                        agg = np.load(npz_path, allow_pickle=True)
                        if 'coeffs_ci_lower' in agg.files and 'coeffs_ci_upper' in agg.files:
                            ci_lower = agg['coeffs_ci_lower'][last_layer]
                            ci_upper = agg['coeffs_ci_upper'][last_layer]
                            # Approximate std from 95% CI: std ≈ (CI_upper - CI_lower) / (2 * 1.96)
                            std = float((ci_upper - ci_lower) / (2 * 1.96))
                    except Exception:
                        pass
                
                return score, std
        except Exception:
            pass
        return None, None

    def _circularity_last(json_path):
        """Compute CSFS and CSLS at the last layer with std."""
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            
            closest_to_first = data.get('closest_to_first_stage', {})
            closest_to_end = data.get('closest_to_end_stage', {})
            
            if not closest_to_first or not closest_to_end:
                return (None, None), (None, None)
            
            # CSFS: Average across all diseases at last layer
            csfs_scores = []
            for disease, layer_data in closest_to_first.items():
                if layer_data:
                    # Get the last layer key
                    layer_keys = [int(k) for k in layer_data.keys() if layer_data[k] is not None]
                    if layer_keys:
                        last_layer = max(layer_keys)
                        if str(last_layer) in layer_data and layer_data[str(last_layer)] is not None:
                            csfs_scores.append(layer_data[str(last_layer)])
            
            # CSLS: Average across all diseases at last layer
            csls_scores = []
            for disease, layer_data in closest_to_end.items():
                if layer_data:
                    # Get the last layer key
                    layer_keys = [int(k) for k in layer_data.keys() if layer_data[k] is not None]
                    if layer_keys:
                        last_layer = max(layer_keys)
                        if str(last_layer) in layer_data and layer_data[str(last_layer)] is not None:
                            csls_scores.append(layer_data[str(last_layer)])
            
            csfs_avg = float(np.mean(csfs_scores)) if csfs_scores else None
            csfs_std = float(np.std(csfs_scores)) if csfs_scores else None
            csls_avg = float(np.mean(csls_scores)) if csls_scores else None
            csls_std = float(np.std(csls_scores)) if csls_scores else None
            
            return (csfs_avg, csfs_std), (csls_avg, csls_std)
            
        except Exception:
            return (None, None), (None, None)

    def _dosages_last(json_path):
        """Compute patching effect at the last layer with std."""
        try:
            with open(json_path, 'r') as f:
                ap = json.load(f)
            all_res = ap.get('all_patching_results', {})
            if not all_res:
                return None, None
            
            # Get the last layer
            layer_keys = [int(k) for k in all_res.keys()]
            if not layer_keys:
                return None, None
            last_layer = max(layer_keys)
            last_layer_str = str(last_layer)
            
            if last_layer_str in all_res:
                pe = all_res[last_layer_str].get('patching_effect', {})
                vals = [v for v in pe.values() if isinstance(v, (int, float)) and np.isfinite(v) and v != 0]
                if vals:
                    return float(np.mean(vals)), float(np.std(vals))
            return None, None
        except Exception:
            return None, None

    # Collect metrics per concept per model
    concepts = [
        ("Age", "R$^2$ (linear) $\\uparrow$"),
        ("Symptoms", "Silhouette $\\uparrow$"),
        ("Diseases", "Silhouette $\\uparrow$"),
        ("Disease Progression", "CSFS $\\downarrow$"),
        ("Disease Progression", "CSLS $\\uparrow$"),
        ("Drugs", "Silhouette (mech.) $\\uparrow$"),
        ("Drugs", "Silhouette (spec.) $\\uparrow$"),
        ("Dosages", "Patching Effect $\\uparrow$"),
    ]

    # metrics[concept][model_key] = (score, layer) or {"mech": (score, layer), "spec": (score, layer)} for Drugs
    metrics = {c[0]: {} for c in concepts}
    # avg_metrics[concept][model_key] = (mean, std) (no layer for averages) or {"mech": (mean, std), "spec": (mean, std)} for Drugs
    avg_metrics = {c[0]: {} for c in concepts}
    # last_metrics[concept][model_key] = score (last layer only) or {"mech": score, "spec": score} for Drugs
    last_metrics = {c[0]: {} for c in concepts}

    for model_key, _pretty in ordered_models:
        # Age: from cache
        age_cache = _safe_load_pickle(os.path.join(results_dir, f"age_umap_cache_{model_key}.pkl"))
        age_best_mean, age_best_layer, age_best_std = _age_r2_best(age_cache)
        metrics["Age"][model_key] = (age_best_mean, age_best_layer, age_best_std)
        # Age average and std
        age_avg, age_std = _age_r2_avg_std(age_cache)
        avg_metrics["Age"][model_key] = (age_avg, age_std)
        # Age last layer
        age_last_mean, age_last_std = _age_r2_last(age_cache)
        last_metrics["Age"][model_key] = (age_last_mean, age_last_std)

        # Symptoms silhouette (prefer cache; fallback to aggregated coeffs)
        sym_best = (None, None, None)
        sym_cache = _safe_load_pickle(os.path.join(results_dir, f"symptom_umap_cache_{model_key}.pkl"))
        if sym_cache is not None:
            embsC = sym_cache.get('embsClustering')
            labels = sym_cache.get('symptom_groups_list')
            if labels is None and sym_cache.get('SYMPTOM_TO_GROUP') and sym_cache.get('symptoms_list'):
                stg = sym_cache['SYMPTOM_TO_GROUP']
                labels = [stg[s] for s in sym_cache['symptoms_list']]
            npz_path = os.path.join(results_dir, f"symptom_umap_agg_{model_key}.npz")
            sym_best = _silhouette_best(embsC, labels, npz_path)
        if sym_best == (None, None, None):
            agg_path = os.path.join(results_dir, f"symptom_umap_agg_{model_key}.npz")
            if os.path.exists(agg_path):
                try:
                    agg = np.load(agg_path, allow_pickle=True)
                    if 'coeffs' in agg.files:
                        arr = np.asarray(agg['coeffs'], dtype=float)
                        if arr.size > 0 and np.isfinite(arr).any():
                            li = int(np.nanargmax(arr))
                            sym_best = (float(arr[li]), li, None)
                except Exception:
                    pass
        metrics["Symptoms"][model_key] = sym_best
        # Symptoms average and std
        sym_avg, sym_std = None, None
        if sym_cache is not None:
            embsC = sym_cache.get('embsClustering')
            labels = sym_cache.get('symptom_groups_list')
            if labels is None and sym_cache.get('SYMPTOM_TO_GROUP') and sym_cache.get('symptoms_list'):
                stg = sym_cache['SYMPTOM_TO_GROUP']
                labels = [stg[s] for s in sym_cache['symptoms_list']]
            sym_avg, sym_std = _silhouette_avg_std(embsC, labels)
        avg_metrics["Symptoms"][model_key] = (sym_avg, sym_std)
        # Symptoms last layer
        sym_last_mean, sym_last_std = None, None
        if sym_cache is not None:
            embsC = sym_cache.get('embsClustering')
            labels = sym_cache.get('symptom_groups_list')
            if labels is None and sym_cache.get('SYMPTOM_TO_GROUP') and sym_cache.get('symptoms_list'):
                stg = sym_cache['SYMPTOM_TO_GROUP']
                labels = [stg[s] for s in sym_cache['symptoms_list']]
            npz_path = os.path.join(results_dir, f"symptom_umap_agg_{model_key}.npz")
            sym_last_mean, sym_last_std = _silhouette_last(embsC, labels, npz_path)
        last_metrics["Symptoms"][model_key] = (sym_last_mean, sym_last_std)

        # Diseases silhouette (prefer cache; fallback to aggregated coeffs)
        dis_best = (None, None, None)
        dis_cache = _safe_load_pickle(os.path.join(results_dir, f"disease_umap_cache_{model_key}.pkl"))
        if dis_cache is not None:
            embsC = dis_cache.get('embsClustering')
            d2s = dis_cache.get('DISEASE_TO_SPEC')
            dlist = dis_cache.get('diseases_list')
            labels = [d2s[d] for d in dlist] if (d2s and dlist) else None
            npz_path = os.path.join(results_dir, f"disease_umap_agg_{model_key}.npz")
            dis_best = _silhouette_best(embsC, labels, npz_path)
        if dis_best == (None, None, None):
            agg_path = os.path.join(results_dir, f"disease_umap_agg_{model_key}.npz")
            if os.path.exists(agg_path):
                try:
                    agg = np.load(agg_path, allow_pickle=True)
                    if 'coeffs' in agg.files:
                        arr = np.asarray(agg['coeffs'], dtype=float)
                        if arr.size > 0 and np.isfinite(arr).any():
                            li = int(np.nanargmax(arr))
                            dis_best = (float(arr[li]), li, None)
                except Exception:
                    pass
        metrics["Diseases"][model_key] = dis_best
        # Diseases average and std
        dis_avg, dis_std = None, None
        if dis_cache is not None:
            embsC = dis_cache.get('embsClustering')
            d2s = dis_cache.get('DISEASE_TO_SPEC')
            dlist = dis_cache.get('diseases_list')
            labels = [d2s[d] for d in dlist] if (d2s and dlist) else None
            dis_avg, dis_std = _silhouette_avg_std(embsC, labels)
        avg_metrics["Diseases"][model_key] = (dis_avg, dis_std)
        # Diseases last layer
        dis_last_mean, dis_last_std = None, None
        if dis_cache is not None:
            embsC = dis_cache.get('embsClustering')
            d2s = dis_cache.get('DISEASE_TO_SPEC')
            dlist = dis_cache.get('diseases_list')
            labels = [d2s[d] for d in dlist] if (d2s and dlist) else None
            npz_path = os.path.join(results_dir, f"disease_umap_agg_{model_key}.npz")
            dis_last_mean, dis_last_std = _silhouette_last(embsC, labels, npz_path)
        last_metrics["Diseases"][model_key] = (dis_last_mean, dis_last_std)

        # Drugs silhouette - both mechanism and specialty
        drug_mech_best = (None, None, None)
        drug_spec_best = (None, None, None)
        drug_mech_avg, drug_mech_std = None, None
        drug_spec_avg, drug_spec_std = None, None
        
        # Load mechanism clustering results
        drug_mech_path = os.path.join(results_dir, f"drugs_umap-mechanism_{model_key}.npz")
        if os.path.exists(drug_mech_path):
            try:
                mech_data = np.load(drug_mech_path, allow_pickle=True)
                if 'coeffs' in mech_data.files:
                    coeffs = mech_data['coeffs']
                    if len(coeffs) > 0 and np.isfinite(coeffs).any():
                        best_idx = int(np.nanargmax(coeffs))
                        # For std, we'll use the std of all layers as approximation
                        drug_mech_best = (float(coeffs[best_idx]), best_idx, float(np.nanstd(coeffs)))
                        drug_mech_avg = float(np.nanmean(coeffs))
                        drug_mech_std = float(np.nanstd(coeffs))
            except Exception:
                pass
        
        # Load specialty clustering results
        drug_spec_path = os.path.join(results_dir, f"drugs_umap-specialty_{model_key}.npz")
        if os.path.exists(drug_spec_path):
            try:
                spec_data = np.load(drug_spec_path, allow_pickle=True)
                if 'coeffs' in spec_data.files:
                    coeffs = spec_data['coeffs']
                    if len(coeffs) > 0 and np.isfinite(coeffs).any():
                        best_idx = int(np.nanargmax(coeffs))
                        # For std, we'll use the std of all layers as approximation
                        drug_spec_best = (float(coeffs[best_idx]), best_idx, float(np.nanstd(coeffs)))
                        drug_spec_avg = float(np.nanmean(coeffs))
                        drug_spec_std = float(np.nanstd(coeffs))
            except Exception:
                pass
        
        # Store both mechanism and specialty results
        metrics["Drugs"][model_key] = {
            "mech": drug_mech_best,
            "spec": drug_spec_best
        }
        avg_metrics["Drugs"][model_key] = {
            "mech": (drug_mech_avg, drug_mech_std),
            "spec": (drug_spec_avg, drug_spec_std)
        }
        # Drugs last layer - both mechanism and specialty
        drug_mech_last_mean, drug_mech_last_std = None, None
        drug_spec_last_mean, drug_spec_last_std = None, None
        
        # Load mechanism last layer
        if os.path.exists(drug_mech_path):
            try:
                mech_data = np.load(drug_mech_path, allow_pickle=True)
                if 'coeffs' in mech_data.files:
                    coeffs = mech_data['coeffs']
                    if len(coeffs) > 0 and np.isfinite(coeffs).any():
                        drug_mech_last_mean = float(coeffs[-1])  # Last layer
                        # For std, we'll use the std of all layers as approximation
                        drug_mech_last_std = float(np.std(coeffs))
            except Exception:
                pass
        
        # Load specialty last layer
        if os.path.exists(drug_spec_path):
            try:
                spec_data = np.load(drug_spec_path, allow_pickle=True)
                if 'coeffs' in spec_data.files:
                    coeffs = spec_data['coeffs']
                    if len(coeffs) > 0 and np.isfinite(coeffs).any():
                        drug_spec_last_mean = float(coeffs[-1])  # Last layer
                        # For std, we'll use the std of all layers as approximation
                        drug_spec_last_std = float(np.std(coeffs))
            except Exception:
                pass
        
        last_metrics["Drugs"][model_key] = {
            "mech": (drug_mech_last_mean, drug_mech_last_std),
            "spec": (drug_spec_last_mean, drug_spec_last_std)
        }

        # Dosages patching effect
        pe_best = (None, None, None)
        ap_path = os.path.join(results_dir, f"dosage_activation_patching_{model_key}.json")
        if os.path.exists(ap_path):
            try:
                with open(ap_path, 'r') as f:
                    ap = json.load(f)
                layer_means = {}
                layer_stds = {}
                all_res = ap.get('all_patching_results', {})
                for layer_idx_str, layer_res in all_res.items():
                    pe = layer_res.get('patching_effect', {})
                    vals = [v for v in pe.values() if isinstance(v, (int, float)) and np.isfinite(v) and v != 0]
                    if vals:
                        layer_means[int(layer_idx_str)] = float(np.mean(vals))
                        layer_stds[int(layer_idx_str)] = float(np.std(vals))
                if layer_means:
                    li_best = max(layer_means, key=lambda k: layer_means[k])
                    pe_best = (layer_means[li_best], li_best, layer_stds.get(li_best, None))
            except Exception:
                pass
        metrics["Dosages"][model_key] = pe_best
        # Dosages average and std
        dosages_avg, dosages_std = _dosages_avg_std(ap_path)
        avg_metrics["Dosages"][model_key] = (dosages_avg, dosages_std)
        # Dosages last layer
        dosages_last_mean, dosages_last_std = _dosages_last(ap_path)
        last_metrics["Dosages"][model_key] = (dosages_last_mean, dosages_last_std)

        # Disease Progression circularity scores
        circularity_path = os.path.join(results_dir, f"progression_circularity_{model_key}.json")
        if os.path.exists(circularity_path):
            (csfs_score, csfs_layer, csfs_std), (csls_score, csls_layer, csls_std) = _circularity_scores(circularity_path)
            # Store CSFS and CSLS separately - we'll handle them in the table generation
            metrics["Disease Progression"][model_key] = {
                "CSFS": (csfs_score, csfs_layer, csfs_std),
                "CSLS": (csls_score, csls_layer, csls_std)
            }
            # Disease Progression averages and std
            (csfs_avg, csfs_std), (csls_avg, csls_std) = _circularity_avg_std(circularity_path)
            avg_metrics["Disease Progression"][model_key] = {
                "CSFS": (csfs_avg, csfs_std),
                "CSLS": (csls_avg, csls_std)
            }
            # Disease Progression last layer
            (csfs_last_mean, csfs_last_std), (csls_last_mean, csls_last_std) = _circularity_last(circularity_path)
            last_metrics["Disease Progression"][model_key] = {
                "CSFS": (csfs_last_mean, csfs_last_std),
                "CSLS": (csls_last_mean, csls_last_std)
            }
        else:
            metrics["Disease Progression"][model_key] = {
                "CSFS": (None, None, None),
                "CSLS": (None, None, None)
            }
            avg_metrics["Disease Progression"][model_key] = {
                "CSFS": (None, None),
                "CSLS": (None, None)
            }
            last_metrics["Disease Progression"][model_key] = {
                "CSFS": None,
                "CSLS": None
            }

    # Build LaTeX table
    def fmt(score_layer_std, bold=False):
        if score_layer_std is None or len(score_layer_std) < 2 or score_layer_std[0] is None or score_layer_std[1] is None:
            cell = "-- & --"
        else:
            score, layer = score_layer_std[0], score_layer_std[1]
            std = score_layer_std[2] if len(score_layer_std) > 2 else None
            if std is None:
                cell = f"{score:.2f} & {layer}"
            else:
                cell = f"{score:.2f} ± {std:.2f} & {layer}"
        if bold and score_layer_std is not None and score_layer_std[0] is not None:
            # Bold just the score part
            parts = cell.split(' & ')
            parts[0] = f"\\textbf{{{parts[0]}}}"
            cell = ' & '.join(parts)
        return cell

    # Prepare column headers in chosen order
    headers = [pretty for _, pretty in ordered_models]

    lines = []
    lines.append("\\begin{table}[htbp]")
    lines.append("\\centering")
    lines.append("\\caption{Per-concept interpretability metrics across models. Values are best layer scores (max across depth) with the layer index.}")
    lines.append("\\label{tab:concept_metrics}")
    lines.append("\\resizebox{\\textwidth}{!}{%")
    # Build tabular column spec: l (concept) c (metric) then for each model two columns (score,layer)
    colspec = 'lcc' + 'c' * (2 * len(headers))
    lines.append(f"\\begin{{tabular}}{{{colspec}}}")
    lines.append("\\toprule")
    # Header row
    header_top = ["\\multirow{2}{*}{\\textbf{Concept}}", "\\multirow{2}{*}{\\textbf{Metric}}"]
    for h in headers:
        header_top.append(f"\\multicolumn{{2}}{{c}}{{\\textbf{{{h}}}}}")
    lines.append(' & '.join(header_top) + " \\\\")
    # Mid header with Score/Layer
    lines.append(' '.join(["\\cmidrule(lr){3-" + str(2 + 2*len(headers)) + "}"]))
    header_bot = ["", ""] + [x for _ in headers for x in ("Score", "Layer")]
    lines.append(' & '.join(header_bot) + " \\\\")
    lines.append("\\midrule")

    # Rows per concept
    for concept, metric_label in concepts:
        if concept == "Disease Progression":
            # Handle Disease Progression metrics separately
            if "CSFS" in metric_label:
                # CSFS row (lower is better)
                row_vals = []
                min_val = None
                for model_key, _ in ordered_models:
                    dp_data = metrics[concept].get(model_key, {})
                    s_l_std = dp_data.get("CSFS", (None, None, None))
                    if s_l_std[0] is not None and (min_val is None or s_l_std[0] < min_val):
                        min_val = s_l_std[0]
                    row_vals.append(s_l_std)
                
                # Assemble CSFS row
                row = [concept, metric_label]
                for s_l_std in row_vals:
                    row.append(fmt(s_l_std, bold=(s_l_std is not None and s_l_std[0] is not None and min_val is not None and abs(s_l_std[0] - min_val) < 1e-12)))
                lines.append(' & '.join(row) + " \\\\")
                
            elif "CSLS" in metric_label:
                # CSLS row (higher is better)
                row_vals = []
                max_val = None
                for model_key, _ in ordered_models:
                    dp_data = metrics[concept].get(model_key, {})
                    s_l_std = dp_data.get("CSLS", (None, None, None))
                    if s_l_std[0] is not None and (max_val is None or s_l_std[0] > max_val):
                        max_val = s_l_std[0]
                    row_vals.append(s_l_std)
                
                # Assemble CSLS row
                row = [concept, metric_label]
                for s_l_std in row_vals:
                    row.append(fmt(s_l_std, bold=(s_l_std is not None and s_l_std[0] is not None and max_val is not None and abs(s_l_std[0] - max_val) < 1e-12)))
                lines.append(' & '.join(row) + " \\\\")
        elif concept == "Drugs":
            # Handle Drugs metrics separately for mechanism and specialty
            if "mech" in metric_label:
                # Mechanism row (higher is better)
                row_vals = []
                max_val = None
                for model_key, _ in ordered_models:
                    drug_data = metrics[concept].get(model_key, {})
                    s_l_std = drug_data.get("mech", (None, None, None))
                    if s_l_std[0] is not None and (max_val is None or s_l_std[0] > max_val):
                        max_val = s_l_std[0]
                    row_vals.append(s_l_std)
                
                # Assemble mechanism row
                row = [concept, metric_label]
                for s_l_std in row_vals:
                    row.append(fmt(s_l_std, bold=(s_l_std is not None and s_l_std[0] is not None and max_val is not None and abs(s_l_std[0] - max_val) < 1e-12)))
                lines.append(' & '.join(row) + " \\\\")
                
            elif "spec" in metric_label:
                # Specialty row (higher is better)
                row_vals = []
                max_val = None
                for model_key, _ in ordered_models:
                    drug_data = metrics[concept].get(model_key, {})
                    s_l_std = drug_data.get("spec", (None, None, None))
                    if s_l_std[0] is not None and (max_val is None or s_l_std[0] > max_val):
                        max_val = s_l_std[0]
                    row_vals.append(s_l_std)
                
                # Assemble specialty row
                row = [concept, metric_label]
                for s_l_std in row_vals:
                    row.append(fmt(s_l_std, bold=(s_l_std is not None and s_l_std[0] is not None and max_val is not None and abs(s_l_std[0] - max_val) < 1e-12)))
                lines.append(' & '.join(row) + " \\\\")
        else:
            # Handle other concepts normally
            row_vals = []
            max_val = None
            for model_key, _ in ordered_models:
                s_l_std = metrics[concept].get(model_key, (None, None, None))
                if s_l_std[0] is not None and (max_val is None or s_l_std[0] > max_val):
                    max_val = s_l_std[0]
                row_vals.append(s_l_std)

            # Assemble row
            row = [concept, metric_label]
            for s_l_std in row_vals:
                row.append(fmt(s_l_std, bold=(s_l_std is not None and s_l_std[0] is not None and max_val is not None and abs(s_l_std[0] - max_val) < 1e-12)))
            lines.append(' & '.join(row) + " \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}%")
    lines.append("}")
    lines.append("\\end{table}")

    # Also print the LaTeX table to stdout for quick inspection
    table_text = '\n'.join(lines) + '\n'
    print(table_text)
    with open(out_tex, 'w') as f:
        f.write(table_text)
    print(f"Wrote LaTeX metrics table → {out_tex}")
    
    # Generate second table with average metrics
    def fmt_avg(mean_std, bold=False):
        if mean_std is None or mean_std[0] is None:
            return "--"
        else:
            mean, std = mean_std
            if std is None:
                cell = f"{mean:.2f}"
            else:
                cell = f"{mean:.2f} ± {std:.2f}"
        if bold and mean_std is not None and mean_std[0] is not None:
            cell = f"\\textbf{{{cell}}}"
        return cell

    # Build average LaTeX table
    lines_avg = []
    lines_avg.append("\\begin{table}[htbp]")
    lines_avg.append("\\centering")
    lines_avg.append("\\caption{Per-concept interpretability metrics across models. Values are mean ± std across all layers.}")
    lines_avg.append("\\label{tab:concept_metrics_avg}")
    lines_avg.append("\\resizebox{\\textwidth}{!}{%")
    # Build tabular column spec: l (concept) c (metric) then for each model one column (mean ± std)
    colspec = 'lc' + 'c' * len(headers)
    lines_avg.append(f"\\begin{{tabular}}{{{colspec}}}")
    lines_avg.append("\\toprule")
    # Header row
    header_top = ["\\textbf{Concept}", "\\textbf{Metric}"]
    for h in headers:
        header_top.append(f"\\textbf{{{h}}}")
    lines_avg.append(' & '.join(header_top) + " \\\\")
    lines_avg.append("\\midrule")

    # Rows per concept
    for concept, metric_label in concepts:
        if concept == "Disease Progression":
            # Handle Disease Progression metrics separately
            if "CSFS" in metric_label:
                # CSFS row (lower is better)
                row_vals = []
                min_val = None
                for model_key, _ in ordered_models:
                    dp_data = avg_metrics[concept].get(model_key, {})
                    mean_std = dp_data.get("CSFS", (None, None))
                    if mean_std[0] is not None and (min_val is None or mean_std[0] < min_val):
                        min_val = mean_std[0]
                    row_vals.append(mean_std)
                
                # Assemble CSFS row
                row = [concept, metric_label]
                for mean_std in row_vals:
                    row.append(fmt_avg(mean_std, bold=(mean_std is not None and mean_std[0] is not None and min_val is not None and abs(mean_std[0] - min_val) < 1e-12)))
                lines_avg.append(' & '.join(row) + " \\\\")
                
            elif "CSLS" in metric_label:
                # CSLS row (higher is better)
                row_vals = []
                max_val = None
                for model_key, _ in ordered_models:
                    dp_data = avg_metrics[concept].get(model_key, {})
                    mean_std = dp_data.get("CSLS", (None, None))
                    if mean_std[0] is not None and (max_val is None or mean_std[0] > max_val):
                        max_val = mean_std[0]
                    row_vals.append(mean_std)
                
                # Assemble CSLS row
                row = [concept, metric_label]
                for mean_std in row_vals:
                    row.append(fmt_avg(mean_std, bold=(mean_std is not None and mean_std[0] is not None and max_val is not None and abs(mean_std[0] - max_val) < 1e-12)))
                lines_avg.append(' & '.join(row) + " \\\\")
        elif concept == "Drugs":
            # Handle Drugs metrics separately for mechanism and specialty
            if "mech" in metric_label:
                # Mechanism row (higher is better)
                row_vals = []
                max_val = None
                for model_key, _ in ordered_models:
                    drug_data = avg_metrics[concept].get(model_key, {})
                    mean_std = drug_data.get("mech", (None, None))
                    if mean_std[0] is not None and (max_val is None or mean_std[0] > max_val):
                        max_val = mean_std[0]
                    row_vals.append(mean_std)
                
                # Assemble mechanism row
                row = [concept, metric_label]
                for mean_std in row_vals:
                    row.append(fmt_avg(mean_std, bold=(mean_std is not None and mean_std[0] is not None and max_val is not None and abs(mean_std[0] - max_val) < 1e-12)))
                lines_avg.append(' & '.join(row) + " \\\\")
                
            elif "spec" in metric_label:
                # Specialty row (higher is better)
                row_vals = []
                max_val = None
                for model_key, _ in ordered_models:
                    drug_data = avg_metrics[concept].get(model_key, {})
                    mean_std = drug_data.get("spec", (None, None))
                    if mean_std[0] is not None and (max_val is None or mean_std[0] > max_val):
                        max_val = mean_std[0]
                    row_vals.append(mean_std)
                
                # Assemble specialty row
                row = [concept, metric_label]
                for mean_std in row_vals:
                    row.append(fmt_avg(mean_std, bold=(mean_std is not None and mean_std[0] is not None and max_val is not None and abs(mean_std[0] - max_val) < 1e-12)))
                lines_avg.append(' & '.join(row) + " \\\\")
        else:
            # Handle other concepts normally
            row_vals = []
            max_val = None
            for model_key, _ in ordered_models:
                mean_std = avg_metrics[concept].get(model_key, (None, None))
                if mean_std[0] is not None and (max_val is None or mean_std[0] > max_val):
                    max_val = mean_std[0]
                row_vals.append(mean_std)

            # Assemble row
            row = [concept, metric_label]
            for mean_std in row_vals:
                row.append(fmt_avg(mean_std, bold=(mean_std is not None and mean_std[0] is not None and max_val is not None and abs(mean_std[0] - max_val) < 1e-12)))
            lines_avg.append(' & '.join(row) + " \\\\")

    lines_avg.append("\\bottomrule")
    lines_avg.append("\\end{tabular}%")
    lines_avg.append("}")
    lines_avg.append("\\end{table}")

    # Print and save average table
    table_avg_text = '\n'.join(lines_avg) + '\n'
    print("\n" + "="*80)
    print("AVERAGE METRICS TABLE:")
    print("="*80)
    print(table_avg_text)
    with open(out_tex_avg, 'w') as f:
        f.write(table_avg_text)
    print(f"Wrote LaTeX average metrics table → {out_tex_avg}")
    
    # Generate third table with last layer metrics
    def fmt_last(mean_std, bold=False):
        if mean_std is None or mean_std[0] is None:
            return "--"
        else:
            mean, std = mean_std
            if std is None:
                cell = f"{mean:.2f}"
            else:
                cell = f"{mean:.2f} ± {std:.2f}"
        if bold and mean_std is not None and mean_std[0] is not None:
            cell = f"\\textbf{{{cell}}}"
        return cell

    # Build last layer LaTeX table
    lines_last = []
    lines_last.append("\\begin{table}[htbp]")
    lines_last.append("\\centering")
    lines_last.append("\\caption{Per-concept interpretability metrics across models. Values are scores at the last layer (representations right before output).}")
    lines_last.append("\\label{tab:concept_metrics_last}")
    lines_last.append("\\resizebox{\\textwidth}{!}{%")
    # Build tabular column spec: l (concept) c (metric) then for each model one column (last layer score)
    colspec = 'lc' + 'c' * len(headers)
    lines_last.append(f"\\begin{{tabular}}{{{colspec}}}")
    lines_last.append("\\toprule")
    # Header row
    header_top = ["\\textbf{Concept}", "\\textbf{Metric}"]
    for h in headers:
        header_top.append(f"\\textbf{{{h}}}")
    lines_last.append(' & '.join(header_top) + " \\\\")
    lines_last.append("\\midrule")

    # Rows per concept
    for concept, metric_label in concepts:
        if concept == "Disease Progression":
            # Handle Disease Progression metrics separately
            if "CSFS" in metric_label:
                # CSFS row (lower is better)
                row_vals = []
                min_val = None
                for model_key, _ in ordered_models:
                    dp_data = last_metrics[concept].get(model_key, {})
                    mean_std = dp_data.get("CSFS", (None, None))
                    if mean_std[0] is not None and (min_val is None or mean_std[0] < min_val):
                        min_val = mean_std[0]
                    row_vals.append(mean_std)
                
                # Assemble CSFS row
                row = [concept, metric_label]
                for mean_std in row_vals:
                    row.append(fmt_last(mean_std, bold=(mean_std is not None and mean_std[0] is not None and min_val is not None and abs(mean_std[0] - min_val) < 1e-12)))
                lines_last.append(' & '.join(row) + " \\\\")
                
            elif "CSLS" in metric_label:
                # CSLS row (higher is better)
                row_vals = []
                max_val = None
                for model_key, _ in ordered_models:
                    dp_data = last_metrics[concept].get(model_key, {})
                    mean_std = dp_data.get("CSLS", (None, None))
                    if mean_std[0] is not None and (max_val is None or mean_std[0] > max_val):
                        max_val = mean_std[0]
                    row_vals.append(mean_std)
                
                # Assemble CSLS row
                row = [concept, metric_label]
                for mean_std in row_vals:
                    row.append(fmt_last(mean_std, bold=(mean_std is not None and mean_std[0] is not None and max_val is not None and abs(mean_std[0] - max_val) < 1e-12)))
                lines_last.append(' & '.join(row) + " \\\\")
        elif concept == "Drugs":
            # Handle Drugs metrics separately for mechanism and specialty
            if "mech" in metric_label:
                # Mechanism row (higher is better)
                row_vals = []
                max_val = None
                for model_key, _ in ordered_models:
                    drug_data = last_metrics[concept].get(model_key, {})
                    mean_std = drug_data.get("mech", (None, None))
                    if mean_std[0] is not None and (max_val is None or mean_std[0] > max_val):
                        max_val = mean_std[0]
                    row_vals.append(mean_std)
                
                # Assemble mechanism row
                row = [concept, metric_label]
                for mean_std in row_vals:
                    row.append(fmt_last(mean_std, bold=(mean_std is not None and mean_std[0] is not None and max_val is not None and abs(mean_std[0] - max_val) < 1e-12)))
                lines_last.append(' & '.join(row) + " \\\\")
                
            elif "spec" in metric_label:
                # Specialty row (higher is better)
                row_vals = []
                max_val = None
                for model_key, _ in ordered_models:
                    drug_data = last_metrics[concept].get(model_key, {})
                    mean_std = drug_data.get("spec", (None, None))
                    if mean_std[0] is not None and (max_val is None or mean_std[0] > max_val):
                        max_val = mean_std[0]
                    row_vals.append(mean_std)
                
                # Assemble specialty row
                row = [concept, metric_label]
                for mean_std in row_vals:
                    row.append(fmt_last(mean_std, bold=(mean_std is not None and mean_std[0] is not None and max_val is not None and abs(mean_std[0] - max_val) < 1e-12)))
                lines_last.append(' & '.join(row) + " \\\\")
        else:
            # Handle other concepts normally
            row_vals = []
            max_val = None
            for model_key, _ in ordered_models:
                mean_std = last_metrics[concept].get(model_key, (None, None))
                if mean_std[0] is not None and (max_val is None or mean_std[0] > max_val):
                    max_val = mean_std[0]
                row_vals.append(mean_std)

            # Assemble row
            row = [concept, metric_label]
            for mean_std in row_vals:
                row.append(fmt_last(mean_std, bold=(mean_std is not None and mean_std[0] is not None and max_val is not None and abs(mean_std[0] - max_val) < 1e-12)))
            lines_last.append(' & '.join(row) + " \\\\")

    lines_last.append("\\bottomrule")
    lines_last.append("\\end{tabular}%")
    lines_last.append("}")
    lines_last.append("\\end{table}")

    # Print and save last layer table
    table_last_text = '\n'.join(lines_last) + '\n'
    print("\n" + "="*80)
    print("LAST LAYER METRICS TABLE:")
    print("="*80)
    print(table_last_text)
    with open(out_tex_last, 'w') as f:
        f.write(table_last_text)
    print(f"Wrote LaTeX last layer metrics table → {out_tex_last}")
    
    return out_tex


def plot_selected_metrics_per_layer(results_dir: str = "results",
                               model_name_map=None,
                               out_pdf: str = None, save=True,
                               success_threshold: float = 0.5):
    """
    Plot a 4x4 grid of selected metrics over layers, overlaying curves for all attached models.

    Grid layout (rows x cols):
      1,1 - Age R^2 over layers
      1,2 - Age anisotropy (1D-ness measure) over layers
      1,3 - Age saliency
      1,4 - Age activation patching (fraction of successful patches)
      2,1 - Legend (consistent colors across plots)
      2,2 - Symptoms Silhouette
      2,3 - Symptoms saliency
      2,4 - Symptoms activation patching (fraction of successful patches)
      3,1 - Dosage saliency
      3,2 - Diseases Silhouette
      3,3 - Diseases saliency
      3,4 - Diseases activation patching (fraction of successful patches)
      4,1 - Drugs Silhouette (Mechanism)
      4,2 - Drugs Silhouette (Med. Specialty)
      4,3 - Drugs saliency
      4,4 - Drugs activation patching (fraction of successful patches)
 
    Args:
        results_dir: Directory containing results files
        model_name_map: Optional mapping of model names to display names
        out_pdf: Output PDF path
        save: Whether to save the plot
        success_threshold: Threshold for defining a successful patch (default: 0.5)
                         A successful patch is one where the patching effect > success_threshold
    """
    import numpy as np
    import os, pickle, json
    import matplotlib.pyplot as plt
    from sklearn.linear_model import LinearRegression
    from sklearn.metrics import silhouette_score

    os.makedirs(results_dir, exist_ok=True)

    # Preferred model columns and colors
    preferred_models = [
        ("Llama-3.3-70B-Instruct", "Llama 70B"),
        ("Gemma-3-27b-it", "Gemma 27B"),
        ("MedGemma-27b-text-it", "MedGemma 27B"),
        ("Qwen3-32B", "Qwen 32B"),
        ("gpt-oss-120b", "GPT-OSS 120B"),
        # ("Llama-3.2-1B-Instruct", "Llama 1B"),
    ]
    if model_name_map is None:
        model_name_map = {k: v for k, v in preferred_models}

    # Probe results_dir to discover present models
    present = set()
    try:
        allowed_exts = {'.pkl', '.npz', '.json'}
        for fn in os.listdir(results_dir):
            name, ext = os.path.splitext(fn)
            if ext not in allowed_exts:
                continue  # ignore PDFs, TXTs, etc.
            for prefix in [
                "age_umap_cache_", "age_umap_agg_",
                "symptom_umap_cache_", "symptom_umap_agg_",
                "disease_umap_cache_", "disease_umap_agg_",
                "top100-drugs_umap_cache_",
                "dosage_activation_patching_",
                # Keep saliency pickles that are exactly model-level
                "age_saliency_", "symptom_saliency_", "disease_saliency_", "top100-drugs_saliency_", "dosage_saliency_",
            ]:
                if fn.startswith(prefix):
                    key = fn[len(prefix):]
                    # strip extension (we already filtered to allowed ext)
                    if key.endswith(ext):
                        key = key[: -len(ext)]
                    # skip finegrained dosage ap keys
                    if prefix == "dosage_activation_patching_" and key.startswith("finegrained_"):
                        continue
                    present.add(key)
    except FileNotFoundError:
        pass

    # Order models: preferred first if present, then any extras
    ordered = []
    seen = set()
    for k, pretty in preferred_models:
        if k in present:
            ordered.append((k, pretty))
            seen.add(k)
    # for k in sorted(present):
    #     if k not in seen:
    #         ordered.append((k, model_name_map.get(k, k)))

    if not ordered:
        print("No models detected in results/. Nothing to plot.")
        return None

    # Consistent colors per model
    palette = [
        "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
        "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
    ]
    model_to_color = {mk: palette[i % len(palette)] for i, (mk, _) in enumerate(ordered)}

    # Helpers to load arrays with optional CI
    def _safe_pickle(path):
        try:
            with open(path, 'rb') as f:
                return pickle.load(f)
        except Exception:
            return None

    def _age_r2_per_layer(cache):
        if cache is None:
            return None
        embs = cache.get('embs')
        subj_markers = np.array(cache.get('subj_markers', []))
        ages = np.array(cache.get('ages', []), dtype=float)
        if embs is None or ages.size == 0 or subj_markers.size == 0:
            return None
        markers = ['o', 's', '*']
        r2_layers = []
        for X in embs:
            per_pron = []
            for m in markers:
                mask = (subj_markers == m)
                Xi = X[mask]
                yi = ages[mask]
                if Xi.shape[0] >= 3 and Xi.shape[1] >= 2 and np.std(yi) > 0:
                    try:
                        reg = LinearRegression().fit(Xi, yi)
                        per_pron.append(float(reg.score(Xi, yi)))
                    except Exception:
                        pass
            r2_layers.append(np.mean(per_pron) if per_pron else np.nan)
        arr = np.asarray(r2_layers, dtype=float)
        return arr

    def _anisotropy_from_npz(npz_path):
        try:
            agg = np.load(npz_path, allow_pickle=True)
            if 'aniso_mean' in agg.files:
                mean = np.asarray(agg['aniso_mean'], dtype=float)
                lo = hi = None
                if 'aniso_std' in agg.files:
                    std = np.asarray(agg['aniso_std'], dtype=float)
                    lo = mean - std
                    hi = mean + std
                return mean, lo, hi
        except Exception:
            pass
        return None, None, None

    def _saliency_mean_ci(pkl_path):
        try:
            data = _safe_pickle(pkl_path)
            if data and 'avg_sal_per_prompt' in data:
                # Use Act.compute_saliency_per_layer formula inline
                avg_sal_per_prompt = data['avg_sal_per_prompt']
                # Build matrix prompts x layers
                layers = None
                per_prompt = []
                for _, sal_dict in avg_sal_per_prompt.items():
                    # Derive num_layers by max parsed index
                    max_layer = -1
                    for name, val in sal_dict.items():
                        if not name.endswith("weight"):
                            continue
                        parts = name.split('.')
                        for p in parts:
                            if p.isdigit():
                                max_layer = max(max_layer, int(p))
                    num_layers = max_layer + 1 if max_layer >= 0 else 0
                    vec = np.zeros(num_layers, dtype=float)
                    for name, val in sal_dict.items():
                        if not name.endswith("weight"):
                            continue
                        parts = name.split('.')
                        li = None
                        for p in parts:
                            if p.isdigit():
                                li = int(p); break
                        if li is not None and li < num_layers:
                            vec[li] += float(val)
                    per_prompt.append(vec)
                if not per_prompt:
                    return None, None, None
                # Pad to same length
                L = max(len(v) for v in per_prompt)
                M = np.zeros((len(per_prompt), L), dtype=float)
                for i, v in enumerate(per_prompt):
                    M[i, :len(v)] = v
                mean = M.mean(axis=0)
                # t-approx CI
                import scipy.stats as stats
                n = M.shape[0]
                if n > 1:
                    std = M.std(axis=0, ddof=1)
                    tcrit = stats.t.ppf(0.975, df=n-1)
                    lo = mean - tcrit * (std / np.sqrt(n))
                    hi = mean + tcrit * (std / np.sqrt(n))
                else:
                    lo = hi = mean.copy()
                return mean, lo, hi
        except Exception:
            pass
        return None, None, None

    def _silhouette_from_npz(npz_path):
        try:
            agg = np.load(npz_path, allow_pickle=True)
            if 'coeffs' in agg.files:
                coeffs = np.asarray(agg['coeffs'], dtype=float)
                lo = agg['coeffs_ci_lower'] if 'coeffs_ci_lower' in agg.files else None
                hi = agg['coeffs_ci_upper'] if 'coeffs_ci_upper' in agg.files else None
                lo = np.asarray(lo, dtype=float) if lo is not None else None
                hi = np.asarray(hi, dtype=float) if hi is not None else None
                return coeffs, lo, hi
        except Exception:
            pass
        return None, None, None

    def _activation_patching_means(json_path, num_layers_hint=None, include_zeros=True):
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            res = data.get('all_patching_results', {})
            layer_to_vals = {}
            for lstr, entry in res.items():
                li = int(lstr)
                pe = entry.get('patching_effect', {})
                for _, v in pe.items():
                    if isinstance(v, (int, float)) and np.isfinite(v) and (include_zeros or v != 0):
                        layer_to_vals.setdefault(li, []).append(float(v))
            if not layer_to_vals and num_layers_hint is None:
                return None
            L = (max(layer_to_vals) + 1) if layer_to_vals else (num_layers_hint or 0)
            arr = np.zeros(L, dtype=float)
            for li in range(L):
                vals = layer_to_vals.get(li, [])
                arr[li] = float(np.mean(vals)) if vals else 0.0
            return arr
        except Exception:
            return None

    def _activation_patching_success_fraction(json_path, threshold, num_layers_hint=None):
        """
        Calculate the fraction of successful patches per layer.
        A successful patch is one where the patching effect > threshold.
        """
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            res = data.get('all_patching_results', {})
            layer_to_vals = {}
            for lstr, entry in res.items():
                li = int(lstr)
                pe = entry.get('patching_effect', {})
                for _, v in pe.items():
                    if isinstance(v, (int, float)) and np.isfinite(v):
                        layer_to_vals.setdefault(li, []).append(float(v))
            if not layer_to_vals and num_layers_hint is None:
                return None
            L = (max(layer_to_vals) + 1) if layer_to_vals else (num_layers_hint or 0)
            arr = np.zeros(L, dtype=float)
            for li in range(L):
                vals = layer_to_vals.get(li, [])
                if vals:
                    successful_patches = sum(1 for v in vals if v > threshold)
                    arr[li] = float(successful_patches) / len(vals)
                else:
                    arr[li] = 0.0
            return arr
        except Exception:
            return None

    # Styling constants
    TITLE_FS = 18
    LABEL_FS = 14
    TICK_FS = 12
    LEGEND_FS = 18
    LINEWIDTH = 5.0

    # Prepare figure 4x4
    fig, axes = plt.subplots(4, 4, figsize=(20, 16), squeeze=False)

    # Helper to plot with optional CI band
    def plot_series(ax, x, mean, lo=None, hi=None, label=None, color=None):
        if mean is None or len(mean) == 0:
            return
        ax.plot(x, mean, label=label, color=color, linewidth=LINEWIDTH)
        if lo is not None and hi is not None and len(lo) == len(mean) and len(hi) == len(mean):
            ax.fill_between(x, lo, hi, color=color, alpha=0.15)
        ax.grid(True, alpha=0.25, linestyle='--', linewidth=0.6)
        ax.tick_params(axis='both', labelsize=TICK_FS)

    # For each model, load and plot all panels
    for model_key, pretty in ordered:
        color = model_to_color[model_key]

        # Age R^2 per layer
        age_cache = _safe_pickle(os.path.join(results_dir, f"age_umap_cache_{model_key}.pkl"))
        age_r2 = _age_r2_per_layer(age_cache)
        if age_r2 is not None:
            x = np.arange(len(age_r2))
            plot_series(axes[0,0], x, age_r2, label=pretty, color=color)
            axes[0,0].set_title('Age R$^2$ (linear)', fontsize=TITLE_FS)
            axes[0,0].set_xlabel('Layer', fontsize=LABEL_FS); axes[0,0].set_ylabel('R$^2$', fontsize=LABEL_FS)

        # Age anisotropy (from age npz)
        a_mean, a_lo, a_hi = _anisotropy_from_npz(os.path.join(results_dir, f"age_umap_agg_{model_key}.npz"))
        if a_mean is not None:
            x = np.arange(len(a_mean))
            plot_series(axes[0,1], x, a_mean, a_lo, a_hi, label=pretty, color=color)
            axes[0,1].set_title('Age anisotropy (1D-ness)', fontsize=TITLE_FS)
            axes[0,1].set_xlabel('Layer', fontsize=LABEL_FS); axes[0,1].set_ylabel('Value', fontsize=LABEL_FS)

        # Age saliency (normalize per model by max)
        s_mean, s_lo, s_hi = _saliency_mean_ci(os.path.join(results_dir, f"age_saliency_{model_key}.pkl"))
        if s_mean is not None:
            denom = float(np.nanmax(s_mean)) if np.isfinite(np.nanmax(s_mean)) and np.nanmax(s_mean) > 0 else 1.0
            s_mean_n = s_mean / denom
            s_lo_n = (s_lo / denom) if s_lo is not None else None
            s_hi_n = (s_hi / denom) if s_hi is not None else None
            x = np.arange(len(s_mean_n))
            plot_series(axes[0,2], x, s_mean_n, s_lo_n, s_hi_n, label=pretty, color=color)
            axes[0,2].set_title('Age saliency (normalized)', fontsize=TITLE_FS)
            axes[0,2].set_xlabel('Layer', fontsize=LABEL_FS); axes[0,2].set_ylabel('Relative value', fontsize=LABEL_FS)
            axes[0,2].set_ylim(-0.1, 1.25)

        # Age activation patching (fraction of successful patches)
        ap_age = _activation_patching_success_fraction(os.path.join(results_dir, f"age_activation_patching_{model_key}.json"), success_threshold)
        if ap_age is not None and ap_age.size > 0:
            x = np.arange(len(ap_age))
            plot_series(axes[0,3], x, ap_age, label=pretty, color=color)
            axes[0,3].set_title('Age activation patching', fontsize=TITLE_FS)
            axes[0,3].set_xlabel('Layer', fontsize=LABEL_FS); axes[0,3].set_ylabel('Fraction successful', fontsize=LABEL_FS)
            axes[0,3].set_ylim(0, 1)

        # Symptoms silhouette
        sym_coeffs, sym_lo, sym_hi = _silhouette_from_npz(os.path.join(results_dir, f"symptom_umap_agg_{model_key}.npz"))
        if sym_coeffs is not None:
            x = np.arange(len(sym_coeffs))
            plot_series(axes[1,1], x, sym_coeffs, sym_lo, sym_hi, label=pretty, color=color)
            axes[1,1].set_title('Symptoms Silhouette', fontsize=TITLE_FS)
            axes[1,1].set_xlabel('Layer', fontsize=LABEL_FS); axes[1,1].set_ylabel('Score', fontsize=LABEL_FS)

        # Symptoms saliency (relative)
        s_mean, s_lo, s_hi = _saliency_mean_ci(os.path.join(results_dir, f"symptom_saliency_{model_key}.pkl"))
        if s_mean is not None:
            denom = float(np.nanmax(s_mean)) if np.isfinite(np.nanmax(s_mean)) and np.nanmax(s_mean) > 0 else 1.0
            s_mean_n = s_mean / denom
            s_lo_n = (s_lo / denom) if s_lo is not None else None
            s_hi_n = (s_hi / denom) if s_hi is not None else None
            x = np.arange(len(s_mean_n))
            plot_series(axes[1,2], x, s_mean_n, s_lo_n, s_hi_n, label=pretty, color=color)
            axes[1,2].set_title('Symptoms saliency (relative)', fontsize=TITLE_FS)
            axes[1,2].set_xlabel('Layer', fontsize=LABEL_FS); axes[1,2].set_ylabel('Relative value', fontsize=LABEL_FS)
            axes[1,2].set_ylim(-0.1, 1.25)

        # Symptoms activation patching (fraction of successful patches)
        ap_sym = _activation_patching_success_fraction(os.path.join(results_dir, f"symptom_activation_patching_{model_key}.json"), success_threshold)
        if ap_sym is not None and ap_sym.size > 0:
            x = np.arange(len(ap_sym))
            plot_series(axes[1,3], x, ap_sym, label=pretty, color=color)
            axes[1,3].set_title('Symptoms activation patching', fontsize=TITLE_FS)
            axes[1,3].set_xlabel('Layer', fontsize=LABEL_FS); axes[1,3].set_ylabel('Fraction successful', fontsize=LABEL_FS)
            axes[1,3].set_ylim(0, 1)

        # Diseases silhouette
        dis_coeffs, dis_lo, dis_hi = _silhouette_from_npz(os.path.join(results_dir, f"disease_umap_agg_{model_key}.npz"))
        if dis_coeffs is not None:
            x = np.arange(len(dis_coeffs))
            plot_series(axes[2,1], x, dis_coeffs, dis_lo, dis_hi, label=pretty, color=color)
            axes[2,1].set_title('Diseases Silhouette', fontsize=TITLE_FS)
            axes[2,1].set_xlabel('Layer', fontsize=LABEL_FS); axes[2,1].set_ylabel('Score', fontsize=LABEL_FS)

        # Diseases saliency (relative)
        s_mean, s_lo, s_hi = _saliency_mean_ci(os.path.join(results_dir, f"disease_saliency_{model_key}.pkl"))
        if s_mean is not None:
            denom = float(np.nanmax(s_mean)) if np.isfinite(np.nanmax(s_mean)) and np.nanmax(s_mean) > 0 else 1.0
            s_mean_n = s_mean / denom
            s_lo_n = (s_lo / denom) if s_lo is not None else None
            s_hi_n = (s_hi / denom) if s_hi is not None else None
            x = np.arange(len(s_mean_n))
            plot_series(axes[2,2], x, s_mean_n, s_lo_n, s_hi_n, label=pretty, color=color)
            axes[2,2].set_title('Diseases saliency (relative)', fontsize=TITLE_FS)
            axes[2,2].set_xlabel('Layer', fontsize=LABEL_FS); axes[2,2].set_ylabel('Relative value', fontsize=LABEL_FS)
            axes[2,2].set_ylim(-0.1, 1.25)

        # Diseases activation patching (fraction of successful patches)
        ap_dis = _activation_patching_success_fraction(os.path.join(results_dir, f"disease_activation_patching_{model_key}.json"), success_threshold)
        if ap_dis is not None and ap_dis.size > 0:
            x = np.arange(len(ap_dis))
            plot_series(axes[2,3], x, ap_dis, label=pretty, color=color)
            axes[2,3].set_title('Diseases activation patching', fontsize=TITLE_FS)
            axes[2,3].set_xlabel('Layer', fontsize=LABEL_FS); axes[2,3].set_ylabel('Fraction successful', fontsize=LABEL_FS)
            axes[2,3].set_ylim(0, 1)

        # Drugs silhouette (mechanism)
        drug_mech_coeffs, drug_mech_lo, drug_mech_hi = _silhouette_from_npz(os.path.join(results_dir, f"drugs_umap-mechanism_{model_key}.npz"))
        if drug_mech_coeffs is not None:
            x = np.arange(len(drug_mech_coeffs))
            plot_series(axes[3,0], x, drug_mech_coeffs, drug_mech_lo, drug_mech_hi, label=pretty, color=color)
            axes[3,0].set_title('Drugs Silhouette (Mechanism)', fontsize=TITLE_FS)
            axes[3,0].set_xlabel('Layer', fontsize=LABEL_FS); axes[3,0].set_ylabel('Score', fontsize=LABEL_FS)

        # Drugs saliency (relative)
        s_mean, s_lo, s_hi = _saliency_mean_ci(os.path.join(results_dir, f"top100-drugs_saliency_{model_key}.pkl"))
        if s_mean is not None:
            denom = float(np.nanmax(s_mean)) if np.isfinite(np.nanmax(s_mean)) and np.nanmax(s_mean) > 0 else 1.0
            s_mean_n = s_mean / denom
            s_lo_n = (s_lo / denom) if s_lo is not None else None
            s_hi_n = (s_hi / denom) if s_hi is not None else None
            x = np.arange(len(s_mean_n))
            plot_series(axes[3,2], x, s_mean_n, s_lo_n, s_hi_n, label=pretty, color=color)
            axes[3,2].set_title('Drugs saliency (relative)', fontsize=TITLE_FS)
            axes[3,2].set_xlabel('Layer', fontsize=LABEL_FS); axes[3,2].set_ylabel('Relative value', fontsize=LABEL_FS)
            axes[3,2].set_ylim(-0.1, 1.25)

        # Drugs activation patching (fraction of successful patches)
        ap_drug = _activation_patching_success_fraction(os.path.join(results_dir, f"top100-drugs_activation_patching_{model_key}.json"), success_threshold)
        if ap_drug is not None and ap_drug.size > 0:
            x = np.arange(len(ap_drug))
            plot_series(axes[3,3], x, ap_drug, label=pretty, color=color)
            axes[3,3].set_title('Drugs activation patching', fontsize=TITLE_FS)
            axes[3,3].set_xlabel('Layer', fontsize=LABEL_FS); axes[3,3].set_ylabel('Fraction successful', fontsize=LABEL_FS)
            axes[3,3].set_ylim(0, 1)

        # Dosage saliency (relative)
        s_mean, s_lo, s_hi = _saliency_mean_ci(os.path.join(results_dir, f"dosage_saliency_{model_key}.pkl"))
        if s_mean is not None:
            denom = float(np.nanmax(s_mean)) if np.isfinite(np.nanmax(s_mean)) and np.nanmax(s_mean) > 0 else 1.0
            s_mean_n = s_mean / denom
            s_lo_n = (s_lo / denom) if s_lo is not None else None
            s_hi_n = (s_hi / denom) if s_hi is not None else None
            x = np.arange(len(s_mean_n))
            plot_series(axes[2,0], x, s_mean_n, s_lo_n, s_hi_n, label=pretty, color=color)
            axes[2,0].set_title('Dosage saliency (relative)', fontsize=TITLE_FS)
            axes[2,0].set_xlabel('Layer', fontsize=LABEL_FS); axes[2,0].set_ylabel('Relative value', fontsize=LABEL_FS)
            axes[2,0].set_ylim(-0.1, 1.25)

        # Drugs silhouette (medical specialty)
        drug_spec_coeffs, drug_spec_lo, drug_spec_hi = _silhouette_from_npz(os.path.join(results_dir, f"drugs_umap-specialty_{model_key}.npz"))
        if drug_spec_coeffs is not None:
            x = np.arange(len(drug_spec_coeffs))
            plot_series(axes[3,1], x, drug_spec_coeffs, drug_spec_lo, drug_spec_hi, label=pretty, color=color)
            axes[3,1].set_title('Drugs Silhouette (Med. Specialty)', fontsize=TITLE_FS)
            axes[3,1].set_xlabel('Layer', fontsize=LABEL_FS); axes[3,1].set_ylabel('Score', fontsize=LABEL_FS)

    # Legend panel (2,1)
    leg_ax = axes[1,0]
    leg_ax.axis('off')
    for model_key, pretty in ordered:
        color = model_to_color[model_key]
        leg_ax.plot([], [], color=color, label=pretty, linewidth=LINEWIDTH + 1.0)
    leg_ax.legend(loc='center', frameon=False, ncol=1, fontsize=LEGEND_FS)
    leg_ax.set_title('Models', fontsize=TITLE_FS)

    # Tight layout with a small margin (2%) on all sides
    plt.tight_layout(rect=[0.12, 0.04, 0.96, 0.96])
    if out_pdf is None:
        out_pdf = os.path.join(results_dir, "selected_metrics_per_layer.pdf")
    if save:
        fig.savefig(out_pdf, dpi=300, bbox_inches='tight')
        print(f"Saved selected metrics grid → {out_pdf}")
    plt.show()
    return out_pdf


def plot_lesioning_metrics_per_layer(results_dir: str = "results",
                                   model_name_map=None,
                                   out_pdf: str = None, save=True):
    """
    Plot a 2x3 grid of lesioning degradation metrics over layers, overlaying curves for all attached models.

    Grid layout (rows x cols):
      1,1 - Age Lesioning Degradation over layers
      1,2 - Symptoms Lesioning Degradation over layers  
      1,3 - Diseases Lesioning Degradation over layers
      2,1 - Drugs Lesioning Degradation over layers
      2,2 - Dosage Lesioning Degradation over layers
      2,3 - Legend

    Uses cached lesioning JSON files from each analysis.
    """
    import numpy as np
    import os, json
    import matplotlib.pyplot as plt
    from scipy import stats

    os.makedirs(results_dir, exist_ok=True)

    # Preferred model columns and colors
    preferred_models = [
        ("Llama-3.3-70B-Instruct", "Llama 70B"),
        ("Gemma-3-27b-it", "Gemma 27B"),
        ("MedGemma-27b-text-it", "MedGemma 27B"),
        ("Qwen3-32B", "Qwen 32B"),
        ("gpt-oss-120b", "GPT-OSS 120B"),
    ]
    if model_name_map is None:
        model_name_map = {k: v for k, v in preferred_models}

    # Probe results_dir to discover present models
    present = set()
    try:
        for fn in os.listdir(results_dir):
            if not fn.endswith('.json'):
                continue
            for prefix in [
                "age_lesioning_", "symptom_lesioning_", "disease_lesioning_", 
                "top100-drugs_lesioning_", "dosage_lesioning_"
            ]:
                if fn.startswith(prefix):
                    key = fn[len(prefix):-5]  # Remove prefix and .json
                    present.add(key)
    except FileNotFoundError:
        pass

    # Order models: preferred first if present, then any extras
    ordered = []
    seen = set()
    for k, pretty in preferred_models:
        if k in present:
            ordered.append((k, pretty))
            seen.add(k)

    if not ordered:
        print("No models detected in results/. Nothing to plot.")
        return None

    # Consistent colors per model
    palette = [
        "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
        "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
    ]
    model_to_color = {mk: palette[i % len(palette)] for i, (mk, _) in enumerate(ordered)}

    def _load_lesioning_data(json_path):
        """Load lesioning data and compute per-layer averages with confidence intervals."""
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            
            # Extract average scores across all prompts for each layer
            prompt_scores = data.get("prompt_scores", {})
            layer_scores = {}
            
            for prompt_result in prompt_scores.values():
                scores_and_justifications = prompt_result.get("scores_and_justifications", {})
                for layer_idx, score_info in scores_and_justifications.items():
                    layer_idx = int(layer_idx)
                    if layer_idx not in layer_scores:
                        layer_scores[layer_idx] = []
                    layer_scores[layer_idx].append(score_info["score"])
            
            # Calculate average scores and confidence intervals for each layer
            lesion_scores = []
            lesion_ci_lower = []
            lesion_ci_upper = []
            
            for layer_idx in sorted(layer_scores.keys()):
                scores = layer_scores[layer_idx]
                avg_score = sum(scores) / len(scores)
                std_score = np.std(scores, ddof=1)  # Sample standard deviation
                
                # Calculate 95% confidence interval using t-distribution
                n_prompts = len(scores)
                confidence_level = 0.95
                alpha = 1 - confidence_level
                t_critical = stats.t.ppf(1 - alpha/2, df=n_prompts-1)
                
                ci_margin = t_critical * (std_score / np.sqrt(n_prompts))
                ci_lower = avg_score - ci_margin
                ci_upper = avg_score + ci_margin
                
                lesion_scores.append(avg_score)
                lesion_ci_lower.append(ci_lower)
                lesion_ci_upper.append(ci_upper)
            
            return lesion_scores, lesion_ci_lower, lesion_ci_upper
        except Exception as e:
            print(f"Warning: Could not load lesioning data from {json_path}: {e}")
            return None, None, None

    # Styling constants
    TITLE_FS = 16
    LABEL_FS = 14
    TICK_FS = 12
    LEGEND_FS = 16
    LINEWIDTH = 5.0

    # # Prepare figure 2x3 with extra space for leftmost plots
    fig, axes = plt.subplots(2, 3, figsize=(20, 8), squeeze=False)
    
    # # Adjust leftmost plots to have extra space for text labels


    # Helper to plot with CI band
    def plot_series(ax, x, mean, lo=None, hi=None, label=None, color=None):
        if mean is None or len(mean) == 0:
            return
        ax.plot(x, mean, label=label, color=color, linewidth=LINEWIDTH)
        if lo is not None and hi is not None and len(lo) == len(mean) and len(hi) == len(mean):
            ax.fill_between(x, lo, hi, color=color, alpha=0.15)
        ax.grid(True, alpha=0.25, linestyle='--', linewidth=0.6)
        ax.tick_params(axis='both', labelsize=TICK_FS)

    # For each model, load and plot all panels
    for model_key, pretty in ordered:
        color = model_to_color[model_key]

        # Age lesioning
        age_scores, age_lo, age_hi = _load_lesioning_data(os.path.join(results_dir, f"age_lesioning_{model_key}.json"))
        if age_scores is not None:
            print(f"Plotting Age lesioning scores")
            x = np.arange(len(age_scores))
            plot_series(axes[0,0], x, age_scores, age_lo, age_hi, label=pretty, color=color)
            axes[0,0].set_title('Age Lesioning Degradation', fontsize=TITLE_FS)
            axes[0,0].set_xlabel('Layer', fontsize=LABEL_FS)
            axes[0,0].set_ylabel('Degradation Score', fontsize=LABEL_FS)
            axes[0,0].set_ylim(1, 10)
            axes[0,0].set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
            axes[0,0].set_yticklabels(['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'])
            # Add text labels for scale interpretation
            axes[0,0].text(-10, 0, 'no\ndegradation', transform=axes[0,0].transData, ha='center', va='center', fontsize=LABEL_FS-2)
            axes[0,0].text(-10, 11, 'significant\ndegradation', transform=axes[0,0].transData, ha='center', va='center', fontsize=LABEL_FS-2)

        # Symptoms lesioning
        symptom_scores, symptom_lo, symptom_hi = _load_lesioning_data(os.path.join(results_dir, f"symptom_lesioning_{model_key}.json"))
        if symptom_scores is not None:
            print(f"Plotting Symptoms lesioning scores")
            x = np.arange(len(symptom_scores))
            plot_series(axes[0,1], x, symptom_scores, symptom_lo, symptom_hi, label=pretty, color=color)
            axes[0,1].set_title('Symptoms Lesioning Degradation', fontsize=TITLE_FS)
            axes[0,1].set_xlabel('Layer', fontsize=LABEL_FS)
            axes[0,1].set_ylabel('Degradation Score', fontsize=LABEL_FS)
            axes[0,1].set_ylim(1, 10)
            axes[0,1].set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
            axes[0,1].set_yticklabels(['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'])
            # Add text labels for scale interpretation
            # axes[0,1].text(-0.15, 1, 'no change', transform=axes[0,1].transData, ha='right', va='center', fontsize=LABEL_FS-2)
            # axes[0,1].text(-0.15, 10, 'significant\ndisruption', transform=axes[0,1].transData, ha='right', va='center', fontsize=LABEL_FS-2)

        # Diseases lesioning
        disease_scores, disease_lo, disease_hi = _load_lesioning_data(os.path.join(results_dir, f"disease_lesioning_{model_key}.json"))
        if disease_scores is not None:
            print(f"Plotting Diseases lesioning scores")
            x = np.arange(len(disease_scores))
            plot_series(axes[0,2], x, disease_scores, disease_lo, disease_hi, label=pretty, color=color)
            axes[0,2].set_title('Diseases Lesioning Degradation', fontsize=TITLE_FS)
            axes[0,2].set_xlabel('Layer', fontsize=LABEL_FS)
            axes[0,2].set_ylabel('Degradation Score', fontsize=LABEL_FS)
            axes[0,2].set_ylim(1, 10)
            axes[0,2].set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
            axes[0,2].set_yticklabels(['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'])
            # Add text labels for scale interpretation
            # axes[0,2].text(-0.15, 1, 'no change', transform=axes[0,2].transData, ha='right', va='center', fontsize=LABEL_FS-2)
            # axes[0,2].text(-0.15, 10, 'significant\ndisruption', transform=axes[0,2].transData, ha='right', va='center', fontsize=LABEL_FS-2)

        # Drugs lesioning
        drug_scores, drug_lo, drug_hi = _load_lesioning_data(os.path.join(results_dir, f"top100-drugs_lesioning_{model_key}.json"))
        if drug_scores is not None:
            print(f"Plotting Drugs lesioning scores")
            x = np.arange(len(drug_scores))
            plot_series(axes[1,0], x, drug_scores, drug_lo, drug_hi, label=pretty, color=color)
            axes[1,0].set_title('Drugs Lesioning Degradation', fontsize=TITLE_FS)
            axes[1,0].set_xlabel('Layer', fontsize=LABEL_FS)
            axes[1,0].set_ylabel('Degradation Score', fontsize=LABEL_FS)
            axes[1,0].set_ylim(1, 10)
            axes[1,0].set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
            axes[1,0].set_yticklabels(['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'])
            # Add text labels for scale interpretation
            axes[1,0].text(-10, 0, 'no\ndegradation', transform=axes[1,0].transData, ha='center', va='center', fontsize=LABEL_FS-2)
            axes[1,0].text(-10, 11, 'significant\ndegradation', transform=axes[1,0].transData, ha='center', va='center', fontsize=LABEL_FS-2)

        # Dosage lesioning
        dosage_scores, dosage_lo, dosage_hi = _load_lesioning_data(os.path.join(results_dir, f"dosage_lesioning_{model_key}.json"))
        if dosage_scores is not None:
            print(f"Plotting Dosage lesioning scores")
            x = np.arange(len(dosage_scores))
            plot_series(axes[1,1], x, dosage_scores, dosage_lo, dosage_hi, label=pretty, color=color)
            axes[1,1].set_title('Dosage Lesioning Degradation', fontsize=TITLE_FS)
            axes[1,1].set_xlabel('Layer', fontsize=LABEL_FS)
            axes[1,1].set_ylabel('Degradation Score', fontsize=LABEL_FS)
            axes[1,1].set_ylim(1, 10)
            axes[1,1].set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
            axes[1,1].set_yticklabels(['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'])
            # Add text labels for scale interpretation
            # axes[1,1].text(-0.15, 1, 'no change', transform=axes[1,1].transData, ha='right', va='center', fontsize=LABEL_FS-2)
            # axes[1,1].text(-0.15, 10, 'significant\ndisruption', transform=axes[1,1].transData, ha='right', va='center', fontsize=LABEL_FS-2)

    # Legend panel (2,3)
    leg_ax = axes[1,2]
    leg_ax.axis('off')
    for model_key, pretty in ordered:
        color = model_to_color[model_key]
        leg_ax.plot([], [], color=color, label=pretty, linewidth=LINEWIDTH + 1.0)
    leg_ax.legend(loc='center', frameon=False, ncol=1, fontsize=LEGEND_FS)
    leg_ax.set_title('Models', fontsize=TITLE_FS)

    # Tight layout with a small margin (2%) on all sides
    plt.tight_layout(rect=[0.2, 0.04, 0.96, 0.96])
    # Age plot (0,0) - move it right and make it narrower
    # axes[0,0].set_position([0, 0.6, 0, 0])  # [left, bottom, width, height]
    # # Drugs plot (1,0) - move it right and make it narrower  
    # axes[1,0].set_position([0.05, 0.2, 0.3, 0.4])  # [left, bottom, width, height]

    if out_pdf is None:
        out_pdf = os.path.join(results_dir, "lesioning_metrics_per_layer.pdf")
    if save:
        fig.savefig(out_pdf, dpi=300, bbox_inches='tight')
        print(f"Saved lesioning metrics grid → {out_pdf}")
    plt.show()
    return out_pdf
