"""
UMAP of per-layer activations
Age sweep 1–100 with three subject variants
"""
import math, torch, umap, matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize
from transformers import AutoModelForCausalLM, AutoTokenizer
from functools import partial
import numpy as np
import argparse
from sklearn.metrics import silhouette_score
from sklearn.linear_model import LinearRegression
from collections import defaultdict
import pickle
import os
import json
from scipy import stats
from matplotlib.gridspec import GridSpec


from saliency import merge_saliency_dicts, load_saliency, print_top_saliencies, save_saliency

from act import Act
from maps import *

class AgeAct(Act):
    def __init__(self, tok, model, processor, MODEL_NAME, PRINT_REPLIES, MIN_AGE, MAX_AGE, TEMPERATURE, USE_ACCELERATE=False):
        super().__init__(tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=USE_ACCELERATE)
        self.MIN_AGE = MIN_AGE
        self.MAX_AGE = MAX_AGE
        self.analysis_name = "age"
        self.set_files()

    def plot_umap(self, embs, replies, subj_markers, ages, out_file, layer_list=None, nr_cols=5, fontsize=24, plot_title=False, prediction_layer=None, gender_prediction_discontinuity='she', special_plots=0):

        cmap   = get_cmap("YlGnBu")
        norm   = Normalize(vmin=1, vmax=self.MAX_AGE)
        colors = [cmap(norm(a)) for a in ages]

        # Check if this is a selected layers plot (layer_list provided) or full plot
        if layer_list is not None:
            # Selected layers plot: expect exactly 4 layers + 2 special plots
            assert len(layer_list) == 4, "Layer list must be 4 layers"
            n_layers_to_plot = 4  # Only plot first 4 selected layers
            total_plots = n_layers_to_plot + special_plots
        else:
            # Full plot: use all layers, no special plots
            layer_list = list(range(self.num_layers))
            if self.num_layers > 85: # if more than 85, show every 3rd layer
                layer_list = list(range(0, self.num_layers, 3))
                
            if self.num_layers > 43: # if more than 42, show every other layer
                layer_list = list(range(0, self.num_layers, 2))
            
            n_layers_to_plot = len(layer_list)

            total_plots = n_layers_to_plot + special_plots
        
        n_cols, n_rows = nr_cols, math.ceil(total_plots / nr_cols)
        fig, axes = plt.subplots(n_rows, n_cols,
                                figsize=(n_cols*4, n_rows*4),
                                squeeze=False)

        marker_map = {"o": "He is X years old. How old is he?", "s": "She is X years old. How old is she?", "*": "Someone is X years old. How old are they?"}  # for legend

        # Plot UMAP layers
        for plot_idx in range(n_layers_to_plot):
            layer_idx = layer_list[plot_idx]
            ax = axes[plot_idx // n_cols, plot_idx % n_cols]
            emb = embs[layer_idx]
            for mark in marker_map:                         # plot each subject separately
                idx = [i for i, m in enumerate(subj_markers) if m == mark]
                ax.scatter(emb[idx, 0], emb[idx, 1],
                        c=np.array(colors)[idx],
                        marker=mark, s=240, edgecolors='k', linewidths=0.2)
            ax.set_title(f"Layer {layer_idx}", fontsize=fontsize)
            
            # Add axis labels and ticks only for selected layers (when layer_list is not None)
            if layer_list is not None and special_plots > 0:
                ax.set_xlabel("UMAP1", fontsize=fontsize-4)
                ax.set_ylabel("UMAP2", fontsize=fontsize-4)
                ax.tick_params(axis='both', which='major', labelsize=fontsize-6)
            else:
                ax.set_xticks([]); ax.set_yticks([])
            ax.grid(True, ls="--", alpha=0.15)

        # Add special plots if we have selected layers
        if special_plots > 0:
            # 5th plot: True age vs Predicted age scatter plot
            ax_pred = axes[(n_layers_to_plot) // n_cols, (n_layers_to_plot) % n_cols]
            
            # Use selected gender embeddings for age prediction
            gender_marker_map = {"he": "o", "she": "s", "someone": "*"}
            selected_marker = gender_marker_map[gender_prediction_discontinuity]
            gender_mask = np.array(subj_markers) == selected_marker
            gender_ages = np.array(ages)[gender_mask]
            gender_embs = [emb[gender_mask] for emb in embs]
            
            # Auto-detect layer with highest age discontinuity if prediction_layer is None
            if prediction_layer is None:
                # Calculate age discontinuities for each layer using selected gender embeddings
                layer_discontinuities = []
                layer_max_discontinuities = []  # Store max discontinuity per layer
                layer_max_ages = []  # Store age at max discontinuity per layer
                
                for layer_idx in range(len(gender_embs)):
                    total_discontinuity = 0
                    max_discontinuity = 0
                    max_discontinuity_age = 0
                    
                    for age in range(1, self.MAX_AGE):
                        age_mask = (gender_ages == age)
                        age_plus_1_mask = (gender_ages == age + 1)
                        
                        if np.any(age_mask) and np.any(age_plus_1_mask):
                            emb_age = gender_embs[layer_idx][age_mask]
                            emb_age_plus_1 = gender_embs[layer_idx][age_plus_1_mask]
                            
                            if len(emb_age) > 0 and len(emb_age_plus_1) > 0:
                                mean_emb_age = np.mean(emb_age, axis=0)
                                mean_emb_age_plus_1 = np.mean(emb_age_plus_1, axis=0)
                                distance = np.linalg.norm(mean_emb_age - mean_emb_age_plus_1)
                                total_discontinuity += distance
                                
                                # Track maximum discontinuity for this layer
                                if distance > max_discontinuity:
                                    max_discontinuity = distance
                                    max_discontinuity_age = age
                    
                    layer_discontinuities.append(total_discontinuity)
                    layer_max_discontinuities.append(max_discontinuity)
                    layer_max_ages.append(max_discontinuity_age)
                
                # Select layer with highest total discontinuity
                prediction_layer = np.argmax(layer_discontinuities)
                print(f"Auto-selected layer {prediction_layer} for age prediction (highest discontinuity in {gender_prediction_discontinuity} embeddings)")
                
                # Print ages with largest discontinuities for each layer
                print(f"Largest discontinuities by layer ({gender_prediction_discontinuity} embeddings):")
                for layer_idx in range(len(gender_embs)):
                    print(f"  Layer {layer_idx}: age {layer_max_ages[layer_idx]} (discontinuity: {layer_max_discontinuities[layer_idx]:.4f})")
            
            # Use the specified layer for prediction
            pred_emb = gender_embs[prediction_layer]  # Use specified layer
            
            # Fit linear regression
            reg = LinearRegression()
            reg.fit(pred_emb, gender_ages)
            predicted_ages = reg.predict(pred_emb)
            r2_score = reg.score(pred_emb, gender_ages)
            
            # Plot scatter
            ax_pred.scatter(gender_ages, predicted_ages, c=gender_ages, cmap=cmap, marker=selected_marker, s=150, alpha=0.7, edgecolors='k', linewidths=0.2)
            
            # Plot regression line
            min_age, max_age = min(gender_ages), max(gender_ages)
            line_ages = np.linspace(min_age, max_age, 100)
            # Create embeddings for the line by interpolating between min and max embeddings
            min_emb = pred_emb[np.argmin(gender_ages)]
            max_emb = pred_emb[np.argmax(gender_ages)]
            line_embs = np.array([min_emb + (max_emb - min_emb) * (i / 99) for i in range(100)])
            line_pred = reg.predict(line_embs)
            ax_pred.plot(line_ages, line_pred, 'r-', linewidth=2, label=f'R² = {r2_score:.3f}')
            
            ax_pred.set_xlabel('True Age', fontsize=fontsize-4)
            ax_pred.set_ylabel('Predicted Age', fontsize=fontsize-4)
            ax_pred.set_title(f'Age Prediction (Layer {prediction_layer})', fontsize=fontsize)
            ax_pred.grid(True, alpha=0.3)
            ax_pred.legend(fontsize=fontsize-6)
            
            # 6th plot: Discontinuity heatmap
            ax_heat = axes[(n_layers_to_plot + 1) // n_cols, (n_layers_to_plot + 1) % n_cols]
            
            # Create discontinuity matrix: layers x ages using selected gender embeddings
            discontinuity_matrix = np.zeros((len(gender_embs), self.MAX_AGE))
            
            for layer_idx in range(len(gender_embs)):
                for age in range(1, self.MAX_AGE):
                    # Find selected gender embeddings for this age and age+1
                    age_mask = (gender_ages == age)
                    age_plus_1_mask = (gender_ages == age + 1)
                    
                    if np.any(age_mask) and np.any(age_plus_1_mask):
                        emb_age = gender_embs[layer_idx][age_mask]
                        emb_age_plus_1 = gender_embs[layer_idx][age_plus_1_mask]
                        
                        # Calculate average distance between consecutive ages
                        if len(emb_age) > 0 and len(emb_age_plus_1) > 0:
                            # Use mean embeddings for each age
                            mean_emb_age = np.mean(emb_age, axis=0)
                            mean_emb_age_plus_1 = np.mean(emb_age_plus_1, axis=0)
                            distance = np.linalg.norm(mean_emb_age - mean_emb_age_plus_1)
                            discontinuity_matrix[layer_idx, age-1] = distance
            
            # Normalize distances for color mapping
            # max_dist = np.max(discontinuity_matrix[discontinuity_matrix > 0])
            # if max_dist > 0:
            #     discontinuity_matrix = discontinuity_matrix / max_dist
            
            # Create heatmap
            im = ax_heat.imshow(discontinuity_matrix, cmap='Reds', aspect='auto', origin='lower')
            ax_heat.set_xlabel('Age', fontsize=fontsize-4)
            ax_heat.set_ylabel('Layer', fontsize=fontsize-4)
            ax_heat.set_title(f'Age Discontinuity Heatmap ({gender_prediction_discontinuity})', fontsize=fontsize)
            ax_heat.set_xticks(range(0, self.MAX_AGE, 20))
            ax_heat.set_xticklabels(range(1, self.MAX_AGE+1, 20))
            
            # Add colorbar
            cbar = plt.colorbar(im, ax=ax_heat, orientation='vertical')
            cbar.set_label('Embedding Distance', fontsize=fontsize-2)

        # turn off any empty panels
        for i in range(total_plots, n_rows*n_cols):
            axes[i // n_cols, i % n_cols].axis("off")

        # shared colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm); sm.set_array([])

        # ── color-bar positioning based on number of layers ──────────────────────
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes  # add at top of file

        # if total_plots < 9:
        # Place colorbar on top-right for fewer layers
        top_right_ax = axes[0][-1]
        cax = inset_axes(
                top_right_ax,
                width="60%", height="6%",          # % of parent axis
                loc='lower center',
                bbox_to_anchor=(0, 1.33, 1, 1),   # (x, y, w, h) in axes fraction - moved up and left
                bbox_transform=top_right_ax.transAxes,
                borderpad=0)
        # else:
        #     # Place colorbar on bottom-right for more layers
        #     bottom_right_ax = axes[-1][-1]
        #     cax = inset_axes(
        #             bottom_right_ax,
        #             width="60%", height="6%",          # % of parent axis
        #             loc='lower right',
        #             bbox_to_anchor=(-0.25, 0.25, 1, 1),   # (x, y, w, h) in axes fraction
        #             bbox_transform=bottom_right_ax.transAxes,
        #             borderpad=0)

        cbar = fig.colorbar(sm, cax=cax, orientation='horizontal')
        cbar.ax.tick_params(labelsize=fontsize+1)          # slightly larger than other fonts
        cbar.set_label("Age", fontsize=fontsize+1, labelpad=-20)         # moved closer to colorbar
        ticks = [self.MIN_AGE,self.MAX_AGE]   # 1, 10, 20, …, 100
        cbar.set_ticks(ticks)
        cbar.set_ticklabels([str(t) for t in ticks])   # optional—ensures string labels

        # single legend for markers
        handles_leg = [plt.Line2D([], [], marker=m, color='blue', markerfacecolor='gray',
                                markeredgecolor='k', markersize=25, linestyle='',
                                label=lbl)
                    for m, lbl in marker_map.items()]
        fig.legend(handles=handles_leg, loc='lower center',
                ncol=len(marker_map), frameon=False, fontsize=fontsize-4)

        if plot_title:
            fig.suptitle(f"{self.MODEL_NAME} - UMAP representations for ages {self.MIN_AGE}-{self.MAX_AGE}", fontsize=fontsize+4)
        fig.tight_layout(
            rect=[0, 0.03, 1, 0.94],   # keep your outer margins
            pad=2.5,                   # global padding around the whole figure
            w_pad=2.0,                 # width padding between subplots (in font-size units)
            h_pad=2.5                  # height padding between subplots
        )

        
        fig.savefig(out_file, dpi=300)
        print(f"\nUMAP figure saved → {out_file}")
        plt.show()

    

# def runAgeAnalysis(tok, model, MODEL_NAME, MIN_AGE, MAX_AGE, PRINT_REPLIES):
    def runUmap(self, load_cached=True):
        
        # Try cache first
        cache = None
        if load_cached and os.path.exists(self.umapCacheFile):
            with open(self.umapCacheFile, "rb") as f:
                cache = pickle.load(f)
                print(f"Loaded cached UMAP data from {self.umapCacheFile}")
                embs = cache.get('embs')
                replies = cache.get('replies')
                subj_markers = cache.get('subj_markers')
                ages = cache.get('ages')
        else:

            # ═══════════ 2. Build prompts ═════════════════════════════════════════════
            subjects = [("He",       "he",   "o"),   # marker
                        ("She",      "she",  "s"),
                        ("Someone",  "they", "*")]

            prompts, ages, subj_markers = [], [], []
            for age in range(self.MIN_AGE, self.MAX_AGE):                          # ages 1-100
                for subj, pron, mark in subjects:
                    q = f"How old is {pron}?" if pron != "they" else "How old are they?"
                    prompts.append(f"{subj} is {age} years old.") # He is 20 years old. How old is he? 20 years old
                    ages.append(age)
                    subj_markers.append(mark)

            # Colormap yellow → blue

            # ═══════════ 3. Forward hook to grab activations ══════════════════════════
            handles, layer_buffers = self.register_hooks()

            # ═══════════ 4. Run prompts (collect + print model reply) ═════════════════
            replies = self.run_model(prompts)

            for h in handles: h.remove()

            # ═══════════ 5. Per-layer UMAP reductions ════════════════════════════════
            embs = self.umap_fit(layer_buffers, n_components=2)

            # Save comprehensive cache for fast reloads (PKL)
            cache_out = {
                'embs': embs,
                'replies': replies,
                'subj_markers': subj_markers,
                'ages': ages,
            }
            # cache_out.update(stats)
            with open(self.umapCacheFile, "wb") as f:
                pickle.dump(cache_out, f)
            print(f"Saved UMAP cache → {self.umapCacheFile}")

        # ═══════════ 6. Plot grid ════════════════════════════════════════════════
        out_file = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}.pdf"
        self.plot_umap(embs, replies, subj_markers, ages, out_file)
        
        selected_layers = self.selectLayersForPlots(nrPlots=4)
        print(f"Selected layers: {selected_layers}")
        out_file_selected = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}_selected_he.pdf"
        self.plot_umap(embs, replies, subj_markers, ages, out_file_selected, layer_list=selected_layers, nr_cols=3, fontsize=16, plot_title=False, gender_prediction_discontinuity='he', special_plots=2)

        out_file_selected = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}_selected_she.pdf"
        self.plot_umap(embs, replies, subj_markers, ages, out_file_selected, layer_list=selected_layers, nr_cols=3, fontsize=16, plot_title=False, gender_prediction_discontinuity='she', special_plots=2)

        out_file_selected = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}_selected_someone.pdf"
        self.plot_umap(embs, replies, subj_markers, ages, out_file_selected, layer_list=selected_layers, nr_cols=3, fontsize=16, plot_title=False, gender_prediction_discontinuity='someone', special_plots=2)


        # ═══════════ 8. Compute thickness of principal curve ══════════════════════
        print("Computing principal curve thickness...")
        # filter the embs to only include the points with "someone"
        someone_mask = np.array(subj_markers) == "*"
        embs_someone = [emb[someone_mask] for emb in embs]
        stats = manifold_1d_tests_per_layer(embs_someone, k_graph=3, graph_mode="mutual-knn")
        # Save stats to NPZ (for maps and plotMap)
        np.savez(self.umapAggFile, **stats)
        

    def gen_saliency_prompts(self):
        # each prompt needs to have a target token on which saliency is computed
        prefix = 'Answer with a single number, nothing else.'
        return [
            (prefix + "At what age can a person vote in national elections in the United States?", "18", "voting_age"),
            (prefix + "At what age is someone legally considered an adult?", "18", "adult_age"),
            (prefix + "At what age can someone obtain a driver's license in the United States?", "16", "driver_age"),
            (prefix + "At what age can someone legally buy alcohol in the United States?", "21", "alcohol_age"),
            (prefix + "At what age does a person typically finish high school?", "18", "high_school_age"),
            (prefix + "At what age is someone eligible for senior citizen discounts in the United States?", "65", "senior_discount_age"),
            (prefix + "At what age can someone legally buy cigarettes in the United States?", "21", "cigarette_age"),
            (prefix + "At what age can someone enlist in the military with parental consent in the United States?", "17", "military_age"),
            (prefix + "At what age can someone rent a car from most companies in the United States?", "25", "rent_car_age"),
            (prefix + "At what age can a person run for president of the United States?", "35", "president_age"),
            (prefix + "At what age does a typical child start kindergarten?", "5", "kindergarten_age"),
        ]


    def plotMap(self):
        """
        Plots per-layer:
        - Principal-curve thickness (tau) and normalized thickness (tau_norm)
        - TWO-NN intrinsic dimension (id)
        - Local anisotropy mean ± std (aniso_mean ± aniso_std)
        - kNN-graph degree profile: frac_deg2, endpoint_frac, branch_frac
        - Graph cycle rank (cycle_rank)
        - Fill ratio (fill_ratio)
        Also overlays Total Saliency (mean ± 95% CI) in a dedicated panel.
        Only panels for metrics found in `self.umapAggFile` are rendered.
        """

        print('Plotting LLM Maps ...')

        # ---------- Load metrics (any subset may exist) ----------
        agg = np.load(self.umapAggFile, allow_pickle=True)
        # Helper: safe getter (returns None if not present)
        def get(key): return agg[key] if key in agg.files else None

        # Candidate metrics (name, pretty label, plotting kind)
        METRICS = [
            ("tau",        r"Principal thickness $\tau$",            "line"),
            ("tau_norm",   r"Normalized thickness $\tau/R$",         "line"),
            ("id",         "Intrinsic dimension (TWO-NN)",           "line"),
            ("aniso_mean", "Local anisotropy (mean ± std)",          "band"),   # needs aniso_std
            ("fill_ratio", "Fill ratio (areal occupancy)",           "line"),
            ("cycle_rank", "kNN graph: cycle rank",                  "line"),
        ]
        # Degree-profile panel is a multi-series special case
        deg_keys = ("frac_deg2", "endpoint_frac", "branch_frac")

        # Saliency (from pickle) - optional
        sal_mean = sal_ci_lower = sal_ci_upper = None
        try:
            with open(self.saliencyFile, "rb") as f:
                sal_data = pickle.load(f)
            avg_sal_per_prompt = sal_data["avg_sal_per_prompt"]
            sal_mean, sal_ci_lower, sal_ci_upper = self.compute_saliency_per_layer(avg_sal_per_prompt)
            print(f"Loaded saliency data from {self.saliencyFile}")
        except FileNotFoundError:
            print(f"Warning: Saliency file {self.saliencyFile} not found. Skipping saliency plot.")
        except Exception as e:
            print(f"Warning: Error loading saliency data: {e}. Skipping saliency plot.")

        # ---------- Load lesioning data ----------
        lesion_data = None
        try:
            lesion_file = f"results/{self.analysis_name}_lesioning_{self.MODEL_NAME}.json"
            with open(lesion_file, "r") as f:
                lesion_data = json.load(f)
                
            # Extract average scores across all prompts for each layer
            prompt_scores = lesion_data["prompt_scores"]
            layer_scores = {}
            
            for prompt_result in prompt_scores.values():
                for layer_idx, score_info in prompt_result["scores_and_justifications"].items():
                    layer_idx = int(layer_idx)
                    if layer_idx not in layer_scores:
                        layer_scores[layer_idx] = []
                    layer_scores[layer_idx].append(score_info["score"])
            
            # Calculate average scores and confidence intervals for each layer
            lesion_scores = []
            lesion_ci_lower = []
            lesion_ci_upper = []
            
            for layer_idx in sorted(layer_scores.keys()):
                scores = layer_scores[layer_idx]
                avg_score = sum(scores) / len(scores)
                std_score = np.std(scores, ddof=1)  # Sample standard deviation
                
                # Calculate 95% confidence interval using t-distribution
                n_prompts = len(scores)
                confidence_level = 0.95
                alpha = 1 - confidence_level
                t_critical = stats.t.ppf(1 - alpha/2, df=n_prompts-1)
                
                ci_margin = t_critical * (std_score / np.sqrt(n_prompts))
                ci_lower = avg_score - ci_margin
                ci_upper = avg_score + ci_margin
                
                lesion_scores.append(avg_score)
                lesion_ci_lower.append(ci_lower)
                lesion_ci_upper.append(ci_upper)
            
            print(f"Loaded lesioning data with {len(lesion_scores)} layers")
        except Exception as e:
            print(f"Warning: Could not load lesioning data: {e}")

        # ---------- Load activation patching data ----------
        activation_patching_data = None
        try:
            activation_patching_file = f"results/{self.analysis_name}_activation_patching_{self.MODEL_NAME}.json"
            with open(activation_patching_file, "r") as f:
                activation_patching_data = json.load(f)
        except Exception as e:
            print(f"Warning: Could not load activation patching data: {e}")

        # ---------- Prepare panels ----------
        panels = []
        
        # UMAP panel - check if UMAP data exists
        try:
            agg = np.load(self.umapAggFile, allow_pickle=True)
            if 'aniso_mean' in agg.files:
                panels.append(("line", "UMAP Local Anisotropy", "line", agg['aniso_mean']))
        except Exception as e:
            print(f"Warning: Could not load UMAP data: {e}")
        
        # Saliency panel
        if sal_mean is not None:
            panels.append(("saliency", "Saliency Analysis (mean ± 95% CI)", "saliency", 
                        (sal_mean, sal_ci_lower, sal_ci_upper)))
        
        # Lesioning panel
        if lesion_data is not None:
            panels.append(("lesioning", "Layer Ablation Analysis (mean ± 95% CI)", "lesioning", 
                        (lesion_scores, lesion_ci_lower, lesion_ci_upper)))
        
        # Activation patching panel
        if activation_patching_data is not None:
            # Extract average patching effects across all prompts for each layer
            if 'all_patching_results' in activation_patching_data:
                layer_results = activation_patching_data["all_patching_results"]
                layer_patching_effects = {}
                
                for layer_idx_str, layer_result in layer_results.items():
                    layer_idx = int(layer_idx_str)
                    if 'patching_effect' in layer_result:
                        patching_effects = layer_result["patching_effect"]
                        
                        if layer_idx not in layer_patching_effects:
                            layer_patching_effects[layer_idx] = []
                        
                        # Collect all patching effects for this layer
                        for prompt_idx, effect in patching_effects.items():
                            if effect != 0:  # Skip zero effects
                                layer_patching_effects[layer_idx].append(effect)
                
                # Calculate average patching effects for each layer
                activation_patching_scores = []
                for layer_idx in range(self.num_layers):
                    if layer_idx in layer_patching_effects and layer_patching_effects[layer_idx]:
                        avg_effect = sum(layer_patching_effects[layer_idx]) / len(layer_patching_effects[layer_idx])
                        activation_patching_scores.append(avg_effect)
                    else:
                        activation_patching_scores.append(0.0)
                
                panels.append(("line", "Activation Patching Effects", "line", activation_patching_scores))

        # Decide layer count from whatever metric we have (fall back to saliency)
        if len(panels) > 0:
            # Use the first panel's data to determine layer count
            first_panel = panels[0]
            if len(first_panel) >= 4:
                first_vec = first_panel[3]
                if isinstance(first_vec, (list, np.ndarray)):
                    L = len(first_vec)
                else:
                    L = self.num_layers
            else:
                L = self.num_layers
        elif sal_mean is not None:
            L = len(sal_mean)
        else:
            L = self.num_layers
        
        layers = np.arange(L)

        # ---------- Layout (rows x cols) ----------
        n = len(panels)
        n_cols = 3
        n_rows = int(np.ceil(n / n_cols))

        fig = plt.figure(figsize=(5.0 * n_cols, 3.6 * n_rows))
        gs = GridSpec(n_rows, n_cols, figure=fig, wspace=0.28, hspace=0.38)

        # ---------- Render panels ----------
        for i, (k, label, kind, val) in enumerate(panels):
            ax = fig.add_subplot(gs[i // n_cols, i % n_cols])
            ax.grid(True, alpha=0.25, linestyle="--")

            if kind == "line":
                y = np.asarray(val)
                ax.plot(layers, y, linewidth=2)
                ax.set_title(label)
                ax.set_xlabel("Layer"); ax.set_ylabel("Value")

            elif kind == "band":
                mean, std = val
                mean = np.asarray(mean); std = np.asarray(std)
                ax.plot(layers, mean, linewidth=2)
                ax.fill_between(layers, mean - std, mean + std, alpha=0.2)
                ax.set_title(label)
                ax.set_xlabel("Layer"); ax.set_ylabel("Value")

            elif kind == "multi":  # degree profile: 3 series
                frac2, frac1, frac3p = val
                ax.plot(layers, frac2, linewidth=2, label="deg = 2")
                ax.plot(layers, frac1, linewidth=2, label="deg = 1 (endpoints)")
                ax.plot(layers, frac3p, linewidth=2, label="deg ≥ 3 (branches)")
                ax.set_title(label)
                ax.set_xlabel("Layer"); ax.set_ylabel("Fraction")
                ax.legend(frameon=False)

            elif kind == "saliency":
                mean, lo, hi = (np.asarray(x) for x in val)
                ax.plot(layers, mean, linewidth=2, label="Total Saliency")
                ax.fill_between(layers, lo, hi, alpha=0.2)
                ax.set_title(label)
                ax.set_xlabel("Layer"); ax.set_ylabel("Value")

            elif kind == "lesioning":
                scores, ci_lower, ci_upper = val
                scores = np.asarray(scores)
                ci_lower = np.asarray(ci_lower)
                ci_upper = np.asarray(ci_upper)
                lesion_layers = np.arange(len(scores))
                ax.plot(lesion_layers, scores, linewidth=2, marker='o', color='red', label="Layer Ablation Impact")
                ax.fill_between(lesion_layers, ci_lower, ci_upper, color='red', alpha=0.2)
                ax.set_title(label)
                ax.set_xlabel("Layer")
                ax.set_ylabel("Change from Original")
                ax.set_ylim(1, 10)
                ax.set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
                ax.set_yticklabels(['1\n(no change)', '2', '3', '4', '5', '6', '7', '8', '9', '10\n(significant\n disruption)'])
                ax.legend()
                ax.grid(True, alpha=0.3)

            else:
                ax.text(0.5, 0.5, f"Unknown kind: {kind}", ha="center", va="center")
                ax.axis("off")

        # ---------- Title + save ----------
        fig.suptitle(f"LLM Map: {self.analysis_name} — {self.MODEL_NAME}", y=0.995, fontsize=14)
        out_file = f"results/{self.analysis_name}_llm-maps_{self.MODEL_NAME}.pdf"
        fig.savefig(out_file, dpi=300, bbox_inches="tight")
        print(f"Saved → {out_file}")
        plt.show()

    def generate_lesioning_prompts(self):
        # let's have more open-ended prompts relating to age knowledge
        prompts = [
            "How long do people live around the world?",
            "What is the age at which people start to have memory problems and why?",
            "What is the legal age of consent in the world?",
            "At what age do most people retire globally, and how is this changing?",
            "At what age do people get married globally, and how is this changing?",
            "How do life expectancy trends differ between countries and what drives them?",
            "How do cognitive abilities typically change with age, and what mechanisms explain this?"
            ]
        return prompts

    # def runLesioning(self, load_cached=True):
    #     """
    #     Performs layer lesioning analysis on age-related prompts.
    #     """
    #     print("Running layer lesioning analysis for age...")
        
    #     # generate the prompts
    #     prompts = self.generate_lesioning_prompts()
        
        
    #     lesion_file = f"results/{self.analysis_name}_lesioning_{self.MODEL_NAME}.json"
    #     if load_cached and os.path.exists(lesion_file):
    #         with open(lesion_file, "r") as f:
    #             lesion_results = json.load(f)
    #         print(f"Loaded cached lesioning results from {lesion_file}")
    #     else:
    #         # Run layer lesioning
    #         lesion_results = self.run_layer_lesioning(test_prompts, output_file=lesion_file)
    #         print(f"Age lesioning analysis complete. Results saved to {lesion_file}")
        
        
    #     self.print_lesioning_results(lesion_results)
    #     self.print_lesioning_table(lesion_results)
        
    def generate_activation_patching_prompts(self):
        """
        Generate 10 prompt pairs for activation patching with single-word age-related answers.
        
        Returns:
            List of tuples (clean_prompt, corrupt_prompt, clean_answer, corrupt_answer)
        """

        print("Running activation patching analysis for age...")

        prefix = "Answer with a single word only describing a cause. "

        prompt_pairs = [
            # Elderly vs Child - Memory
            ("A 75-year-old patient has memory problems. The most likely cause is ___","Dementia",
             "A 5-year-old patient has memory problems. The most likely cause is ___", "Epilepsy"),
            
            # Elderly vs Young - Heart
            ("A 70-year-old patient has chest pain. The most likely cause is ___", "Angina",
             "A 25-year-old patient has chest pain. The most likely cause is ___", "Anxiety"),
            
            # Elderly vs Middle-aged - Bones
            ("A 80-year-old patient has bone pain. The most likely cause is ___", "Osteoporosis",
             "A 40-year-old patient has bone pain. The most likely cause is ___", "Arthritis"), # or trauma
            
            # Elderly vs Young - Vision
            ("A 75-year-old patient has vision loss. The most likely cause is ___", "Cataracts", #Maculopathy, 
             "A 20-year-old patient has vision loss. The most likely cause is ___", "Injury"),
            
            # Elderly vs Child - Walking
            ("A 85-year-old patient has walking problems. The most likely cause is ___", "Arthritis", #Parkinsonism
             "A 3-year-old patient has walking problems. The most likely cause is ___", "Cerebral palsy"),
            
            # # Middle-aged vs Young - Blood
            # (prefix + "A 55-year-old patient has high blood pressure. The most likely cause is ___", "Heredity",
            #  prefix + "A 25-year-old patient has high blood pressure. The most likely cause is ___","Stress"),
            
            # Elderly vs Young - Sleep
            ("A 70-year-old patient has sleep problems. The most likely cause is ___", "Insomnia",
             "A 30-year-old patient has sleep problems. The most likely cause is ___", "Stress"),
            
            # Elderly vs Child - Appetite
            ("A 80-year-old patient has poor appetite. The most likely cause is ___", "Depression",
             "A 4-year-old patient has poor appetite. The most likely cause is ___", "Teething"),
            
            # Middle-aged vs Young - Weight
            ("A 20-year-old female patient has weight gain. The most likely cause is ___", "Diet",
             "A 50-year-old female patient has weight gain. The most likely cause is ___", "Menopause"),
            
            # # Elderly vs Young - Energy
            # (prefix + "A 75-year-old patient has low energy. The most likely cause is ___", "Fatigue",
            #  prefix + "A 25-year-old patient has low energy. The most likely cause is ___", "Stress")
        ]

        # add the prefix to each prompt
        prompt_pairs = [(prefix + prompt[0], prompt[1], prefix + prompt[2], prompt[3]) for prompt in prompt_pairs]

        return prompt_pairs
