"""
Generic class for running UMAP and saliency analysis on a model.
"""
import math, torch, umap, matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize
from matplotlib.lines import Line2D
from matplotlib.backends.backend_pdf import PdfPages
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList
from functools import partial
import numpy as np
import argparse
import random
import torch.nn.functional as F
from utils import *
import pickle
from scipy import stats
import json
import os

from saliency import *


class Act:
    def __init__(self, tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES=False, TEMPERATURE=0, DEEPSPEED=False, USE_ACCELERATE=False):
        self.tokenizer = tokenizer
        self.model = model if not DEEPSPEED else model.module
        self.processor = processor
        self.MODEL_NAME = MODEL_NAME
        self.PRINT_REPLIES = PRINT_REPLIES
        self.TEMPERATURE = TEMPERATURE
        self.DEEPSPEED = DEEPSPEED
        self.USE_ACCELERATE = USE_ACCELERATE
        self.SEED = 0
        self.num_layers = 0 # will be set in register_hooks
        self.analysis_model = None  # Cache for Llama-70B analysis model
        self.analysis_processor = None
        self.analysis_tokenizer = None
        random.seed(self.SEED)
        np.random.seed(self.SEED)
        # Controls whether forward hooks should capture activations
        self.capture_activations = True
        
        if 'medgemma' in self.MODEL_NAME.lower():
            # self.model_layers = self.model.model.layers # medgemma models
            self.model_layers = self.model.model.layers
            self.layer_part = 2 # these are used only in plot_saliency_grid and denote which parts of the dictionary to use for extracting the layer index, the sub-index, and the operation
            self.sub_part = 3
            self.op_part = 4
        elif 'gemma' in self.MODEL_NAME.lower():
            # parts: ['model', 'language_model', 'layers', '61', 'self_attn', 'o_proj', 'weight']
            # parts: ['model', 'language_model', 'layers', '61', 'self_attn', 'q_norm', 'weight']
            # parts: ['model', 'language_model', 'layers', '61', 'self_attn', 'k_norm', 'weight']
            self.model_layers = self.model.model.language_model.layers # gemma models
            self.layer_part = 3 # +1 because it contains the "language_model" part
            self.sub_part = 4
            self.op_part = 5
        elif 'llama-4' in self.MODEL_NAME.lower():
             # take the layers from the language model only
            self.model_layers = self.model.language_model.model.layers
            self.layer_part = 3 # +1 because it contains the "language_model" part
            self.sub_part = 4
            self.op_part = 5
        elif hasattr(self.model.model, "layers"):
            self.model_layers = self.model.model.layers
            self.layer_part = 2
            self.sub_part = 3
            self.op_part = 4
        else:
            raise ValueError(f"Unknown model: {self.MODEL_NAME}")
        
        self.num_layers    = len(self.model_layers)
        

    def set_files(self):
        self.umapAggFile = f"results/{self.analysis_name}_umap_agg_{self.MODEL_NAME}.npz"
        self.umapCacheFile = f"results/{self.analysis_name}_umap_cache_{self.MODEL_NAME}.pkl"
        self.saliencyFile = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}.pkl"
        self.lesioningFile = f"results/{self.analysis_name}_lesioning_{self.MODEL_NAME}.json"
        self.lesioningFinegrainedFile = f"results/{self.analysis_name}_lesioning_finegrained_{self.MODEL_NAME}.json"
        self.activationPatchingFile = f"results/{self.analysis_name}_activation_patching_{self.MODEL_NAME}.json"
        self.activationPatchingFinegrainedFile = f"results/{self.analysis_name}_activation_patching_finegrained_{self.MODEL_NAME}.json"

    def load_cached(self, analysis_type):
        """Load cached results for different analysis types"""
        if analysis_type == 'saliency':
            with open(self.saliencyFile, 'rb') as f:
                return pickle.load(f)
        elif analysis_type == 'lesioning_finegrained':
            return self._assemble_finegrained_results_from_layers()
        elif analysis_type == 'activation_patching_finegrained':
            with open(self.activationPatchingFinegrainedFile, 'r') as f:
                return json.load(f)
        else:
            raise ValueError(f"Unknown analysis type: {analysis_type}")




    def run(self, do_umap=True, do_saliency=True, do_maps=True, do_lesioning=False, do_activation_patching=False, do_activation_patching_finegrained=False, do_lesioning_finegrained=False, do_heatmap=False, load_cached=True):
        if do_umap:
            self.runUmap(load_cached)
        if do_saliency:
            self.runSaliency(load_cached)
        if do_lesioning:
            self.runLesioning(load_cached)
        if do_lesioning_finegrained:
            self.runLesioningFinegrained(load_cached)
        if do_activation_patching:
            self.runActivationPatching()
        if do_activation_patching_finegrained:
            self.runActivationPatchingFinegrained(load_cached)
        if do_maps:
            self.plotMap()
        if do_heatmap:
            self.heatmap(load_cached=load_cached)


    def umap_fit(self, layer_buffers, n_components=2):
        print("UMAP fitting …")
        cfg = dict(n_neighbors=15, min_dist=0.1, metric="cosine", random_state=self.SEED, n_components=n_components)
        embs = []
        for buf in layer_buffers:
            # X = torch.cat(buf, 0).numpy()
            X = torch.cat(buf, dim=0).to(torch.float32).cpu().numpy()
            X = np.nan_to_num(X, nan=0., posinf=0., neginf=0.)
            embs.append(umap.UMAP(**cfg).fit_transform(X))
        return embs
    
    def register_hooks(self):
        
        layer_buffers = [[] for _ in range(self.num_layers)]

        def hook(layer_id, module, _, out):
            # Skip capturing when disabled (e.g., during optional long generation)
            if not self.capture_activations:
                return
            hidden = out[0] if isinstance(out, tuple) else out  # (B,T,H)
            layer_buffers[layer_id].append(hidden.mean(dim=1).cpu())

        handles = [
            blk.register_forward_hook(partial(hook, i))
            for i, blk in enumerate(self.model_layers)
        ]
        return handles, layer_buffers

    def run_model(self, prompts, long_response_length=0):
        print("Running prompts …")
        gen_cfg = dict(max_new_tokens=12, do_sample=False)
        replies = []
        with torch.no_grad():
            for txt in prompts:
                print("PROMPT:", txt)
            
                # ─── original Llama-3 flow ──────────────────────────────────────
                toks = self.tokenizer(txt, return_tensors="pt").to(self.model.device)
                # print('TEMPERATURE:', self.TEMPERATURE)

                # generate
                # reply_ids = self.model.generate(
                #     **toks,
                #     use_cache=False,
                #     max_new_tokens=3,
                #     do_sample=True,
                #     temperature=self.TEMPERATURE,
                #     top_k=5,
                #     pad_token_id=self.tok.eos_token_id,
                # )[0][toks.input_ids.shape[-1]:]

                # reply = self.tok.decode(reply_ids, skip_special_tokens=True).strip()
                # print("REPLY:", reply)
                # replies.append(reply)

                # forward‐pass with hooks
                out = self.model(**toks, temperature=self.TEMPERATURE)
                logits = out.logits

                # Last-position logits  →  probability distribution
                next_token_logits = logits[0, -1]                          # (V,)
                probs = torch.softmax(next_token_logits, dim=-1)           # (V,)
                
                # Inspect the top-k options
                k = 5
                topk = torch.topk(probs, k)
                for idx, p in zip(topk.indices, topk.values):
                    print(f"{self.tokenizer.decode(idx):<12} {p.item():.4f}")

                # Optionally generate a longer continuation if requested (non-Llama-4 path)
                if long_response_length > 0:
                    # Temporarily disable activation capture so hooks don't record generation activations
                    prev = self.capture_activations
                    self.capture_activations = False
                    gen_ids = self.model.generate(
                        **toks,
                        use_cache=False,
                        max_new_tokens=long_response_length,
                        do_sample=True,
                        temperature=max(self.TEMPERATURE, 0.7),
                        top_k=50,
                        pad_token_id=self.tokenizer.eos_token_id,
                    )[0][toks.input_ids.shape[-1]:]
                    reply = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
                    if self.PRINT_REPLIES:
                        print("REPLY:", reply)
                    replies.append(reply)
                    # Re-enable capture for subsequent forward passes
                    self.capture_activations = prev
                else:
                    replies.append("")

                # # this single forward‐pass is the *only* one we hook
                # _ = model(**toks)

        return replies

    def runSaliency(self, load_cached):
        
        # check if cache exists:
        if load_cached and os.path.exists(self.saliencyFile):
            with open(self.saliencyFile, "rb") as f:
                data = pickle.load(f)
                print(data.keys())
                avg_sal = data["avg_sal"]
                avg_sal_per_prompt = data["avg_sal_per_prompt"]
                avg_sal_per_head = data["avg_sal_per_head"]
                avg_sal_per_mlp = data["avg_sal_per_mlp"]
            print(f"Loaded cached saliency data from {self.saliencyFile}")
        else:
            
            all_saliency_dicts = {}
            curr_saliency = defaultdict(list)
            # Accumulate per-prompt per-head/MLP saliencies so we can average consistently
            per_head_accumulator = defaultdict(list)   # key: (layer_idx, head_idx) -> [vals]
            per_mlp_accumulator  = defaultdict(list)   # key: layer_idx -> [vals]

            prompts_targets = self.gen_saliency_prompts()

            for prompt, target, prompt_name in prompts_targets:
                # prompt = f"Answer with a single number, nothing else. Someone is {age} years old. How old are they?"

                allowed_ids = None # no restriction on the vocabulary, this will compute a saliency for a single target token, and find the weights that minimize it the most

                # Run forward to get predicted token
                inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
                with torch.no_grad():
                    outputs = self.model(**inputs)
                logits = outputs.logits[:, -1, :]
                probs = torch.softmax(logits, dim=-1)
                top_token_id = torch.argmax(probs, dim=-1).item()
                top_token_str = self.tokenizer.decode([top_token_id]).strip()
                print(f"Prompt {prompt_name}: Predicted token: '{top_token_str}' (ID={top_token_id})")

                # Get saliency
                curr_saliency = self.saliency_for(prompt, target, allowed_ids)
                all_saliency_dicts[target] = curr_saliency

                lines = print_top_saliencies(curr_saliency,  f"Prompt: {prompt_name}, Target: {target}")
                print(lines)

                # Plot saliency map for this particular prompt
                self.plot_saliency_grid(
                    self.model_layers, curr_saliency,
                    f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_{prompt_name}.pdf",
                    f"{self.MODEL_NAME} saliency for {prompt_name}"
                )

                # Extract per-head and per-MLP saliency for this prompt
                curr_saliency_per_head, curr_saliency_per_mlp = self.extract_saliency_components(curr_saliency)

                # Debug: print final-layer head saliencies per prompt
                try:
                    final_layer = len(self.model_layers) - 1
                except Exception:
                    final_layer = max(k[0] for k in curr_saliency_per_head.keys()) if curr_saliency_per_head else 0
                heads_in_final = sorted([h for (ly, h) in curr_saliency_per_head.keys() if ly == final_layer])
                if heads_in_final:
                    vals = [curr_saliency_per_head.get((final_layer, h), float('nan')) for h in heads_in_final]
                    print(f"Per-prompt head saliency (layer {final_layer}) for {prompt_name}:")
                    print("  heads:", heads_in_final)
                    print("  vals:", [f"{v:.6f}" if isinstance(v, (int, float)) else str(v) for v in vals])
                    if 17 in heads_in_final:
                        print(f"  H17 value: {curr_saliency_per_head.get((final_layer, 17), float('nan')):.6f}")
                
                # also print 2D saliency table for this particular prompt
                self.plot_saliency_2D(
                    curr_saliency_per_head, curr_saliency_per_mlp,
                    f"results/{self.analysis_name}_2Dsaliency_{self.MODEL_NAME}_{prompt_name}.pdf",
                    f"{self.MODEL_NAME} saliency for {prompt_name}"
                )

                # Accumulate for consistent averaging across prompts
                for key, val in curr_saliency_per_head.items():
                    if np.isfinite(val):
                        per_head_accumulator[key].append(float(val))
                for key, val in curr_saliency_per_mlp.items():
                    if np.isfinite(val):
                        per_mlp_accumulator[key].append(float(val))

            avg_sal, avg_sal_per_prompt, avg_sal_per_head, avg_sal_per_mlp = merge_saliency_dicts(all_saliency_dicts, self.model)

            # Replace head/mlp averages with our prompt-wise means to ensure consistency with per-prompt tables
            if per_head_accumulator:
                avg_sal_per_head = { key: float(np.mean(vals)) for key, vals in per_head_accumulator.items() if len(vals) > 0 }
            if per_mlp_accumulator:
                avg_sal_per_mlp  = { key: float(np.mean(vals)) for key, vals in per_mlp_accumulator.items() if len(vals) > 0 }

            # Save everything in .npz
            save_saliency(self.saliencyFile, avg_sal, avg_sal_per_prompt, 
                     avg_sal_per_head, avg_sal_per_mlp)
            print(f"Saved all saliency data → {self.saliencyFile}")

        # Debug: print averaged final-layer head saliencies
        try:
            final_layer = len(self.model_layers) - 1
        except Exception:
            final_layer = max(k[0] for k in avg_sal_per_head.keys()) if avg_sal_per_head else 0
        heads_in_final = sorted([h for (ly, h) in avg_sal_per_head.keys() if ly == final_layer])
        if heads_in_final:
            vals = [avg_sal_per_head.get((final_layer, h), float('nan')) for h in heads_in_final]
            print(f"Averaged head saliency (layer {final_layer}):")
            print("  heads:", heads_in_final)
            print("  vals:", [f"{v:.6f}" if isinstance(v, (int, float)) else str(v) for v in vals])
            if 17 in heads_in_final:
                print(f"  H17 averaged value: {avg_sal_per_head.get((final_layer, 17), float('nan')):.6f}")
        
        lines = print_top_saliencies(avg_sal, f"Average Saliency: {self.analysis_name}")
        ftxt = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}.txt"
        with open(ftxt, "w") as f:
            f.write("\n".join(lines))
        print(f"Saved saliency report → {ftxt}")

        # Plot average saliency
        avg_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_avg.pdf"
        self.plot_saliency_grid(
            self.model_layers, avg_sal,
            avg_path,
            f"{self.MODEL_NAME} saliency: {self.analysis_name}"
        )
        print(f"Plotted saliency map → {avg_path}")
        
        # Plot saliency table
        table_path = f"results/{self.analysis_name}_2Dsaliency_{self.MODEL_NAME}_avg.pdf"
        self.plot_saliency_2D(
            avg_sal_per_head, avg_sal_per_mlp,
            table_path,
            f"{self.MODEL_NAME} saliency: {self.analysis_name}"
        )
        print(f"Plotted saliency table → {table_path}")

    def saliency_for(self, prompt: str, desired_token: str, allowed_ids=None):
        device = next(self.model.parameters()).device
        """
        allowed_ids:
        - if None: no restriction on the vocabulary, this will compute a saliency for a single target token, and find the weights that minimize it the most
        - if not None: e.g. [dead_id, alive_id], it will restrict the vocabulary to the allowed_ids. this will compute a saliency for a single target token, and find the weights that maximally increase the probability of the target token
        """

        # (1) run in FP32 only—drop the .double() entirely
        # self.model.double()

        # Tokenize on CPU, then move to the embedding device to avoid CPU/GPU mismatch
        inputs = self.tokenizer(prompt, return_tensors="pt")
        try:
            emb_device = self.model.get_input_embeddings().weight.device
        except Exception:
            emb_device = next(self.model.parameters()).device
        inputs = {k: (v.to(emb_device) if torch.is_tensor(v) else v) for k, v in inputs.items()}
        # Disable KV cache to reduce memory during saliency backward
        outputs = self.model(**inputs, use_cache=False)
        last_logits = outputs.logits[:, -1, :]
        # Only compute/print restricted logits if a clamp is requested
        if allowed_ids is not None:
            allowed_logits = last_logits[:, allowed_ids]
            print('allowed_ids:', allowed_ids)
        # Try to tokenize the desired token to find its ID
        tokenized = self.tokenizer(desired_token, add_special_tokens=False).input_ids
        token_id = tokenized[0] # takes only the first token from the word and computes the saliency for it
        if allowed_ids is not None:
            allowed_logits = last_logits[:, allowed_ids]  # shape: (1, len(allowed_ids))
            target_idx = allowed_ids.index(token_id)
            # Ensure target lives on the same device as logits
            loss = F.cross_entropy(allowed_logits, torch.tensor([target_idx], device=allowed_logits.device))
        else:
            # Just maximize the raw logit of the desired token
            logit = last_logits[0, token_id]
            loss = -logit

        # (2) backward & grab grads
        # self.model.zero_grad()
        if hasattr(self.model, "module"):
            self.model.module.zero_grad()
        else:
            self.model.zero_grad()
        loss.backward()

        # (3) clone *to CPU* in float32
        grad_dict = {}
        for name, p in self.model.named_parameters():
            if p.grad is not None:
                grad = p.grad.detach().abs().cpu().clone()
                grad_dict[name] = grad

        self.model.zero_grad()      # clear GPU grads
        torch.cuda.empty_cache()    # free any fragmentation
        
        return grad_dict

    def extract_saliency_components(self, saliency_dict):
        """
        Extract per-head and per-MLP saliency components from a saliency dictionary.
        
        Args:
            saliency_dict: Dictionary with parameter names as keys and saliency tensors as values
            
        Returns:
            curr_saliency_per_head: {(layer_idx, head_idx) -> scalar saliency}
            curr_saliency_per_mlp: {layer_idx -> scalar saliency for MLP matrices}
        """
        curr_saliency_per_head = {}
        curr_saliency_per_mlp = {}

        # Model head geometry (best-effort; defaults if unavailable)
        try:
            num_heads = int(getattr(self.model.config, 'num_attention_heads'))
            hidden_size = int(getattr(self.model.config, 'hidden_size'))
        except Exception:
            num_heads = 32
            hidden_size = saliency_dict.get('model.layers.0.self_attn.q_proj.weight', None)
            hidden_size = hidden_size.shape[0] if hidden_size is not None else 4096
        head_dim = max(1, hidden_size // max(1, num_heads))

        # Temporary accumulators so we can combine Q/K/V/O contributions per head
        head_sums = defaultdict(float)
        head_counts = defaultdict(int)

        for param_name, saliency_tensor in saliency_dict.items():
            # Extract layer index from parameter name
            if 'layers.' in param_name:
                try:
                    # Extract layer number from parameter name like "model.layers.5.self_attn.q_proj.weight"
                    layer_part = param_name.split('layers.')[1].split('.')[0]
                    layer_idx = int(layer_part)
                except (ValueError, IndexError):
                    continue
            else:
                continue
            
            # Handle attention heads
            if 'self_attn' in param_name and any(proj in param_name for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
                # Expect linear weight [out_features, in_features]
                if saliency_tensor.ndim == 2:
                    out_features, in_features = saliency_tensor.shape
                    # Q/K/V heads are laid out along rows (out_features)
                    if any(proj in param_name for proj in ['q_proj', 'k_proj', 'v_proj']):
                        for h in range(num_heads):
                            row_start = h * head_dim
                            row_end = min(out_features, (h + 1) * head_dim)
                            if row_start >= out_features:
                                break
                            block = saliency_tensor[row_start:row_end, :]
                            val = float(block.abs().mean().item()) if block.numel() > 0 else 0.0
                            head_sums[(layer_idx, h)] += val
                            head_counts[(layer_idx, h)] += 1
                    else:
                        # o_proj mixes heads across columns; slice by input columns
                        for h in range(num_heads):
                            col_start = h * head_dim
                            col_end = min(in_features, (h + 1) * head_dim)
                            if col_start >= in_features:
                                break
                            block = saliency_tensor[:, col_start:col_end]
                            val = float(block.abs().mean().item()) if block.numel() > 0 else 0.0
                            head_sums[(layer_idx, h)] += val
                            head_counts[(layer_idx, h)] += 1
            
            # Handle MLP components
            elif 'mlp' in param_name and any(proj in param_name for proj in ['gate_proj', 'up_proj', 'down_proj']):
                # Get the mean saliency for MLP matrices
                mlp_val = float(saliency_tensor.abs().mean().item())
                # Combine multiple MLP matrices by averaging per layer
                if layer_idx in curr_saliency_per_mlp:
                    curr_saliency_per_mlp[layer_idx] = 0.5 * (curr_saliency_per_mlp[layer_idx] + mlp_val)
                else:
                    curr_saliency_per_mlp[layer_idx] = mlp_val

        # Finalize head values by averaging contributions across Q/K/V/O
        for key, total in head_sums.items():
            count = max(1, head_counts.get(key, 1))
            curr_saliency_per_head[key] = total / count

        return curr_saliency_per_head, curr_saliency_per_mlp


    # --- Helper: print probabilities for only alive/dead tokens ---
    def print_clamped_probs(self, prompt: str, allowed_ids):
        # grab the device once
        device = next(self.model.parameters()).device

        # tokenize and move *all* tensors to device
        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = self.model(**inputs)
        last_logits = outputs.logits[:, -1, :]
        allowed_logits = last_logits[:, allowed_ids]
        probs = torch.softmax(allowed_logits, dim=-1)[0]
        print(f"\nPrompt: {prompt}")
        print(f"  'probs[0]': {probs[0]:.4f}")
        print(f"  'probs[1]' : {probs[1]:.4f}")

    # --- Example generation using logits_processor ---
    def generate_clamped(self, prompt: str, allowed_ids):
        # move everything to the model's device
        device = next(self.model.parameters()).device
        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # your processor
        processor = LogitsProcessorList([OnlyAllowCertainLogitsProcessor(allowed_ids)])

        # generate (returns a Tensor or a dict with "sequences")
        out = self.model.generate(
            **inputs,
            max_new_tokens=1,
            logits_processor=processor,
            do_sample=False
        )

        # unwrap sequences
        if isinstance(out, dict):
            seqs = out["sequences"]
        else:
            seqs = out

        # figure out where the new token is
        input_ids = inputs["input_ids"]
        new_token_id = seqs[0, input_ids.shape[-1]].item()

        print(f"Generated (clamped): {self.tokenizer.decode([new_token_id])}")

    def plot_saliency_grid(self, layers, saliency_dict, out_path, title, doLog = True):
        """
        Generates and saves a horizontal grid plot of saliency weights.
        Each column is a sub-layer (Attention or MLP) per layer; 4 rows for operations.
        Rows: Q, K, V, O for attention; G, U, D, blank for MLP.
        layers:
            - the model layers
        saliency_dict:
            - the saliency dictionary
        out_path:
            - the path to save the plot
        title:
            - the title of the plot
        doLog:
            - if True: log-transform the saliency values
        """
        # 1) Infer model structure
        n_layers = len(layers)
        sublayers = ['self_attn', 'mlp']
        n_sub = len(sublayers)
        rows, cols = 4, n_layers * n_sub

        # 2) Initialize grid and letter map
        grid = np.zeros((rows, cols))
        letters = [['' for _ in range(cols)] for _ in range(rows)]

        # 3) Define operation mappings
        attn_map = {'q_proj': (0, 'Q'), 'k_proj': (1, 'K'),
                    'v_proj': (2, 'V'), 'o_proj': (3, 'O')}
        mlp_map  = {'gate_proj': (0, 'G'), 'up_proj': (1, 'U'),
                    'down_proj': (2, 'D')}
        
        counter = 0
        # print('saliency_dict.keys():', saliency_dict.keys())
        # 4) Populate grid with saliency values
        for name, val in saliency_dict.items():
            if not name.endswith('.weight'):
                continue
            parts = name.split('.')
            # print('parts:', parts)

            if parts[0] not in ['model', 'layers', 'language_model'] or parts[1] not in ['model', 'layers', 'language_model']:
                print('parts[:2] different from expected: ', parts[:2])
                continue
            
            # special layers in the gemma models
            if parts[2] in ['embed_tokens', 'norm']:
                continue

            layer_idx = int(parts[self.layer_part])
            sub, op = parts[self.sub_part], parts[self.op_part]
            # scalar saliency
            sal = val.mean().item() if hasattr(val, 'mean') else float(val)

            # print('layer_idx sub op: ', layer_idx, sub, op)

            if sub == 'self_attn' and op in attn_map:
                col = layer_idx * n_sub + 0
                row, letter = attn_map[op]
                counter += 1
            elif sub == 'mlp' and op in mlp_map:
                col = layer_idx * n_sub + 1
                row, letter = mlp_map[op]
                counter += 1
            else:
                continue

            grid[row, col] = sal
            letters[row][col] = letter

        if counter == 0:
            raise ValueError('No saliency values matched')
        

        if doLog:
            eps = 1e-8
            # create a mask of the blank slots
            blank_mask = np.array(letters) == ''
            # turn the blank slots into NaN so they don't pull down your min
            grid_masked = grid.astype(float)
            grid_masked[blank_mask] = np.nan

            # now log‐transform only the real entries
            grid = np.log(grid_masked + eps)
        
        # 5) Create figure
        fig, ax = plt.subplots(figsize=(cols * 0.6, rows * 0.6))
        cmap = plt.cm.get_cmap('Greens')
        norm = Normalize(vmin=np.nanmin(grid), vmax=np.nanmax(grid))

        # 6) Draw grid cells
        for i in range(rows):
            for j in range(cols):
                x, y = j, rows - 1 - i
                if letters[i][j] == '':
                    facecolor = 'lightgrey'  # blank/unused slot
                else:
                    facecolor = cmap(norm(grid[i, j]))
                ax.add_patch(plt.Rectangle((x, y), 1, 1, color=facecolor, ec='white'))
                if letters[i][j]:
                    ax.text(x + 0.5, y + 0.5, letters[i][j],
                            ha='center', va='center', fontsize=8)

        # 7) Configure axes
        ax.set_xlim(0, cols)
        ax.set_ylim(0, rows)
        ax.set_xticks([i * n_sub + n_sub / 2 for i in range(n_layers)])
        ax.set_xticklabels([str(i + 1) for i in range(n_layers)])
        ax.set_xlabel('Layer')
        ax.set_yticks([])

        # Bold vertical separators between layers
        for i in range(1, n_layers):
            ax.axvline(i * n_sub, color='black', linewidth=2)

        # 8) Add colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm._A = []
        cbar = fig.colorbar(sm, ax=ax, orientation='vertical', pad=0.01)
        cbar.set_ticks([norm.vmin, norm.vmax])
        cbar.set_ticklabels(['low saliency', 'high saliency'])
        # cbar.set_label('Saliency')
        # move cbar to the left
        cbar.ax.set_position([0.05, 0.1, 0.02, 0.8])

        # 9) Legend mapping letters to operations
        letter_labels = {
            'Q': 'Query',     'K': 'Key',
            'V': 'Value',     'O': 'Output',
            'G': 'Gate',      'U': 'Up Proj.',
            'D': 'Down Proj.'
        }
        handles = [
            Line2D([0], [0], marker=f'${l}$', color='none', linestyle='None',
                label=f'{l} – {desc}')
            for l, desc in letter_labels.items()
        ]
        # Place legend within that reserved space
        legend_x = 1.2 if cols < 18 else 1.05
        ax.legend(handles=handles, loc='upper left',
                bbox_to_anchor=(legend_x, 1), frameon=False)

        # add more space to the right
        tl_x = 0.7 if cols < 18 else 0.7
        plt.tight_layout(rect=[0,0,tl_x,1], h_pad=1.5)

        # add title
        ax.set_title(title)

        # 10) Save figure
        fig.savefig(out_path, dpi=300)
        plt.show()
        # plt.close(fig)

    def plot_saliency_2D(self, avg_sal_per_head, avg_sal_per_mlp, out_path, title):
        """
        Creates a table plot where rows are layers and columns are attention heads + MLP matrices.
        Each cell is colored by the average saliency for that head/matrix.
        Plots all numbers in log space.
        
        Args:
            avg_sal_per_head: {(layer_idx, head_idx) -> scalar saliency}
            avg_sal_per_mlp: {layer_idx -> scalar saliency for MLP matrices}
            out_path: path to save the plot
            title: title for the plot
        """
        import matplotlib.pyplot as plt
        import numpy as np
        from matplotlib.colors import Normalize
        import matplotlib.patches as patches
        
        # Get model configuration (handle multimodal Gemma-3 which nests text config)
        cfg = getattr(self.model, "config", getattr(self.model, "model").config)
        if hasattr(cfg, "text_config"):
            num_heads = getattr(cfg.text_config, "num_attention_heads", None) or getattr(cfg.text_config, "num_heads", None)
            num_layers = getattr(cfg.text_config, "num_hidden_layers", None) or getattr(cfg.text_config, "n_layers", None)
        else:
            num_heads = getattr(cfg, "num_attention_heads", None) or getattr(cfg, "num_heads")
            num_layers = getattr(cfg, "num_hidden_layers", None) or getattr(cfg, "n_layers")
        
        # Create the data matrix
        # Columns: attention heads + 3 MLP matrices (Gate, Up-proj, Down-proj)
        n_cols = num_heads + 3  # attention heads + 3 MLP matrices
        n_rows = num_layers
        
        # Initialize data matrix
        data_matrix = np.full((n_rows, n_cols), np.nan)
        
        # Fill attention head data
        for (layer_idx, head_idx), saliency in avg_sal_per_head.items():
            if 0 <= layer_idx < n_rows and 0 <= head_idx < num_heads:
                data_matrix[layer_idx, head_idx] = saliency
        
        # Fill MLP data (last 3 columns)
        mlp_cols = [num_heads, num_heads + 1, num_heads + 2]  # Gate, Up-proj, Down-proj
        mlp_keys = ['gate_proj', 'up_proj', 'down_proj']
        
        for layer_idx, saliency in avg_sal_per_mlp.items():
            if 0 <= layer_idx < n_rows:
                # For now, put the same MLP saliency in all 3 MLP columns
                # In a more sophisticated version, we could separate by MLP component
                for col in mlp_cols:
                    data_matrix[layer_idx, col] = saliency
        
        # Convert to log space
        eps = 1e-8
        data_matrix_log = np.log(data_matrix + eps)
        
        # Create the plot
        fig, ax = plt.subplots(figsize=(n_cols * 0.8, n_rows * 0.6))
        
        # Create colormap - white to green
        from matplotlib.colors import LinearSegmentedColormap
        colors = ['white', 'green']
        n_bins = 256
        cmap = LinearSegmentedColormap.from_list('white_to_green', colors, N=n_bins)
        
        # Find min/max for normalization (excluding NaN values)
        valid_data = data_matrix_log[~np.isnan(data_matrix_log)]
        if len(valid_data) > 0:
            vmin, vmax = valid_data.min(), valid_data.max()
        else:
            vmin, vmax = 0, 1
        
        norm = Normalize(vmin=vmin, vmax=vmax)
        
        # Create the heatmap
        for i in range(n_rows):
            for j in range(n_cols):
                if not np.isnan(data_matrix_log[i, j]):
                    color = cmap(norm(data_matrix_log[i, j]))
                else:
                    color = 'lightgray'
                
                # Create rectangle
                rect = patches.Rectangle((j, n_rows - 1 - i), 1, 1, 
                                       facecolor=color, edgecolor='white', linewidth=0.5)
                ax.add_patch(rect)
                
                # Add text if the value is not NaN
                if not np.isnan(data_matrix_log[i, j]):
                    ax.text(j + 0.5, n_rows - 0.5 - i, f'{data_matrix_log[i, j]:.2f}',
                           ha='center', va='center', fontsize=14, color='white' if data_matrix_log[i, j] > (vmin + vmax) / 2 else 'black')
        
        # Add vertical separator between attention heads and MLP (moved 0.5 to the right)
        ax.axvline(x=num_heads, color='black', linewidth=2)
        
        # Set axis properties
        ax.set_xlim(0, n_cols)
        ax.set_ylim(0, n_rows)
        
        # Set ticks and labels - centered in grid cells
        ax.set_xticks([i + 0.5 for i in range(n_cols)])
        ax.set_xticklabels([f'H{i}' for i in range(num_heads)] + ['Gate', 'Up', 'Down'], fontsize=14)
        ax.set_yticks([i + 0.5 for i in range(n_rows)])
        ax.set_yticklabels([f'L{i}' for i in reversed(range(n_rows))], fontsize=14)
        
        ax.set_xlabel('Attention Heads / MLP Matrices', fontsize=16)
        ax.set_ylabel('Layers', fontsize=16)
        ax.set_title(title, fontsize=18)
        
        # Add colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm._A = []
        cbar = fig.colorbar(sm, ax=ax, orientation='vertical', pad=0.01)
        cbar.set_label('Log Saliency', fontsize=16)
        cbar.ax.tick_params(labelsize=14)
        
        # Add legend
        # legend_elements = [
        #     # patches.Patch(color='lightgray', label='No data'),
        #     patches.Patch(color=cmap(norm(vmin)), label='Low saliency (white)'),
        #     patches.Patch(color=cmap(norm(vmax)), label='High saliency (green)')
        # ]
        # ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.05, 1), fontsize=14)
        
        plt.tight_layout()
        fig.savefig(out_path, dpi=300, bbox_inches='tight')
        plt.show()

    def selectLayersForPlots(self, nrPlots):
        if self.num_layers < nrPlots:
            # If fewer layers than requested, use all layers
            selected_layers = list(range(self.num_layers))
        else:
            # Always include first layer (0) and last layer (num_layers-1)
            # Calculate step size for intermediate layers
            step = (self.num_layers - 1) / (nrPlots - 1)
            selected_layers = []
            for i in range(nrPlots):
                layer_idx = int(round(i * step))
                selected_layers.append(layer_idx)
        
        return selected_layers

    # def selectLayersForPlots(self):
    #     # ═══════════ 6b. Plot selected layers ════════════════════════════════════════════════
    #     # Select 6 evenly-spaced layers including first and last
    #     if self.num_layers >= 6:
    #         # Calculate step size to get 4 intermediate layers between first and last
    #         step = (self.num_layers - 1) / 5  # 5 intervals for 6 points
    #         selected_layers = [0]  # Always include first layer
    #         for i in range(1, 5):  # Add 4 intermediate layers
    #             layer_idx = int(round(i * step))
    #             if layer_idx not in selected_layers:  # Avoid duplicates
    #                 selected_layers.append(layer_idx)
    #         if self.num_layers - 1 not in selected_layers:  # Always include last layer
    #             selected_layers.append(self.num_layers - 1)
    #         selected_layers = sorted(selected_layers)  # Ensure they're in order
    #     else:
    #         # If fewer than 6 layers, use all layers
    #         selected_layers = list(range(self.num_layers))
        
    #     return selected_layers

    def plotMap(self):

        print('Plotting LLM Maps ...')
        
        # Check which modalities are available
        has_umap = False
        has_saliency = False
        has_lesioning = False
        
        # Try to load UMAP clustering data
        # don't try to load umap for dosage analysis
        if self.analysis_name != "dosage":
            try:
                umap_data = np.load(self.umapAggFile, allow_pickle=True)
                data_umap = umap_data["coeffs"]
                umap_ci_lower = umap_data["coeffs_ci_lower"]
                umap_ci_upper = umap_data["coeffs_ci_upper"]
                has_umap = True
                print(f"Loaded UMAP data with {len(data_umap)} layers")
            except Exception as e:
                print(f"Could not load UMAP data: {e}")
                data_umap = None
                umap_ci_lower = None
                umap_ci_upper = None

        # Try to load saliency data
        try:
            with open(self.saliencyFile, "rb") as f:
                sal_data = pickle.load(f)
            
            avg_sal_per_prompt = sal_data["avg_sal_per_prompt"]
            data_sal, sal_ci_lower, sal_ci_upper = self.compute_saliency_per_layer(avg_sal_per_prompt)
            has_saliency = True
            print(f"Loaded saliency data with {len(data_sal)} layers")
        except Exception as e:
            print(f"Could not load saliency data: {e}")
            data_sal = None

        # Try to load lesioning data
        try:
            lesion_file = self.lesioningFile
            with open(lesion_file, "r") as f:
                lesion_data = json.load(f)
            
            # Extract average scores across all prompts for each layer
            prompt_scores = lesion_data["prompt_scores"]
            layer_scores = {}
            
            for prompt_result in prompt_scores.values():
                for layer_idx, score_info in prompt_result["scores_and_justifications"].items():
                    layer_idx = int(layer_idx)
                    if layer_idx not in layer_scores:
                        layer_scores[layer_idx] = []
                    layer_scores[layer_idx].append(score_info["score"])
            
            # Calculate average scores and confidence intervals for each layer
            data_lesion = []
            lesion_ci_lower = []
            lesion_ci_upper = []
            
            for layer_idx in sorted(layer_scores.keys()):
                scores = layer_scores[layer_idx]
                avg_score = sum(scores) / len(scores)
                std_score = np.std(scores, ddof=1)  # Sample standard deviation
                
                # Calculate 95% confidence interval using t-distribution
                n_prompts = len(scores)
                confidence_level = 0.95
                alpha = 1 - confidence_level
                t_critical = stats.t.ppf(1 - alpha/2, df=n_prompts-1)
                
                ci_margin = t_critical * (std_score / np.sqrt(n_prompts))
                ci_lower = avg_score - ci_margin
                ci_upper = avg_score + ci_margin
                
                data_lesion.append(avg_score)
                lesion_ci_lower.append(ci_lower)
                lesion_ci_upper.append(ci_upper)
            
            has_lesioning = True
            print(f"Loaded lesioning data with {len(data_lesion)} layers")
        except Exception as e:
            print(f"Could not load lesioning data: {e}")
            data_lesion = None

        # Try to load activation patching data
        has_activation_patching = False
        data_patching = None
        patching_ci_lower = None
        patching_ci_upper = None
        
        try:
            patching_file = self.activationPatchingFile
            with open(patching_file, "r") as f:
                patching_data = json.load(f)
            
            # Extract average patching effects across all prompts for each layer
            all_patching_results = patching_data["all_patching_results"]
            layer_patching_effects = {}
            
            for layer_idx, layer_results in all_patching_results.items():
                patching_effects = layer_results["patching_effect"]
                layer_patching_effects[int(layer_idx)] = list(patching_effects.values())
            
            # Calculate average patching effects and confidence intervals for each layer
            data_patching = []
            patching_ci_lower = []
            patching_ci_upper = []
            
            for layer_idx in sorted(layer_patching_effects.keys()):
                effects = layer_patching_effects[layer_idx]
                avg_effect = sum(effects) / len(effects)
                std_effect = np.std(effects, ddof=1)  # Sample standard deviation
                
                # Calculate 95% confidence interval using t-distribution
                n_prompts = len(effects)
                confidence_level = 0.95
                alpha = 1 - confidence_level
                t_critical = stats.t.ppf(1 - alpha/2, df=n_prompts-1)
                
                ci_margin = t_critical * (std_effect / np.sqrt(n_prompts))
                ci_lower = avg_effect - ci_margin
                ci_upper = avg_effect + ci_margin
                
                data_patching.append(avg_effect)
                patching_ci_lower.append(ci_lower)
                patching_ci_upper.append(ci_upper)
            
            has_activation_patching = True
            print(f"Loaded activation patching data with {len(data_patching)} layers")
        except Exception as e:
            print(f"Could not load activation patching data: {e}")
            data_patching = None
            patching_ci_lower = None
            patching_ci_upper = None

        # Check if we have any data to plot
        if not has_umap and not has_saliency and not has_lesioning and not has_activation_patching:
            print("No data available for plotting. UMAP, saliency, lesioning, and activation patching data are missing.")
            return

        # Determine the number of layers (use whichever data is available)
        # Use the maximum number of layers from all available data sources
        max_layers = 0
        if has_umap:
            max_layers = max(max_layers, len(data_umap))
        if has_saliency:
            max_layers = max(max_layers, len(data_sal))
        if has_lesioning:
            max_layers = max(max_layers, len(data_lesion))
        if has_activation_patching:
            max_layers = max(max_layers, len(data_patching))
        
        layers = np.arange(max_layers)

        # Create appropriate subplot layout
        available_modalities = sum([has_umap, has_saliency, has_lesioning, has_activation_patching])
        
        if available_modalities == 4:
            # All four modalities available - create four subplots side by side
            fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(24, 6))
            axs = [ax1, ax2, ax3, ax4]
        elif available_modalities == 3:
            # Three modalities available - create three subplots side by side
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
            axs = [ax1, ax2, ax3]
        elif available_modalities == 2:
            # Two modalities available - create two subplots side by side
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
            axs = [ax1, ax2]
        else:
            # Only one modality available - create single plot
            fig, ax = plt.subplots(1, 1, figsize=(8, 6))
            axs = [ax]

        # add more padding between subplots
        plt.subplots_adjust(wspace=0.3)

        # Plot UMAP clustering coefficient if available
        if has_umap:
            ax_idx = 0
            ax = axs[ax_idx]
            label = "UMAP Clustering Coeff."
            umap_layers = np.arange(len(data_umap))

            ax.plot(umap_layers, data_umap, label=label, color='blue', linewidth=2)
            ax.fill_between(umap_layers, umap_ci_lower, umap_ci_upper, color='blue', alpha=0.2)
            ax.set_xlabel("Layer")
            ax.set_ylabel("Clustering Coefficient")
            ax.set_title("Clustering Coefficient")
            ax.legend()
            ax.grid(True, alpha=0.3)

        # Plot saliency if available
        if has_saliency:
            ax_idx = 1 if has_umap else 0
            ax = axs[ax_idx]
            sal_layers = np.arange(len(data_sal))
            ax.plot(sal_layers, data_sal, label="Total Saliency", color='green', linewidth=2)
            ax.fill_between(sal_layers, sal_ci_lower, sal_ci_upper, color='green', alpha=0.2)
            ax.set_xlabel("Layer")
            ax.set_ylabel("Saliency")
            ax.set_title("Saliency")
            ax.legend()
            ax.grid(True, alpha=0.3)

        # Plot lesioning if available
        if has_lesioning:
            ax_idx = 2 if (has_umap and has_saliency) else (1 if (has_umap or has_saliency) else 0)
            ax = axs[ax_idx]
            lesion_layers = np.arange(len(data_lesion))
            ax.plot(lesion_layers, data_lesion, label="Layer Ablation Impact", color='red', linewidth=2, marker='o')
            ax.fill_between(lesion_layers, lesion_ci_lower, lesion_ci_upper, color='red', alpha=0.2)
            ax.set_xlabel("Layer")

            # move this to the right in the figure
            ax.set_ylabel("Change from Original", labelpad=10) 
            ax.set_title("Layer Ablation Analysis")
            ax.set_ylim(1, 10)
            ax.set_yticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
            ax.set_yticklabels(['1\n(no change)', '2', '3', '4', '5', '6', '7', '8', '9', '10\n(significant\n disruption)'])
            ax.legend()
            ax.grid(True, alpha=0.3)
        
        # Plot activation patching if available
        if has_activation_patching:
            ax_idx = 3 if (has_umap and has_saliency and has_lesioning) else (2 if (has_umap and has_saliency) or (has_umap and has_lesioning) or (has_saliency and has_lesioning) else (1 if (has_umap or has_saliency) or (has_umap or has_lesioning) or (has_saliency or has_lesioning) else 0))
            ax = axs[ax_idx]
            patching_layers = np.arange(len(data_patching))
            ax.plot(patching_layers, data_patching, label="Activation Patching Effect", color='purple', linewidth=2, marker='^')
            ax.fill_between(patching_layers, patching_ci_lower, patching_ci_upper, color='purple', alpha=0.2)
            ax.set_xlabel("Layer")
            ax.set_ylabel("Normalized Effect")
            ax.set_title("Activation Patching Analysis")
            ax.legend()
            ax.grid(True, alpha=0.3)
        
        # Main title for the entire figure
        if available_modalities >= 2:
            fig.suptitle(f"LLM Map: {self.analysis_name} {self.MODEL_NAME}", fontsize=14)
        else:
            # Single plot - use the subplot title as main title
            if has_umap:
                modality = "Clustering Coefficient"
            elif has_saliency:
                modality = "Saliency"
            elif has_lesioning:
                modality = "Layer Ablation Analysis"
            elif has_activation_patching:
                modality = "Activation Patching Analysis"
            axs[0].set_title(f"LLM Map: {self.analysis_name} {self.MODEL_NAME} - {modality}", fontsize=14)
        out_file = f"results/{self.analysis_name}_llm-maps_{self.MODEL_NAME}.pdf"
        plt.savefig(out_file, dpi=300)
        print(f"Saved → {out_file}")
        plt.show()

    

    def compute_saliency_per_layer(self, avg_sal_per_prompt):
        """
        Compute average saliency at each layer for the LLM maps.
        
        Args:
            avg_sal_per_prompt: Dict of {prompt -> {param_name -> scalar_saliency}}
        
        Returns:
            sal_mean, sal_ci_lower, sal_ci_upper: Arrays of saliency statistics per layer
        """
        
        # Convert per-prompt saliency to per-layer saliency
        saliency_per_layer_per_prompt = []
        
        for prompt, sal_dict in avg_sal_per_prompt.items():
            saliency_per_layer = np.zeros(self.num_layers)
            
            for name, val in sal_dict.items():
                if not name.endswith("weight"): 
                    continue
                parts = name.split(".")
                try:
                    layer_idx = int(parts[self.layer_part])
                except (ValueError, IndexError):
                    continue
                
                # Accumulate saliency for this layer
                sal = val if isinstance(val, (int, float)) else float(val)
                saliency_per_layer[layer_idx] += sal
            
            saliency_per_layer_per_prompt.append(saliency_per_layer)
        
        # Convert to numpy array for easier computation
        saliency_per_layer_per_prompt = np.array(saliency_per_layer_per_prompt)
        
        # Compute statistics
        sal_mean = np.mean(saliency_per_layer_per_prompt, axis=0)
        sal_std = np.std(saliency_per_layer_per_prompt, axis=0, ddof=1)  # Sample standard deviation
        
        # Compute 95% confidence intervals using t-distribution
        n_prompts = len(avg_sal_per_prompt)
        confidence_level = 0.95
        alpha = 1 - confidence_level
        t_critical = stats.t.ppf(1 - alpha/2, df=n_prompts-1)
        
        sal_ci_lower = sal_mean - t_critical * (sal_std / np.sqrt(n_prompts))
        sal_ci_upper = sal_mean + t_critical * (sal_std / np.sqrt(n_prompts))
        
        return sal_mean, sal_ci_lower, sal_ci_upper

    def create_identity_layer(self, original_layer=None):
        """
        Creates an identity layer that simply passes through the input without modification.
        This is used to replace layers during ablation studies.
        """
        class IdentityLayer(torch.nn.Module):
            def __init__(self, original_layer=None):
                super().__init__()
                # Store original layer type information for return format determination
                self._original_layer_type = str(type(original_layer)) if original_layer else "unknown"
                self._is_gemma3 = "Gemma3" in self._original_layer_type
                
                # Copy attributes from the original layer to maintain compatibility
                if original_layer is not None:
                    # Copy all attributes that might be accessed during forward pass
                    for attr_name in dir(original_layer):
                        if not attr_name.startswith('_') and not callable(getattr(original_layer, attr_name)):
                            try:
                                setattr(self, attr_name, getattr(original_layer, attr_name))
                            except:
                                pass  # Skip attributes that can't be copied
                    
                    # Handle special cases for different model architectures
                    # For GPT-OSS models, ensure attention_type is available
                    if hasattr(original_layer, 'attention_type'):
                        self.attention_type = original_layer.attention_type
                    
                    # For models with specific layer types, copy them
                    if hasattr(original_layer, 'layer_type'):
                        self.layer_type = original_layer.layer_type
                    
                    # For models with specific attention implementations
                    if hasattr(original_layer, '_attn_implementation'):
                        self._attn_implementation = original_layer._attn_implementation
                    
                    # For Gemma models, copy specific attributes
                    if hasattr(original_layer, 'input_layernorm'):
                        self.input_layernorm = original_layer.input_layernorm
                    if hasattr(original_layer, 'post_attention_layernorm'):
                        self.post_attention_layernorm = original_layer.post_attention_layernorm
                    if hasattr(original_layer, 'pre_feedforward_layernorm'):
                        self.pre_feedforward_layernorm = original_layer.pre_feedforward_layernorm
                    if hasattr(original_layer, 'post_feedforward_layernorm'):
                        self.post_feedforward_layernorm = original_layer.post_feedforward_layernorm
            
            def forward(self, hidden_states, *args, **kwargs):
                # Return the input unchanged, maintaining the same structure as a transformer layer
                # For Gemma3 models, we need to apply normalization layers to maintain proper tensor shapes
                
                if self._is_gemma3:
                    # For Gemma3, apply input normalization to maintain tensor compatibility
                    if hasattr(self, 'input_layernorm'):
                        # Apply input normalization to preserve tensor dimensions
                        normalized_states = self.input_layernorm(hidden_states)
                        # Return in the format expected by Gemma3: (hidden_states, attention_weights, present_key_value)
                        return (normalized_states, None, None)
                    else:
                        # Fallback: return input unchanged
                        return (hidden_states, None, None)
                elif hasattr(self, 'self_attn') and hasattr(self, 'mlp'):
                    # Most transformer layers return tuples when they have attention and MLP
                    return (hidden_states, None)  # (hidden_states, attention_weights)
                else:
                    # For Llama and simpler models, return just hidden_states
                    return hidden_states
        
        return IdentityLayer(original_layer)

    def _check_lesioning_compatibility(self):
        """
        Checks if the current model is compatible with layer lesioning.
        Provides warnings and recommendations for problematic models.
        """
        model_name = self.MODEL_NAME.lower()
        
        # Check for known problematic model types
        # if "gemma" in model_name and "27b" in model_name:
        #     print("Warning: Gemma-27B models are very large and may use disk offloading.")
        #     print("Recommendation: Use a smaller model for lesioning, or ensure sufficient GPU memory.")
        #     print("Alternative: Use Gemma-2B or Gemma-7B models for lesioning analysis.")
        
        # if "70b" in model_name or "120b" in model_name:
        #     print("Warning: Very large models (70B+ parameters) may have memory issues during lesioning.")
        #     print("Recommendation: Ensure sufficient GPU memory or use smaller models.")
        
        # Check if model is using disk offloading
        if hasattr(self.model, 'hf_device_map'):
            device_map = self.model.hf_device_map
            if device_map and any('disk' in str(v) for v in device_map.values()):
                print("Warning: Model is using disk offloading which can cause issues during lesioning.")
                print("Recommendation: Use explicit device mapping or smaller models.")
                raise ValueError("Model is using disk offloading which takes a long time and is not supported for lesioning.")

    def ablate_layer(self, layer_idx):
        """
        Temporarily replaces a specific layer with an identity function.
        
        Args:
            layer_idx: Index of the layer to ablate
            
        Returns:
            The original layer (to be restored later)
        """
        if layer_idx >= len(self.model_layers):
            raise ValueError(f"Layer index {layer_idx} out of range. Model has {len(self.model_layers)} layers.")
        
        # Store the original layer
        original_layer = self.model_layers[layer_idx]
        
        # Replace with identity layer
        identity_layer = self.create_identity_layer(original_layer)
        
        # Ensure the identity layer is on the same device as the original layer
        if hasattr(original_layer, 'weight') and original_layer.weight is not None:
            device = original_layer.weight.device
            identity_layer = identity_layer.to(device)
        
        # Copy any device-specific attributes
        if hasattr(original_layer, 'device'):
            identity_layer.device = original_layer.device
        
        self.model_layers[layer_idx] = identity_layer
        
        return original_layer

    def restore_layer(self, layer_idx, original_layer):
        """
        Restores the original layer after ablation.
        
        Args:
            layer_idx: Index of the layer to restore
            original_layer: The original layer that was replaced
        """
        self.model_layers[layer_idx] = original_layer

    def run_model_with_prompts(self, prompts, max_new_tokens=50):
        """
        Runs the model on a list of prompts and returns the generated text.
        Processes all prompts in parallel for better GPU utilization.
        
        Args:
            prompts: List of prompt strings
            max_new_tokens: Maximum number of new tokens to generate
            
        Returns:
            List of generated text responses
        """
        device = next(self.model.parameters()).device
        
        with torch.no_grad():
            # Tokenize all prompts at once
            inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)
            
            # Generate responses for all prompts in parallel
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=self.tokenizer.eos_token_id,
                temperature=0.0,  # Deterministic generation
            )
            
            # Decode all responses
            responses = []
            for i, prompt in enumerate(prompts):
                # Get the generated tokens (excluding the input)
                input_length = inputs["input_ids"][i].shape[-1]
                generated_tokens = outputs[i][input_length:]
                
                response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
                responses.append(response)
        
        return responses

    def runLesioning(self, load_cached=True):
        """
        Performs layer lesioning analysis 
        """
        print(f"Running layer lesioning analysis for {self.analysis_name}...")
        
        # generate the prompts
        prompts = self.generate_lesioning_prompts()

        # run the lesioning
        lesion_file = self.lesioningFile
        if load_cached and os.path.exists(lesion_file):
            with open(lesion_file, "r") as f:
                lesion_results = json.load(f)
            print(f"Loaded cached lesioning results from {lesion_file}")
        else:
            # Run layer lesioning
            lesion_results = self.run_layer_lesioning(prompts, output_file=lesion_file)
            print(f"{self.analysis_name} lesioning analysis complete. Results saved to {lesion_file}")
        
        
        self.print_lesioning_results(lesion_results)
        self.print_lesioning_table(lesion_results)

    def runLesioningFinegrained(self, load_cached=True):
        """
        Run fine-grained layer lesioning analysis
        """
        # generate the prompts
        self.prompts = self.generate_lesioning_prompts()
        self.run_layer_lesioning_finegrained(self.prompts, load_cached=load_cached)

    # --- Fine-grained lesioning (per attention head and per MLP) ---
    def run_layer_lesioning_finegrained(self, prompts, max_new_tokens=50, load_cached=True):
        """
        For each layer, lesion each attention head independently (zero its residual contribution)
        and lesion the MLP (zero its residual contribution). Collect responses and run a
        GPT-based judge over all heads for that layer alongside the MLP to assess impact.
        
        Reads PART environment variable (0,1,2,3) to determine which quarter of the model to process.
        """
        
        # output_file = self.lesioningFinegrainedFile
        # if load_cached is true, we'll check if all files exist for all layers and assemble the results
        
        if load_cached:
            # Try to assemble results from individual layer files
            results = self._assemble_finegrained_results_from_layers()
            if results is not None:
                print(f"Assembled fine-grained lesioning results from individual layer files")
            else:
                print("No cached layer files found, running fresh analysis...")
                load_cached = False

            self.plot_finegrained_lesioning_heatmap(results)
        
        if not load_cached:

            print(f"Starting fine-grained lesioning with {len(prompts)} prompts…")
            self._check_lesioning_compatibility()

            # Helper: find attention output projection module (o_proj/out_proj)
            def _get_o_proj(layer):
                attn = getattr(layer, 'self_attn', None)
                if attn is None:
                    return None
                for name in ['o_proj', 'out_proj', 'o_proj_linear', 'proj_out']:
                    if hasattr(attn, name):
                        return getattr(attn, name)
                # fall back: try any nn.Linear child with in_features == hidden
                for m in attn.modules():
                    if hasattr(m, 'in_features') and hasattr(m, 'out_features'):  # likely nn.Linear
                        return m
                return None

            # Helper: infer number of heads and head_dim
            def _infer_heads(layer, o_proj):
                attn = getattr(layer, 'self_attn', None)
                num_heads = None
                head_dim = None
                hidden_in = None
                if o_proj is not None and hasattr(o_proj, 'in_features'):
                    hidden_in = int(o_proj.in_features)
                # common names
                for name in ['num_heads', 'n_heads', 'num_attention_heads']:
                    if attn is not None and hasattr(attn, name):
                        num_heads = int(getattr(attn, name))
                        break
                if num_heads is None and attn is not None and hasattr(attn, 'head_dim') and hidden_in is not None:
                    hd = int(getattr(attn, 'head_dim'))
                    if hd > 0:
                        num_heads = hidden_in // hd
                if num_heads is not None and hidden_in is not None:
                    head_dim = hidden_in // num_heads
                return num_heads, head_dim, hidden_in

            # Prepare original responses (for reference in judge prompt)
            print("Generating original responses…")
            original_responses = self.run_model_with_prompts(prompts, max_new_tokens=max_new_tokens)

            # Judge per layer: compare original to each head+mlp variant using GPT
            def _analyze_layer_with_gpt(prompt, orig_resp, head_to_resp, mlp_resp):
                try:
                    from openai import OpenAI
                    import os
                    api_key = None
                    try:
                        with open(os.path.expanduser("~/.oai"), "r") as f:
                            api_key = f.read().strip()
                    except FileNotFoundError:
                        api_key = os.environ.get("OPENAI_API_KEY")
                    if not api_key:
                        raise RuntimeError("OPENAI_API_KEY not found")
                    client = OpenAI(api_key=api_key)
                    # Build analysis text
                    txt = []
                    txt.append("Original prompt: " + prompt)
                    txt.append("Original response: " + orig_resp)
                    txt.append("Lesioned responses for this layer (each entry zeroes one head's residual; 'mlp' zeroes the MLP residual):")
                    for h, resp in head_to_resp.items():
                        txt.append(f"Head {h}: {resp}")
                    if mlp_resp is not None:
                        txt.append(f"MLP: {mlp_resp}")
                    txt.append("\nFor each lesion above, rate disruption 1-10 (10=severe), and give a one-line justification. Format: 'Head i: Score X - justification' and 'MLP: Score Y - justification'.")
                    content = "\n".join(txt)
                    resp = client.chat.completions.create(
                        model="gpt-4o",
                        messages=[
                            {"role": "system", "content": "You are an expert analyzing fine-grained transformer lesions."},
                            {"role": "user", "content": content},
                        ],
                        temperature=0.3,
                        max_tokens=6000,
                    )
                    return resp.choices[0].message.content.strip()
                except Exception as e:
                    return f"Judge failed: {e}"

            # Read PART environment variable to determine which quarter to process
            import os
            part_of_model_to_process = int(os.environ.get('PART', '0'))
            
            # Determine which layers to process based on quarter
            total_layers = len(self.model_layers)
            layers_per_quarter = total_layers // 4
            remainder = total_layers % 4
            
            # Calculate start and end indices for the specified quarter
            start_idx = part_of_model_to_process * layers_per_quarter
            if part_of_model_to_process < remainder:
                start_idx += part_of_model_to_process
                end_idx = start_idx + layers_per_quarter + 1
            else:
                start_idx += remainder
                end_idx = start_idx + layers_per_quarter
            
            layers_to_process = list(range(start_idx, end_idx))
            print(f"Processing quarter {part_of_model_to_process}: layers {layers_to_process} (indices {start_idx}-{end_idx-1})")
            all_layer_results = {}
            per_prompt_analysis = {}

            debug = False
            layers = self.model_layers if not debug else [self.model_layers[0], self.model_layers[1]] # only 
            
            for layer_idx in layers_to_process:
                layer = self.model_layers[layer_idx]
                print(f"Layer {layer_idx}: probing heads and MLP")
                o_proj = _get_o_proj(layer)
                num_heads, head_dim, hidden_in = _infer_heads(layer, o_proj)
                if debug:
                    num_heads = 2 # only run first 2 heads
                
                layer_result = { 'heads': {}, 'mlp': None }

                # Lesion each attention head (if available)
                if o_proj is not None and num_heads and head_dim:
                    for h in range(num_heads):
                        print(f"Layer {layer_idx}: probing head {h}")
                        start = h * head_dim
                        end = start + head_dim

                        def _pre_hook_factory(s, e):
                            def _pre_hook(mod, inputs):
                                x = inputs[0]
                                # zero a slice along the last dimension (residual chunk for this head)
                                x = x.clone()
                                x[..., s:e] = 0
                                return (x,)
                            return _pre_hook

                        pre_hook = o_proj.register_forward_pre_hook(_pre_hook_factory(start, end))
                        try:
                            responses = self.run_model_with_prompts(prompts, max_new_tokens=max_new_tokens)
                        finally:
                            pre_hook.remove()
                        layer_result['heads'][h] = responses
                else:
                    print(f"  Layer {layer_idx}: could not infer attention heads; skipping head-wise lesioning.")

                # Lesion MLP by zeroing its residual output
                mlp = getattr(layer, 'mlp', None)
                if mlp is None:
                    # try alternative names
                    for name in ['feed_forward', 'ffn', 'ff']:
                        if hasattr(layer, name):
                            mlp = getattr(layer, name)
                            break
                if mlp is not None:
                    def _mlp_zero_hook(mod, inputs, output):
                        return torch.zeros_like(output)
                    hdl = mlp.register_forward_hook(_mlp_zero_hook)
                    try:
                        responses = self.run_model_with_prompts(prompts, max_new_tokens=max_new_tokens)
                    finally:
                        hdl.remove()
                    layer_result['mlp'] = responses
                else:
                    print(f"  Layer {layer_idx}: no MLP module found; skipping MLP lesioning.")

                all_layer_results[layer_idx] = layer_result
                
                # Run GPT analysis for this layer
                print(f"Layer {layer_idx}: running GPT analysis...")
                layer_analysis = {}
                for p_idx, prompt in enumerate(prompts):
                    orig = original_responses[p_idx]
                    heads_responses = {h: layer_result['heads'][h][p_idx] for h in layer_result['heads']}
                    mlp_resp = layer_result['mlp'][p_idx] if layer_result['mlp'] is not None else None
                    analysis_text = _analyze_layer_with_gpt(prompt, orig, heads_responses, mlp_resp)
                    layer_analysis[p_idx] = analysis_text
                
                per_prompt_analysis[layer_idx] = layer_analysis
                
                # Save layer results with GPT analysis immediately
                layer_file = self.lesioningFinegrainedFile[:-5] + f'_layer{layer_idx}.json'
                layer_data = {
                    'original_responses': original_responses,
                    'layer_result': layer_result,
                    'layer_analysis': layer_analysis,
                    'layer_idx': layer_idx,
                    'prompts': prompts,
                }
                with open(layer_file, 'w') as f:
                    json.dump(layer_data, f, indent=2)
                print(f"Layer {layer_idx} results with GPT analysis saved to {layer_file}")

            # # Compile final results from all layer files
            # print("Compiling final results from all layer files...")
            # results = {
            #     'original_responses': original_responses,
            #     'finegrained_responses': all_layer_results,
            #     'per_prompt_analysis': per_prompt_analysis,
            #     'prompts': prompts,
            # }

            # with open(output_file, 'w') as f:
            #     json.dump(results, f, indent=2)
            # print(f"Final fine-grained lesioning results compiled and saved to {output_file}")

        

    def _assemble_finegrained_results_from_layers(self):
        """
        Assemble fine-grained lesioning results from individual layer JSON files.
        Returns None if not all layer files are found.
        """
        import os
        import json
        
        # Get the base filename for layer files
        base_file = self.lesioningFinegrainedFile[:-5]  # Remove .json extension
        
        # Try to find all layer files
        all_layer_results = {}
        per_prompt_analysis = {}
        original_responses = None
        prompts = None
        
        # Check for layer files starting from 0
        layer_idx = 0
        found_any = False
        
        while True:
            layer_file = f"{base_file}_layer{layer_idx}.json"
            if os.path.exists(layer_file):
                try:
                    with open(layer_file, 'r') as f:
                        layer_data = json.load(f)
                    
                    # Extract data from this layer
                    all_layer_results[layer_idx] = layer_data['layer_result']
                    
                    # Reorganize layer_analysis from {prompt_idx: analysis_text} to {prompt_idx: {layer_idx: analysis_text}}
                    layer_analysis = layer_data['layer_analysis']
                    for prompt_idx, analysis_text in layer_analysis.items():
                        if prompt_idx not in per_prompt_analysis:
                            per_prompt_analysis[prompt_idx] = {}
                        per_prompt_analysis[prompt_idx][layer_idx] = analysis_text
                    
                    # Get original_responses and prompts from first layer (should be same for all)
                    if original_responses is None:
                        original_responses = layer_data['original_responses']
                        prompts = layer_data['prompts']
                    
                    found_any = True
                    print(f"Loaded layer {layer_idx} from {layer_file}")
                    layer_idx += 1
                    
                except Exception as e:
                    print(f"Error loading {layer_file}: {e}")
                    break
            else:
                break
        
        if not found_any:
            return None
        
        # Assemble final results
        results = {
            'original_responses': original_responses,
            'finegrained_responses': all_layer_results,
            'per_prompt_analysis': per_prompt_analysis,
            'prompts': prompts,
        }
        
        print(f"Successfully assembled results from {len(all_layer_results)} layers")
        return results

    def plot_finegrained_lesioning_heatmap(self, lesion_results):
        """
        Visualize fine-grained lesioning scores (per attention head and per MLP) as a heatmap.

        Expects a dict produced by run_layer_lesioning_finegrained with:
          - 'finegrained_responses': per-layer responses (not used directly here)
          - 'per_prompt_analysis': mapping prompt_idx -> { layer_idx -> analysis_text }

        The analysis_text is expected to contain lines like:
          "Head i: Score X - justification" and "MLP: Score Y - justification".

        We parse scores, average across prompts, and plot a heatmap with
        columns [H0..H{num_heads-1}, MLP] and rows [L0..L{num_layers-1}].
        """
        import re
        print("Creating fine-grained lesioning heatmap…")

        # Infer number of layers and heads from content
        fine = lesion_results.get('finegrained_responses', {})
        if not fine:
            print("No finegrained_responses found; nothing to plot.")
            return

        layer_indices = sorted(list(fine.keys()))
        # keys might be ints; ensure ints
        layer_indices = [int(k) for k in layer_indices]
        num_layers = max(layer_indices) + 1 if layer_indices else len(self.model_layers)
        # infer max heads across layers
        max_heads = 0
        for lyr in layer_indices:
            heads_dict = fine[lyr].get('heads', {})
            max_heads = max(max_heads, len(heads_dict))
        num_heads = max_heads

        # containers for accumulating scores per cell (list of floats)
        cell_scores = [[[] for _ in range(num_heads + 1)] for _ in range(num_layers)]

        per_prompt = lesion_results.get('per_prompt_analysis', {})
        # regex for Head i and MLP lines
        re_head = re.compile(r"Head\s+(\d+)\s*:\s*Score\s*([0-9]+(?:\.[0-9]+)?)", re.IGNORECASE)
        re_mlp  = re.compile(r"MLP\s*:\s*Score\s*([0-9]+(?:\.[0-9]+)?)", re.IGNORECASE)

        for p_idx, layer_map in per_prompt.items():
            # layer_map: {layer_idx -> analysis_text}
            for layer_key, txt in layer_map.items():
                try:
                    layer_idx = int(layer_key)
                except Exception:
                    continue
                if layer_idx < 0 or layer_idx >= num_layers:
                    continue
                if not isinstance(txt, str):
                    continue

                # Parse all head scores
                for m in re_head.finditer(txt):
                    h = int(m.group(1))
                    try:
                        s = float(m.group(2))
                    except Exception:
                        continue
                    if 0 <= h < num_heads:
                        cell_scores[layer_idx][h].append(s)

                # Parse MLP score
                m2 = re_mlp.search(txt)
                if m2:
                    try:
                        s = float(m2.group(1))
                        cell_scores[layer_idx][num_heads].append(s)
                    except Exception:
                        pass

        # Build matrix of average scores
        import numpy as np
        mat = np.zeros((num_layers, num_heads + 1), dtype=float)
        for li in range(num_layers):
            for hj in range(num_heads + 1):
                vals = cell_scores[li][hj]
                mat[li, hj] = float(np.mean(vals)) if len(vals) > 0 else 0.0

        # Plot with white→red colormap, not log-space
        import matplotlib.pyplot as plt
        from matplotlib.colors import LinearSegmentedColormap
        colors = ['white', '#fee5e5', '#fca5a5', '#ef4444', '#b91c1c']
        cmap = LinearSegmentedColormap.from_list('white_to_red', colors, N=256)

        vmax_val = 10.0  # scores are 1–10 by convention
        fig, ax = plt.subplots(figsize=(max(12, (num_heads+1) * 0.6), max(8, num_layers * 0.35)))
        im = ax.imshow(mat, cmap=cmap, aspect='auto', vmin=0.0, vmax=vmax_val)

        ax.set_xlabel('Attention Head / MLP', fontsize=24)
        ax.set_ylabel('Layer', fontsize=24)
        ax.set_title(f'Fine-grained Lesioning Scores (LLM-as-judge)\n{self.MODEL_NAME}', fontsize=14)

        x_labels = [f'H{i}' for i in range(num_heads)] + ['MLP']
        ax.set_xticks(range(num_heads + 1))
        ax.set_xticklabels(x_labels, fontsize=20)
        ax.set_yticks(range(num_layers))
        ax.set_yticklabels([f'L{i}' for i in range(num_layers)])
        ax.axvline(x=num_heads - 0.5, color='black', linewidth=2)

        cbar = plt.colorbar(im, ax=ax, shrink=0.85)
        cbar.set_label('Disruption score (1–10)', fontsize=10)

        # annotate cells with values
        for i in range(num_layers):
            for j in range(num_heads + 1):
                val = mat[i, j]
                txt_color = 'black' if val < (vmax_val * 0.6) else 'white'
                ax.text(j, i, f'{val:.1f}', ha='center', va='center', color=txt_color, fontsize=8)

        plt.tight_layout()
        out_pdf = f"results/{self.analysis_name}_2Dlesioning_{self.MODEL_NAME}.pdf"
        plt.savefig(out_pdf, dpi=300, bbox_inches='tight')
        print(f"Fine-grained lesioning heatmap saved to {out_pdf}")
        plt.show()
        plt.close()

    def analyze_prompt_across_layers(self, prompt, prompt_data):
        """
        Analyzes a single prompt across all lesioned layers and scores degradation from 1-10.
        
        Args:
            prompt: The original prompt
            prompt_data: Dictionary containing original_response and lesioned_responses for all layers
            
        Returns:
            Dictionary with scores and justifications for each layer
        """
        # Load analysis model if not already loaded
        if self.analysis_model is None:
            # analysis_model_path = "/data/hf/meta-llama/Llama-3.1-8B-Instruct"
            analysis_model_path = "/data/hf/meta-llama/Llama-3.3-70B-Instruct"
            print(f"Loading LLM-as-a-judge model for lesion analysis: {analysis_model_path}")
            
            # Load the analysis model - use AutoModelForCausalLM to avoid type conflicts
            from transformers import AutoProcessor, AutoModelForCausalLM
            
            self.analysis_processor = AutoProcessor.from_pretrained(analysis_model_path)
            self.analysis_tokenizer = self.analysis_processor
            
            # Create explicit device map for Llama-8B
            from transformers import AutoConfig
            config = AutoConfig.from_pretrained(analysis_model_path)
            num_layers = config.num_hidden_layers
            layers_per_gpu = num_layers // 4
            
            device_map = {
                "model.embed_tokens": 0,
                "model.norm": 3,
                "lm_head": 3,
            }
            
            # Distribute layers across GPUs
            for i in range(num_layers):
                gpu_id = i // layers_per_gpu
                if gpu_id >= 4:  # Fallback to GPU 3 if we have more layers than expected
                    gpu_id = 3
                device_map[f"model.layers.{i}"] = gpu_id
            
            # Memory configuration for 4 GPUs (8B model needs less memory)
            max_mem = {
                0: "5GB",
                1: "5GB", 
                2: "5GB",
                3: "5GB",
            }
            
            self.analysis_model = AutoModelForCausalLM.from_pretrained(
                analysis_model_path,
                device_map=device_map,
                max_memory=max_mem,
                torch_dtype=torch.bfloat16,
                trust_remote_code=True,
            ).eval()
        
        # Build comparison text for all layers
        comparison_text = f"Here is the original prompt and response, followed by responses from the same model with different layers ablated (replaced with identity functions).\n\n"
        comparison_text += f"Original prompt: {prompt}\n"
        comparison_text += f"Original response: {prompt_data['original_response']}\n\n"
        
        comparison_text += "Lesioned responses:\n"
        for layer_idx, lesioned_response in prompt_data['lesioned_responses'].items():
            comparison_text += f"Layer {layer_idx} ablated: {lesioned_response}\n"
        
        comparison_text += "\n\nPlease analyze each lesioned response and provide:\n"
        comparison_text += "1. A score from 1-10 for each layer ablation (10 = complete rubbish or significant disruption from original, 1 = no noticeable change from original)\n"
        comparison_text += "2. A brief justification for each score\n\n"
        comparison_text += "Format your response as:\n"
        comparison_text += "Layer 0: Score X - [justification]\n"
        comparison_text += "Layer 1: Score Y - [justification]\n"
        comparison_text += "...\n"
        comparison_text += "Layer N: Score Z - [justification]\n"
        
        # Generate analysis
        inputs = self.analysis_tokenizer(comparison_text, return_tensors="pt", truncation=True, max_length=4096)
        inputs = {k: v.to(self.analysis_model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.analysis_model.generate(
                **inputs,
                max_new_tokens=4000,
                temperature=0.7,
                do_sample=True,
                pad_token_id=self.analysis_tokenizer.eos_token_id
            )
        
        input_length = inputs["input_ids"].shape[-1]
        generated_tokens = outputs[0][input_length:]
        analysis_text = self.analysis_tokenizer.batch_decode([generated_tokens])[0].strip()
        
        # Parse the scores and justifications
        scores_and_justifications = self.parse_layer_scores(analysis_text, self.num_layers)
        
        return {
            'analysis_text': analysis_text,
            'scores_and_justifications': scores_and_justifications
        }

    def analyze_prompt_across_layers_gpt5(self, prompt, prompt_data):
        """
        Analyzes a single prompt across all lesioned layers using OpenAI GPT-5 and scores degradation from 1-10.
        
        Args:
            prompt: The original prompt
            prompt_data: Dictionary containing original_response and lesioned_responses for all layers
            
        Returns:
            Dictionary with scores and justifications for each layer
        """
        import openai
        import os
        print("Using OpenAI GPT-5 for layer lesioning analysis...")
        
        # Load OpenAI API key
        try:
            with open(os.path.expanduser("~/.oai"), "r") as f:
                api_key = f.read().strip()
            openai.api_key = api_key
        except FileNotFoundError:
            raise FileNotFoundError("OpenAI API key not found in ~/.oai file")
        
        # Build comparison text for all layers
        comparison_text = f"Here is the original prompt and response, followed by responses from the same model with different layers ablated (replaced with identity functions).\n\n"
        comparison_text += f"Original prompt: {prompt}\n"
        comparison_text += f"Original response: {prompt_data['original_response']}\n\n"
        
        comparison_text += "Lesioned responses:\n"
        for layer_idx, lesioned_response in prompt_data['lesioned_responses'].items():
            comparison_text += f"Layer {layer_idx} ablated: {lesioned_response}\n"
        
        comparison_text += "\n\nPlease analyze each lesioned response and provide:\n"
        comparison_text += "1. A score from 1-10 for each layer ablation (10 = complete rubbish or significant disruption from original, 1 = no noticeable change from original)\n"
        comparison_text += "2. A brief justification for each score\n\n"
        comparison_text += "Format your response as:\n"
        comparison_text += "Layer 0: Score X - [justification]\n"
        comparison_text += "Layer 1: Score Y - [justification]\n"
        comparison_text += "...\n"
        comparison_text += "Layer N: Score Z - [justification]\n"
        
        # Call OpenAI GPT-5 API
        try:
            response = openai.chat.completions.create(
                model="gpt-4o",  # Use gpt-4o instead of gpt-5 for now
                messages=[
                    {"role": "system", "content": "You are an expert AI researcher analyzing the effects of layer ablation on language model performance. Provide detailed, accurate analysis of how each layer ablation affects the model's responses."},
                    {"role": "user", "content": comparison_text}
                ],
                max_completion_tokens=4000,
                temperature=0.7
            )
            
            analysis_text = response.choices[0].message.content.strip()
            
        except Exception as e:
            print(f"Error calling OpenAI GPT-5 API: {e}")
            # Fallback to default scores if API call fails
            analysis_text = "API call failed. Using default scores."
            scores_and_justifications = {}
            for layer_idx in range(self.num_layers):
                scores_and_justifications[layer_idx] = {
                    'score': 5,  # Default middle score
                    'justification': 'API call failed - using default score'
                }
            return {
                'analysis_text': analysis_text,
                'scores_and_justifications': scores_and_justifications
            }
        
        # Parse the scores and justifications
        scores_and_justifications = self.parse_layer_scores(analysis_text, self.num_layers)
        
        return {
            'analysis_text': analysis_text,
            'scores_and_justifications': scores_and_justifications
        }

    def parse_layer_scores(self, analysis_text, num_layers):
        """
        Parses the analysis text to extract scores and justifications for each layer.
        
        Args:
            analysis_text: The analysis text from the LLM
            num_layers: Number of layers to expect
            
        Returns:
            Dictionary mapping layer indices to scores and justifications
        """
        import re
        
        scores_and_justifications = {}
        
        # Pattern to match various formats:
        # "Layer X: Score Y - [justification]"
        # "Layer X: Y - [justification]" 
        # "Layer X: Y - The response is..."
        patterns = [
            r'Layer\s+(\d+):\s*(?:Score\s+)?(\d+)\s*-\s*(.+)',
            r'Layer\s+(\d+):\s*(\d+)\s*-\s*(.+)',
            r'\*\s*Layer\s+(\d+):\s*(\d+)\s*-\s*(.+)'
        ]
        
        # Split into lines and process each line
        lines = analysis_text.split('\n')
        for line in lines:
            line = line.strip()
            if not line:
                continue
                
            for pattern in patterns:
                match = re.search(pattern, line, re.IGNORECASE)
                if match:
                    layer_idx = int(match.group(1))
                    
                    # Skip layers beyond the actual model layers
                    if layer_idx >= num_layers:
                        continue
                        
                    score = int(match.group(2))
                    justification = match.group(3).strip()
                    
                    # Ensure score is within valid range
                    score = max(1, min(10, score))
                    
                    scores_and_justifications[layer_idx] = {
                        'score': score,
                        'justification': justification
                    }
                    break  # Found a match for this line, move to next line
        
        # Fill in missing layers with default values
        for layer_idx in range(num_layers):
            if layer_idx not in scores_and_justifications:
                scores_and_justifications[layer_idx] = {
                    'score': 5,  # Default middle score
                    'justification': 'Score not found in analysis'
                }
        
        return scores_and_justifications

    def cleanup_analysis_model(self):
        """
        Clean up the cached analysis model to free GPU memory.
        """
        if self.analysis_model is not None:
            del self.analysis_model
            del self.analysis_processor
            del self.analysis_tokenizer
            self.analysis_model = None
            self.analysis_processor = None
            self.analysis_tokenizer = None
            torch.cuda.empty_cache()
            print("Analysis model cleaned up and memory freed.")

    def run_layer_lesioning(self, prompts, output_file=None, llm_as_a_judge_only=False, gpt5_judge=True):
        """
        Performs layer lesioning by removing one layer at a time and analyzing the effects.
        For each prompt, analyzes all layers together and scores degradation from 1-10.
        
        Args:
            prompts: List of prompts to test
            output_file: Optional file path to save results
            llm_as_a_judge_only: If True, it will load the lesioning from the json and 
                re-run the LLM-as-a-judge only the LLM will only be used as a judge to 
                score the degradation of the layers.
            gpt5_judge: If True, uses OpenAI GPT-5 for analysis instead of local model
            
        Returns:
            Dictionary containing lesioning results with scores for each layer
        """
        print(f"Starting layer lesioning analysis with {len(prompts)} prompts...")
        
        # Check if model uses disk offloading (which can cause issues during lesioning)
        if hasattr(self.model, 'hf_device_map') and self.model.hf_device_map:
            print("Warning: Model uses device mapping. This should work with our explicit device maps.")
        
        # Check model compatibility for lesioning
        self._check_lesioning_compatibility()
        
        if llm_as_a_judge_only:
            print("Loading lesioning results from json file... Will only re-run the LLM-as-a-judge.")
            with open(output_file, 'r') as f:
                lesion_results = json.load(f)
                original_responses = lesion_results['original_responses']
                all_lesioned_responses = lesion_results['all_lesioned_responses']
                prompts = lesion_results['prompts']
            
            # Analyze each prompt across all layers
            prompt_scores = {}
            
            for prompt_idx, prompt in enumerate(prompts):
                print(f"Re-analyzing prompt {prompt_idx + 1}/{len(prompts)}: {prompt[:50]}...")
                
                # Collect all responses for this prompt across all layers
                prompt_analysis_data = {
                    'original_response': original_responses[prompt_idx],
                    'lesioned_responses': {}
                }
                
                for layer_idx in range(self.num_layers):
                    prompt_analysis_data['lesioned_responses'][layer_idx] = all_lesioned_responses[str(layer_idx)][prompt_idx]
                
                # Analyze this prompt across all layers
                if gpt5_judge:
                    analysis_result = self.analyze_prompt_across_layers_gpt5(prompt, prompt_analysis_data)
                else:
                    analysis_result = self.analyze_prompt_across_layers(prompt, prompt_analysis_data)
                prompt_scores[prompt_idx] = analysis_result

        else:
            print("Running layer lesioning analysis...")

            # Get original responses
            print("Getting original model responses...")
            original_responses = self.run_model_with_prompts(prompts)
            
            # Store all lesioned responses for each layer
            all_lesioned_responses = {}
            
            # Test each layer and collect responses
            for layer_idx in range(self.num_layers):
                print(f"Testing ablation of layer {layer_idx}/{self.num_layers-1}...")
                
                # Ablate the layer
                original_layer = self.ablate_layer(layer_idx)
                
                # Get responses with lesioned model
                lesioned_responses = self.run_model_with_prompts(prompts)
                
                # Store lesioned responses
                all_lesioned_responses[layer_idx] = lesioned_responses
                
                # Restore the layer
                self.restore_layer(layer_idx, original_layer)
                
                print(f"Layer {layer_idx} responses collected.")
            
            # Analyze each prompt across all layers
            prompt_scores = {}
            
            for prompt_idx, prompt in enumerate(prompts):
                print(f"Analyzing prompt {prompt_idx + 1}/{len(prompts)}: {prompt[:50]}...")
                
                # Collect all responses for this prompt across all layers
                prompt_analysis_data = {
                    'original_response': original_responses[prompt_idx],
                    'lesioned_responses': {}
                }
                
                for layer_idx in range(self.num_layers):
                    prompt_analysis_data['lesioned_responses'][layer_idx] = all_lesioned_responses[layer_idx][prompt_idx]
                
                # Analyze this prompt across all layers
                if gpt5_judge:
                    analysis_result = self.analyze_prompt_across_layers_gpt5(prompt, prompt_analysis_data)
                else:
                    analysis_result = self.analyze_prompt_across_layers(prompt, prompt_analysis_data)
                prompt_scores[prompt_idx] = analysis_result
        
        # Compile final results
        lesion_results = {
            'original_responses': original_responses,
            'all_lesioned_responses': all_lesioned_responses,
            'prompt_scores': prompt_scores,
            'prompts': prompts
        }
        
        # Save results if output file specified
        if output_file:
            with open(output_file, 'w') as f:
                json.dump(lesion_results, f, indent=2)
            print(f"Lesioning results saved to {output_file}")
        
        # Clean up the analysis model
        self.cleanup_analysis_model()
        
        return lesion_results



    def print_lesioning_results(self, lesion_results):
        """
        Prints a formatted summary of the lesioning analysis results with scores.
        
        Args:
            lesion_results: Dictionary containing lesioning results with scores
        """
        print("\n" + "="*80)
        print("LAYER LESIONING ANALYSIS RESULTS")
        print("="*80)
        
        # Ensure prompt_scores keys are integers
        raw_prompt_scores = lesion_results['prompt_scores']
        try:
            prompt_scores = {int(k): v for k, v in raw_prompt_scores.items()}
        except Exception:
            prompt_scores = raw_prompt_scores

        # Print summary for each prompt
        for prompt_idx in sorted(prompt_scores.keys()):
            prompt_result = prompt_scores[prompt_idx]
            prompt = lesion_results['prompts'][int(prompt_idx)]
            print(f"\n--- Prompt {prompt_idx + 1}: {prompt[:60]}... ---")
            
            # Print scores for each layer
            scores = prompt_result['scores_and_justifications']
            try:
                scores_int = {int(k): v for k, v in scores.items()}
            except Exception:
                scores_int = scores
            for layer_idx in sorted(scores_int.keys()):
                score_info = scores_int[layer_idx]
                print(f"Layer {layer_idx}: Score {score_info['score']}/10 - {score_info['justification'][:80]}...")
            
            print("-" * 50)
        
        # Print average scores across all prompts for each layer
        print(f"\n--- AVERAGE SCORES ACROSS ALL PROMPTS ---")
        layer_scores = {}
        num_prompts = len(prompt_scores)
        
        for prompt_result in prompt_scores.values():
            scores = prompt_result['scores_and_justifications']
            try:
                scores_int = {int(k): v for k, v in scores.items()}
            except Exception:
                scores_int = scores
            for layer_idx, score_info in scores_int.items():
                if layer_idx not in layer_scores:
                    layer_scores[layer_idx] = []
                layer_scores[layer_idx].append(score_info['score'])
        
        for layer_idx in sorted(layer_scores.keys()):
            avg_score = sum(layer_scores[layer_idx]) / len(layer_scores[layer_idx])
            print(f"Layer {layer_idx}: Average Score {avg_score:.2f}/10")
        
        print("\n" + "="*80)
        print("END OF LESIONING ANALYSIS")
        print("="*80)

    def load_and_print_lesioning_results(self, lesion_file):
        """
        Loads and prints lesioning results from a saved JSON file.
        
        Args:
            lesion_file: Path to the JSON file containing lesioning results
        """
        try:
            with open(lesion_file, 'r') as f:
                lesion_results = json.load(f)
            
            print(f"\nLoading lesioning results from {lesion_file}")
            self.print_lesioning_results(lesion_results)
            
        except FileNotFoundError:
            print(f"Lesioning results file {lesion_file} not found.")
        except json.JSONDecodeError:
            print(f"Error decoding JSON from {lesion_file}")
        except Exception as e:
            print(f"Error loading lesioning results: {e}")

    def print_lesioning_table(self, lesion_results, out_pdf_path=None):
        """
        Render a PDF with key lesioning results for inclusion in papers.
        Layout per prompt (one page per prompt):
        - Prompt box (light blue), up to 3 wrapped lines
        - Original response box (light green), up to 3 wrapped lines
        - For each layer (rows alternate per layer):
            * "Layer k Response after lesioning: ..." (single row, truncated)
            * "Layer k Degradation: s/10 - justification" (single row, truncated); score colored via green→red gradient
        A colorbar legend (1→green, 10→red) is placed at the top-right.

        Args:
            lesion_results: dict as returned by run_layer_lesioning
            out_pdf_path: optional path; defaults to results/{analysis_name}_lesioning_table_{MODEL_NAME}.pdf
        """
        import textwrap
        import numpy as np

        prompts = lesion_results.get('prompts', [])
        original_responses = lesion_results.get('original_responses', [])
        all_lesioned_responses = lesion_results.get('all_lesioned_responses', {})
        raw_prompt_scores = lesion_results.get('prompt_scores', {})

        # Normalize keys possibly serialized as strings
        def _int_keyed(d):
            try:
                return {int(k): v for k, v in d.items()}
            except Exception:
                return d

        all_lesioned_responses = _int_keyed(all_lesioned_responses)
        prompt_scores = _int_keyed(raw_prompt_scores)

        # Determine output path
        if out_pdf_path is None:
            out_dir = 'results'
            os.makedirs(out_dir, exist_ok=True)
            analysis = getattr(self, 'analysis_name', 'analysis')
            out_pdf_path = f"{out_dir}/{analysis}_lesioning_table_{self.MODEL_NAME}.pdf"

        # Colormap for scores (green to orange to red, avoiding yellow)
        from matplotlib.colors import LinearSegmentedColormap
        colors = ['green', 'orange', 'red']
        cmap = LinearSegmentedColormap.from_list('custom', colors, N=256)
        norm = Normalize(vmin=1, vmax=10)

        # Utility to wrap up to max_lines and truncate with ellipsis
        def wrap_lines(text, width, max_lines):
            wrapped = textwrap.wrap((text or '').replace('\n', ' '), width=width)
            if len(wrapped) > max_lines:
                joined = '\n'.join(wrapped[:max_lines])
                return joined[:-3] + '...' if len(joined) > 3 else joined
            return '\n'.join(wrapped)

        # Utility to single-line truncate
        def truncate_single(text, width):
            return textwrap.shorten((text or '').replace('\n', ' '), width=width, placeholder='...')
        
        def escape_matplotlib_text(text):
            """Escape special characters that matplotlib might interpret as math expressions."""
            if not text:
                return text
            # Escape common math symbols that cause issues
            text = str(text).replace('$', r'\$')
            text = text.replace('\\', r'\\')
            text = text.replace('_', r'\_')
            text = text.replace('^', r'\^')
            text = text.replace('{', r'\{')
            text = text.replace('}', r'\}')
            return text

        # Helper for header blocks: up to 3 lines if needed, otherwise up to 2; ellipsis if truncated
        def wrap_for_header(text, width):
            raw = (text or '').replace('\n', ' ')
            lines = textwrap.wrap(raw, width=width)
            # if len(lines) > 2:
            #     used = lines[:3]
            #     if len(lines) > 3:
            #         used[-1] = (used[-1][:-3] + '...') if len(used[-1]) > 3 else used[-1] + '...'
            #     return used
            lines[-1] += '...'
            return lines[:2]

        base_chars = 100
        single_line_chars = base_chars
        multi_line_chars = base_chars

        for prompt_idx in sorted(range(len(prompts))):
            prompt = prompts[prompt_idx]
            orig_resp = original_responses[prompt_idx] if prompt_idx < len(original_responses) else ''

            # Layers present for this prompt
            scores_entry = prompt_scores.get(prompt_idx, {}).get('scores_and_justifications', {})
            scores_entry = _int_keyed(scores_entry) if isinstance(scores_entry, dict) else {}
            layer_ids = sorted(set(list(all_lesioned_responses.keys()) + list(scores_entry.keys())))

            # Count rows: prompt (<=3), original (<=3), 2 rows per layer
            prompt_lines = max(1, len(wrap_lines(f"Prompt: {prompt}", multi_line_chars, 3).split('\n'))) if prompt else 1
            orig_lines = max(1, len(wrap_lines(f"Original Answer: {orig_resp}", multi_line_chars, 3).split('\n'))) if orig_resp else 1
            total_rows = prompt_lines + orig_lines + 2 * len(layer_ids)

            # Figure sizing scales with number of rows/layers
            fig_width = 10
            fig_height = max(6, 0.28 * total_rows)
            fig = plt.figure(figsize=(fig_width, fig_height))

            # Main drawing axis (full width; removed colorbar)
            ax = fig.add_axes([0.05, 0.06, 0.92, 0.88])
            ax.set_xlim(0, 1)
            ax.set_ylim(0, total_rows)
            ax.axis('off')
            

            y = total_rows

            # Prompt block (light blue) with bold label
            prompt_wrapped = wrap_for_header(prompt, multi_line_chars)
            for i_line, line in enumerate(prompt_wrapped):
                y -= 1
                ax.add_patch(plt.Rectangle((0, y), 1, 1, color='#dbeafe', ec='none'))
                if i_line == 0:
                    ax.text(0.01, y + 0.5, "Prompt:", va='center', ha='left', fontsize=10, fontweight='bold', usetex=False)
                    ax.text(0.12, y + 0.5, escape_matplotlib_text(line), va='center', ha='left', fontsize=10, usetex=False)
                else:
                    ax.text(0.12, y + 0.5, escape_matplotlib_text(line), va='center', ha='left', fontsize=10, usetex=False)

            # spacing between prompt and original answer
            y -= 0.4

            # Original response block (light green) with bold label
            orig_wrapped = wrap_for_header(orig_resp, multi_line_chars)
            for i_line, line in enumerate(orig_wrapped):
                y -= 1
                ax.add_patch(plt.Rectangle((0, y), 1, 1, color='#dcfce7', ec='none'))
                if i_line == 0:
                    ax.text(0.01, y + 0.5, "Original", va='center', ha='left', fontsize=10, fontweight='bold', usetex=False)
                    ax.text(0.12, y + 0.5, escape_matplotlib_text(line), va='center', ha='left', fontsize=10, usetex=False)
                elif i_line == 1:
                    ax.text(0.01, y + 0.5, "Answer:", va='center', ha='left', fontsize=10, fontweight='bold', usetex=False)
                    ax.text(0.12, y + 0.5, escape_matplotlib_text(line), va='center', ha='left', fontsize=10, usetex=False)
                else:
                    ax.text(0.12, y + 0.5, escape_matplotlib_text(line), va='center', ha='left', fontsize=10, usetex=False)

            # Per-layer rows
            for li, layer_idx in enumerate(layer_ids):
                # Alternating background per layer (two rows)
                layer_bg = '#ffffff' if (li % 2 == 0) else '#f5f5f5'
                # Response row
                y -= 1
                les_resp = ''
                if layer_idx in all_lesioned_responses:
                    layer_list = all_lesioned_responses[layer_idx]
                    if isinstance(layer_list, dict):
                        try:
                            layer_list = _int_keyed(layer_list)
                            les_resp = layer_list.get(prompt_idx, '')
                        except Exception:
                            les_resp = ''
                    else:
                        if 0 <= prompt_idx < len(layer_list):
                            les_resp = layer_list[prompt_idx]
                # Split into label and response text
                label_text = f"Layer {layer_idx} Response after lesioning:"
                response_text = escape_matplotlib_text(truncate_single(str(les_resp), int(single_line_chars * 0.8)))
                ax.add_patch(plt.Rectangle((0, y), 1, 1, color=layer_bg, ec='none'))
                ax.text(0.01, y + 0.5, label_text, va='center', ha='left', fontsize=10, fontweight='bold', usetex=False)
                ax.text(0.1 + len(label_text) / (single_line_chars * 1.6), y + 0.5, response_text, 
                       va='center', ha='left', fontsize=10, usetex=False)

                # Degradation/justification row (score chip left)
                y -= 1
                score = None
                justification = ''
                if layer_idx in scores_entry:
                    try:
                        score = int(scores_entry[layer_idx].get('score', None))
                    except Exception:
                        _s = scores_entry[layer_idx].get('score', None)
                        score = int(_s) if _s is not None else None
                    justification = scores_entry[layer_idx].get('justification', '')

                just_line = escape_matplotlib_text(truncate_single(str(justification), single_line_chars - 12))
                ax.add_patch(plt.Rectangle((0, y), 1, 1, color=layer_bg, ec='none'))
                # Degradation label (replacing "Layer X" with just "Degradation")
                label_text = "Degradation"
                ax.text(0.01, y + 0.5, label_text, va='center', ha='left', fontsize=10, fontweight='bold', usetex=False)

                # Score as colored text (no box) - only color the score part
                score_str = f"{score}/10" if score is not None else "NA"
                score_color = cmap(norm(score)) if score is not None else (0.5, 0.5, 0.5, 1.0)
                
                # Split text into parts: colored score + " - " + justification
                separator = " - "
                
                # Calculate positions for each part
                x_start = 0.05 + len(label_text) / (single_line_chars * 1.6)
                
                # Position for colored score
                x_score = x_start
                ax.text(x_score, y + 0.5, score_str, va='center', ha='left', fontsize=10, color=score_color, usetex=False)
                
                # Position for separator and justification
                score_width = len(score_str) / (single_line_chars * 1.6) + 0.01
                x_just = x_score + score_width
                ax.text(x_just, y + 0.5, f"{separator}{just_line}", va='center', ha='left', fontsize=10, color='black', usetex=False)

                # Separator line before next layer
                ax.plot([0.0, 1.0], [y, y], color='#cccccc', linewidth=0.8)

            # Title
            fig.suptitle(f"Lesioning summary – Prompt {prompt_idx + 1}", fontsize=12)

            # Save as a separate PDF per prompt with suffix
            if out_pdf_path is None:
                out_dir = 'results'
                os.makedirs(out_dir, exist_ok=True)
                analysis = getattr(self, 'analysis_name', 'analysis')
                prompt_out = f"{out_dir}/{analysis}_lesioning_table_{self.MODEL_NAME}_prompt{prompt_idx + 1}.pdf"
            else:
                base, ext = os.path.splitext(out_pdf_path)
                prompt_out = f"{base}_prompt{prompt_idx + 1}{ext or '.pdf'}"

            with PdfPages(prompt_out) as pdf:
                pdf.savefig(fig, dpi=300)
            plt.close(fig)
            print(f"Saved lesioning table PDF → {prompt_out}")

    def runActivationPatching(self):
        """
        Runs activation patching analysis
        """
        
        # Generate  prompt pairs
        prompts = self.generate_activation_patching_prompts()


        # Run activation patching
        print(f"Starting activation patching analysis with {len(prompts)} prompt pairs...")
        
        # Check if model uses disk offloading (which can cause issues during patching)
        if hasattr(self.model, 'hf_device_map') and self.model.hf_device_map:
            print("Warning: Model uses device mapping. This should work with our explicit device maps.")
        
        # Check model compatibility for patching
        self._check_lesioning_compatibility()
        
        # Get original responses for clean and corrupt prompts
        print("Getting original model responses...")
        clean_prompts = [prompt[0] for prompt in prompts]
        clean_answers = [prompt[1] for prompt in prompts]
        corrupt_prompts = [prompt[2] for prompt in prompts]
        corrupt_answers = [prompt[3] for prompt in prompts]
        
        for clean_prompt, clean_answer, corrupt_prompt, corrupt_answer in prompts:
            print(f"{clean_prompt}")
            # print(f"Clean answer: {clean_answer}")
            print(f"{corrupt_prompt}")
            # print(f"Corrupt answer: {corrupt_answer}")
            print("--------------------------------")

        # Compute LD for clean and corrupted runs using both r and r' on the same run
        # Clean run logits (on clean prompts)
        # logit_cl(r), logit_cl(r')
        clean_r_logits, clean_rp_logits = self.get_logits_for_answers(clean_prompts, clean_answers, corrupt_answers) 
        # Corrupt run (corrupt prompts)
        # logit_*(r), logit_*(r')
        corrupt_r_logits, corrupt_rp_logits = self.get_logits_for_answers(corrupt_prompts, clean_answers, corrupt_answers) 

        # Per-run LDs
        ld_clean = {}
        ld_corrupt = {}
        for i in range(len(prompts)):
            ld_clean[i] = clean_r_logits[i] - clean_rp_logits[i]
            ld_corrupt[i] = corrupt_r_logits[i] - corrupt_rp_logits[i]

        # Denominator for normalization: LD_clean - LD_corrupt
        original_ld = {}
        for i in range(len(prompts)):
            original_ld[i] = ld_clean[i] - ld_corrupt[i] # ld_cl - ld_*
        
        # Store all patching results for each layer
        all_patching_results = {}
        
        # Test each layer and collect patching results
        for layer_idx in range(self.num_layers):
            print(f"Testing activation patching for layer {layer_idx}/{self.num_layers-1}...")
            
            # Get patching results for this layer
            layer_results = self.act_patching_params(prompts, {'layer': layer_idx}, clean_r_logits, clean_rp_logits, corrupt_r_logits, corrupt_rp_logits)
            all_patching_results[layer_idx] = layer_results
            
            print(f"Layer {layer_idx} patching results collected.")
            # Print denominator (LD_clean - LD_corrupt), patched LD, and effect
            print(f"Normalization (LD_clean - LD_corrupt): {original_ld}")
            print(f"Patched LD: {layer_results['patched_ld']}")
            print(f"Patching effect: {layer_results['patching_effect']}")
        
        # Compile final results
        patching_results = {
            # Store denominator used for normalization for transparency
            'clean_r_logits': clean_r_logits,
            'clean_rp_logits': clean_rp_logits,
            'corrupt_r_logits': corrupt_r_logits,
            'corrupt_rp_logits': corrupt_rp_logits,
            'original_ld': original_ld,
            'all_patching_results': all_patching_results,
            'prompts': prompts,
            'clean_answers': clean_answers,
            'corrupt_answers': corrupt_answers
        }
        
        # Save results
        output_file = self.activationPatchingFile
        with open(output_file, 'w') as f:
            json.dump(patching_results, f, indent=2)
        print(f"Activation patching results saved to {output_file}")
        print("Activation patching analysis complete.")

        return patching_results

        
    def runActivationPatchingFinegrained(self, load_cached):
        """
        Runs fine-grained activation patching analysis
        """
        
        # Generate  prompt pairs
        prompts = self.generate_activation_patching_prompts()
        
        # Run fine-grained activation patching
        print(f"Starting fine-grained activation patching analysis with {len(prompts)} prompt pairs...")
        
        # Check if model uses disk offloading (which can cause issues during patching)
        if hasattr(self.model, 'hf_device_map') and self.model.hf_device_map:
            print("Warning: Model uses device mapping. This should work with our explicit device maps.")
        
        # Check model compatibility for patching
        self._check_lesioning_compatibility()
        
        if load_cached:
            # Load cached results
            output_file = f"results/{self.analysis_name}_activation_patching_finegrained_{self.MODEL_NAME}.json"
            with open(output_file, 'r') as f:
                patching_results = json.load(f)
            print(f"Loaded cached fine-grained activation patching results from {output_file}")
        else:

            # Get model configuration to determine number of attention heads
            cfg = getattr(self.model, "config", getattr(self.model, "model").config)
            num_heads = getattr(cfg, "num_attention_heads", None) or getattr(cfg, "num_heads")
            print(f"Model has {num_heads} attention heads per layer")
            
            # Get original responses for clean and corrupt prompts
            print("Getting original model responses...")
            clean_prompts = [prompt[0] for prompt in prompts]
            clean_answers = [prompt[1] for prompt in prompts]
            corrupt_prompts = [prompt[2] for prompt in prompts]
            corrupt_answers = [prompt[3] for prompt in prompts]
            
            # Compute LD for clean and corrupted runs using both r and r' in one pass per prompt
            # logit_cl(r), logit_cl(r')
            clean_r_logits, clean_rp_logits = self.get_logits_for_answers(clean_prompts, clean_answers, corrupt_answers)
            # logit_*(r), logit_*(r')
            corrupt_r_logits, corrupt_rp_logits = self.get_logits_for_answers(corrupt_prompts, clean_answers, corrupt_answers)

            ld_clean = {i: (clean_r_logits[i] - clean_rp_logits[i]) for i in range(len(prompts))}
            ld_corrupt = {i: (corrupt_r_logits[i] - corrupt_rp_logits[i]) for i in range(len(prompts))}
            original_ld = {i: (ld_clean[i] - ld_corrupt[i]) for i in range(len(prompts))} # ld_cl - ld_*
            
            # Store all patching results for each layer, head, and MLP
            all_patching_results = {}
            
            # Test each layer
            for layer_idx in range(self.num_layers):
                #for layer_idx in range(1):
                print(f"Testing fine-grained activation patching for layer {layer_idx}/{self.num_layers-1}...")
                layer = self.model_layers[layer_idx]
                all_patching_results[layer_idx] = {}
                
                # Test each attention head individually
                for head_idx in range(num_heads):
                    print(f"  Testing attention head {head_idx}/{num_heads-1}...")
                    
                    # Get patching results for this specific attention head
                    head_results = self.act_patching_finegrained_params(prompts, {
                        'layer': layer_idx, 
                        'component': 'attention_head', 
                        'head_idx': head_idx,
                        'num_heads': num_heads
                    })
                    print(head_results)
                    
                    all_patching_results[layer_idx][f'head_{head_idx}'] = head_results
                    
                    print(f"    Head {head_idx} patching effect: {head_results['patching_effect']}")
                
                # Test MLP separately
                print(f"  Testing MLP...")
                mlp_results = self.act_patching_finegrained_params(prompts, {
                    'layer': layer_idx, 
                    'component': 'mlp'
                })
                all_patching_results[layer_idx]['mlp'] = mlp_results
                
                print(f"    MLP patching effect: {mlp_results['patching_effect']}")
            
            # Compile final results
            patching_results = {
                'original_ld': original_ld,
                'all_patching_results': all_patching_results,
                'prompts': prompts,
                'clean_answers': clean_answers,
                'corrupt_answers': corrupt_answers,
                'num_heads': num_heads,
                'num_layers': self.num_layers
            }
            
            # Save results
            output_file = self.activationPatchingFinegrainedFile
            with open(output_file, 'w') as f:
                json.dump(patching_results, f, indent=2)
            print(f"Fine-grained activation patching results saved to {output_file}")
        
        # Create heatmap visualization
        self.plot_finegrained_patching_heatmap(patching_results)
        print("Fine-grained activation patching analysis complete.")

        return patching_results
        
        

    def act_patching_params(self, prompts, params, clean_r_logits, clean_rp_logits, corrupt_r_logits, corrupt_rp_logits):
        """
        Helper function for activation patching with specific parameters.
        
        Args:
            prompts: List of tuples (clean_prompt, corrupt_prompt, clean_answer, corrupt_answer)
            params: Dictionary specifying what to patch (e.g., {'layer': 5})
            
        Returns:
            Dictionary containing patching results for the specified parameters
        """
        clean_prompts = [prompt[0] for prompt in prompts]
        corrupt_prompts = [prompt[1] for prompt in prompts]
        clean_answers = [prompt[2] for prompt in prompts]
        corrupt_answers = [prompt[3] for prompt in prompts]


        # LDs
        ld_clean = {}
        ld_corrupt = {}
        for i in range(len(prompts)):
            ld_clean[i] = clean_r_logits[i] - clean_rp_logits[i] # ld_cl
            ld_corrupt[i] = corrupt_r_logits[i] - corrupt_rp_logits[i] # ld_*

        # Denominator for normalization
        original_ld = {}
        for i in range(len(prompts)):
            original_ld[i] = ld_clean[i] - ld_corrupt[i] # ld_cl - ld_*
        
        layer_idx = params.get('layer', 0)
        layer = self.model_layers[layer_idx]
        
        # Store LD_patched for each prompt
        ld_patched = {}
        
        # Process each prompt pair individually
        for prompt_idx in range(len(prompts)):
            clean_prompt = clean_prompts[prompt_idx]
            corrupt_prompt = corrupt_prompts[prompt_idx]
            clean_answer = clean_answers[prompt_idx]
            corrupt_answer = corrupt_answers[prompt_idx]
            
            print(f"    Processing prompt {prompt_idx}: {clean_prompt[:50]}... vs {corrupt_prompt[:50]}...")
            
            # Capture clean activations for this specific prompt
            clean_activations = {}
            
            def clean_hook_fn(module, input, output):
                if module == layer.mlp:
                    if isinstance(output, tuple):
                        clean_activations['mlp'] = output[0].clone().detach()
                    else:
                        clean_activations['mlp'] = output.clone().detach()
                elif module == layer.self_attn:
                    if isinstance(output, tuple):
                        clean_activations['attention'] = output[0].clone().detach()
                    else:
                        clean_activations['attention'] = output.clone().detach()
            
            # Register hooks for clean run
            mlp_hook = layer.mlp.register_forward_hook(clean_hook_fn)
            attn_hook = layer.self_attn.register_forward_hook(clean_hook_fn)
            
            # Run this specific clean prompt to capture activations
            _ = self.get_logits_for_answers([clean_prompt], [clean_answer], [corrupt_answer])
            
            # Remove clean hooks
            mlp_hook.remove()
            attn_hook.remove()
            
            print(f"      Clean activations captured: {list(clean_activations.keys())}")
            
            # Now implement patching for this specific prompt
            def patching_hook_fn(module, input, output):
                # Handle input - it might be a tuple or a single tensor
                if isinstance(input, tuple):
                    if len(input) > 0:
                        input_tensor = input[0]
                    else:
                        return  # Skip if input is empty
                else:
                    input_tensor = input
                    
                if module == layer.mlp:
                    clean_mlp = clean_activations['mlp']
                    if isinstance(output, tuple):
                        output_tensor = output[0]
                    else:
                        output_tensor = output
                    
                    # Handle sequence length mismatch by truncating clean activations
                    if clean_mlp.shape[1] != output_tensor.shape[1]:
                        # Truncate clean activations to match the sequence length of the corrupt run
                        clean_mlp_truncated = clean_mlp[:, :output_tensor.shape[1], :]
                    else:
                        clean_mlp_truncated = clean_mlp
                    
                    if clean_mlp_truncated.shape == output_tensor.shape:
                        # Replace the MLP output with the clean MLP output
                        output_tensor.copy_(clean_mlp_truncated)
                    else:
                        print(f"        MLP shape mismatch after truncation: clean {clean_mlp_truncated.shape} vs output {output_tensor.shape}")
                
                elif module == layer.self_attn:
                    clean_attn = clean_activations['attention']
                    if isinstance(output, tuple):
                        if clean_attn.shape == output[0].shape:
                            # Patch the residual part: output[0] = input_tensor + clean_f(x_in)
                            output[0].copy_(input_tensor + clean_attn)
                        else:
                            print(f"        Attention shape mismatch: clean {clean_attn.shape} vs output {output[0].shape}")
                    else:
                        if clean_attn.shape == output.shape:
                            # Patch the residual part: output = input_tensor + clean_f(x_in)
                            output.copy_(input_tensor + clean_attn)
                        else:
                            print(f"        Attention shape mismatch: clean {clean_attn.shape} vs output {output.shape}")
            
            # Register patching hooks
            mlp_patch_hook = layer.mlp.register_forward_hook(patching_hook_fn)
            attn_patch_hook = layer.self_attn.register_forward_hook(patching_hook_fn)
            
            # Run patched corrupt prompt once to get logits for r and r'
            patched_r_list, patched_rp_list = self.get_logits_for_answers([corrupt_prompt], [clean_answer], [corrupt_answer])
            patched_r = patched_r_list[0] # logit_pt(r)
            patched_rp = patched_rp_list[0] # logit_pt(r')
            ld_patched[prompt_idx] = patched_r - patched_rp # ld_pt(r,r')
            
            # Remove patching hooks
            mlp_patch_hook.remove()
            attn_patch_hook.remove()
            
            # For debugging context, log LDs
            # print(f"      Prompt {prompt_idx} LD_corrupt={ld_corrupt[prompt_idx]:.6f}, LD_patched={ld_patched[prompt_idx]:.6f}")
        
        # Calculate normalized patching effect using
        # (LD_patched - LD_corrupt) / (LD_clean - LD_corrupt)
        patching_effect = {}
        for i in range(len(prompts)):
            denom = original_ld[i] # ld_cl - ld_*
            if denom != 0:  # Avoid division by zero
                patching_effect[i] = (ld_patched[i] - ld_corrupt[i]) / denom # (ld_pt - ld_*) / (ld_cl - ld_*)
            else:
                patching_effect[i] = 0.0
        
        return {
            'logit_pt': patched_r,
            'logit_pt_p': patched_rp,
            'original_ld': original_ld,  # denominator per prompt (ld_cl - ld_*)
            'patched_ld': ld_patched, # ld_pt
            'patching_effect': patching_effect, # (ld_pt - ld_*) / (ld_cl - ld_*)
            'params': params
        }

    def get_logits_for_answers(self, prompts, answers_a, answers_b):
        """
        Get logits for two specific answer tokens for a batch of prompts in a single forward pass.
        
        Args:
            prompts: List[str] prompts
            answers_a: List[str] first set of answers (r)
            answers_b: List[str] second set of answers (r')
            
        Returns:
            Tuple[List[float], List[float]]: logits for answers_a and answers_b respectively, per prompt
        """
        # Ensure a pad token exists for batching
        tok = self.processor.tokenizer if hasattr(self, 'processor') and self.processor else self.tokenizer
        if getattr(tok, 'pad_token_id', None) is None:
            # Prefer using EOS as PAD to avoid resizing embeddings
            if getattr(tok, 'eos_token', None) is not None:
                tok.pad_token = tok.eos_token
            elif getattr(tok, 'bos_token', None) is not None:
                tok.pad_token = tok.bos_token
        if hasattr(self, 'model') and hasattr(self.model, 'config') and getattr(self.model.config, 'pad_token_id', None) is None and getattr(tok, 'pad_token_id', None) is not None:
            self.model.config.pad_token_id = tok.pad_token_id

        # Batch tokenize prompts
        if hasattr(self, 'processor') and self.processor:
            inputs = self.processor(prompts, return_tensors="pt", padding=True)
        else:
            inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)

        # Move inputs to the same device as the model
        device = next(self.model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Model forward once for the whole batch
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits  # [batch, seq_len, vocab]

        # Compute last non-pad token index per prompt
        if 'attention_mask' in inputs:
            last_token_idx = inputs['attention_mask'].sum(dim=1) - 1  # [batch]
        else:
            batch_size = inputs['input_ids'].shape[0]
            last_token_idx = torch.full((batch_size,), inputs['input_ids'].shape[1] - 1, dtype=torch.long, device=device)

        # Token ids for answers per prompt (assumes single-token answers)
        if hasattr(self, 'processor') and self.processor:
            a_ids = [self.processor.tokenizer(a, add_special_tokens=False)['input_ids'][0] if self.processor.tokenizer(a, add_special_tokens=False)['input_ids'] else 0 for a in answers_a]
            b_ids = [self.processor.tokenizer(b, add_special_tokens=False)['input_ids'][0] if self.processor.tokenizer(b, add_special_tokens=False)['input_ids'] else 0 for b in answers_b]
        else:
            a_ids = [self.tokenizer(a, add_special_tokens=False)['input_ids'][0] if self.tokenizer(a, add_special_tokens=False)['input_ids'] else 0 for a in answers_a]
            b_ids = [self.tokenizer(b, add_special_tokens=False)['input_ids'][0] if self.tokenizer(b, add_special_tokens=False)['input_ids'] else 0 for b in answers_b]

        a_ids = torch.tensor(a_ids, device=device)
        b_ids = torch.tensor(b_ids, device=device)
        batch = torch.arange(logits.size(0), device=device)

        a_logits = logits[batch, last_token_idx, a_ids]
        b_logits = logits[batch, last_token_idx, b_ids]

        return a_logits.tolist(), b_logits.tolist()

    # def act_patching_finegrained(self, prompts, load_cached):
    #     """
    #     Run fine-grained activation patching across all layers for individual attention heads and MLP.
        
    #     Args:
    #         prompts: List of tuples (clean_prompt, corrupt_prompt, clean_answer, corrupt_answer)
            
    #     Returns:
    #         Dictionary containing patching results for each layer, attention head, and MLP
    #     """
        

    def act_patching_finegrained_params(self, prompts, params):
        """
        Helper function for fine-grained activation patching with specific parameters.
        
        Args:
            prompts: List of tuples (clean_prompt, corrupt_prompt, clean_answer, corrupt_answer)
            params: Dictionary specifying what to patch (e.g., {'layer': 5, 'component': 'attention_head', 'head_idx': 3})
            
        Returns:
            Dictionary containing patching results for the specified parameters
        """
        clean_prompts = [prompt[0] for prompt in prompts]
        corrupt_prompts = [prompt[1] for prompt in prompts]
        clean_answers = [prompt[2] for prompt in prompts]
        corrupt_answers = [prompt[3] for prompt in prompts]
        
        # Compute LD for clean and corrupted runs using both r and r'
        clean_r_logits, clean_rp_logits = self.get_logits_for_answers(clean_prompts, clean_answers, corrupt_answers)
        corrupt_r_logits, corrupt_rp_logits = self.get_logits_for_answers(corrupt_prompts, clean_answers, corrupt_answers)

        # LDs
        ld_clean = {i: (clean_r_logits[i] - clean_rp_logits[i]) for i in range(len(prompts))}
        ld_corrupt = {i: (corrupt_r_logits[i] - corrupt_rp_logits[i]) for i in range(len(prompts))}

        # Denominator for normalization
        original_ld = {i: (ld_clean[i] - ld_corrupt[i]) for i in range(len(prompts))}
        
        layer_idx = params.get('layer', 0)
        component = params.get('component', 'mlp')
        layer = self.model_layers[layer_idx]
        
        # Capture clean activations for the entire batch at once
        clean_activations = {}

        def clean_hook_fn(module, input, output):
            if component == 'mlp' and module == layer.mlp:
                if isinstance(output, tuple):
                    clean_activations['mlp'] = output[0].detach().clone()
                else:
                    clean_activations['mlp'] = output.detach().clone()

        def clean_o_proj_pre_hook_fn(module, inputs):
            # inputs is a tuple; first element is the concatenated per-head tensor [B, T, H]
            if inputs and len(inputs) > 0:
                clean_activations['attn_pre_o'] = inputs[0].detach().clone()

        # Register hooks for batch clean run
        if component == 'mlp':
            hook = layer.mlp.register_forward_hook(clean_hook_fn)
        else:  # attention_head
            hook = layer.self_attn.o_proj.register_forward_pre_hook(clean_o_proj_pre_hook_fn)

        # Run full clean batch to capture activations
        _ = self.get_logits_for_answers(clean_prompts, clean_answers, corrupt_answers)

        # Remove clean hooks
        hook.remove()
        

        # Define patching hooks that operate over the batch
        def mlp_patching_hook_fn(module, input, output):
            clean_mlp = clean_activations.get('mlp', None)
            if clean_mlp is None:
                return
            
            if isinstance(output, tuple):
                output_tensor = output[0]
            else:
                output_tensor = output
                
            # Handle sequence length mismatch by truncating clean activations
            if clean_mlp.shape[1] != output_tensor.shape[1]:
                # Truncate clean activations to match the sequence length of the corrupt run
                clean_mlp_truncated = clean_mlp[:, :output_tensor.shape[1], :]
            else:
                clean_mlp_truncated = clean_mlp
                
            if clean_mlp_truncated.shape == output_tensor.shape:
                # Replace the MLP output with the clean MLP output
                output_tensor.copy_(clean_mlp_truncated)

        def o_proj_patching_pre_hook_fn(module, inputs):
            if not inputs or len(inputs) == 0:
                return
            x = inputs[0]
            clean_pre_o = clean_activations.get('attn_pre_o', None)
            if clean_pre_o is None:
                return
            head_idx = params.get('head_idx', 0)
            num_heads = params.get('num_heads', 1)
            head_dim = clean_pre_o.shape[-1] // max(1, num_heads)
            start_idx = head_idx * head_dim
            end_idx = (head_idx + 1) * head_dim
            patched = x.clone()
            # Replace only the final token position to avoid seq length mismatches
            patched[:, -1, start_idx:end_idx] = clean_pre_o[:, -1, start_idx:end_idx]
            return (patched,)

        # Register patching hooks for batch corrupt run
        if component == 'mlp':
            patch_hook = layer.mlp.register_forward_hook(mlp_patching_hook_fn)
        else:  # attention_head
            patch_hook = layer.self_attn.o_proj.register_forward_pre_hook(o_proj_patching_pre_hook_fn)

        # Run patched corrupt batch once and compute LD_patched per prompt
        patched_r_list, patched_rp_list = self.get_logits_for_answers(corrupt_prompts, clean_answers, corrupt_answers)
        ld_patched = {i: (patched_r_list[i] - patched_rp_list[i]) for i in range(len(prompts))}

        # Remove patching hooks
        patch_hook.remove()

        # Calculate normalized patching effect: (LD_patched - LD_corrupt) / (LD_clean - LD_corrupt)
        patching_effect = {}
        for i in range(len(prompts)):
            denom = original_ld[i]
            if denom != 0:
                patching_effect[i] = (ld_patched[i] - ld_corrupt[i]) / denom
            else:
                patching_effect[i] = 0.0

        return {
            'original_ld': original_ld,
            'patched_ld': ld_patched,
            'patching_effect': patching_effect,
            'logit_pt': patched_r_list,
            'logit_pt_p': patched_rp_list,
            'params': params
        }

    def plot_finegrained_patching_heatmap(self, patching_results, dolog=False):
        """
        Create a heatmap visualization of fine-grained activation patching results.
        
        Args:
            patching_results: Dictionary containing patching results from act_patching_finegrained
        """
        print("Creating fine-grained activation patching heatmap...")
        
        num_layers = patching_results['num_layers']
        num_heads = patching_results['num_heads']
        all_results = patching_results['all_patching_results']
        
        # Create a matrix to store patching effects
        # Rows: layers, Columns: attention heads + MLP
        heatmap_data = np.zeros((num_layers, num_heads + 1))  # +1 for MLP column
        
        # Fill the matrix with patching effects
        for layer_idx in range(num_layers):
            layer_key_str = str(layer_idx)
            # Support both string and int keys
            layer_results = all_results.get(layer_key_str)
            if layer_results is None:
                layer_results = all_results.get(layer_idx)
            if layer_results is None:
                print(f"Warning: No results found for layer {layer_idx}")
                continue
            
            # Fill attention head results
            for head_idx in range(num_heads):
                head_key = f'head_{head_idx}'
                if head_key in layer_results:
                    # Average patching effect across all prompts
                    patching_effects = layer_results[head_key]['patching_effect']
                    avg_effect = np.mean(list(patching_effects.values()))
                    heatmap_data[layer_idx, head_idx] = avg_effect
            
            # Fill MLP result (last column)
            if 'mlp' in layer_results:
                patching_effects = layer_results['mlp']['patching_effect']
                avg_effect = np.mean(list(patching_effects.values()))
                heatmap_data[layer_idx, num_heads] = avg_effect
        
        # Create the heatmap
        fig, ax = plt.subplots(figsize=(max(12, num_heads * 0.5), max(8, num_layers * 0.3)))
        
        # Create custom colormap from white to purple
        from matplotlib.colors import LinearSegmentedColormap
        colors = ['white', 'lightblue', 'blue', 'purple', 'indigo']
        n_bins = 100
        cmap = LinearSegmentedColormap.from_list('white_to_purple', colors, N=n_bins)
        
        if dolog:
            # Transform to log-space magnitude so color ~ log1p(|effect|)
            heatmap_data = np.log(np.abs(heatmap_data))

        
        vmax_val = float(np.nanmax(heatmap_data)) if np.isfinite(np.nanmax(heatmap_data)) else 1.0
        # if vmax_val == 0:
        #     vmax_val = 1.0
        
        # Plot the heatmap in log-space
        im = ax.imshow(heatmap_data, cmap=cmap, aspect='auto', vmin=0, vmax=vmax_val)
        
        # Set labels
        ax.set_xlabel('Attention Head / MLP', fontsize=24)
        ax.set_ylabel('Layer', fontsize=24)
        ax.set_title(f'Fine-grained Activation Patching Effects\n{self.MODEL_NAME}', fontsize=14)
        
        # Set x-axis labels (attention heads + MLP)
        x_labels = [f'H{i}' for i in range(num_heads)] + ['MLP']
        ax.set_xticks(range(num_heads + 1))
        ax.set_xticklabels(x_labels, fontsize=20)
        
        # Set y-axis labels (layers)
        ax.set_yticks(range(num_layers))
        ax.set_yticklabels([f'L{i}' for i in range(num_layers)])
        
        # Add a vertical line to separate attention heads from MLP
        ax.axvline(x=num_heads - 0.5, color='black', linewidth=2)
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax, shrink=0.8)
        cbar.set_label('log1p(|patching effect|)' if dolog else 'Patching effect', fontsize=10)
        
        # Add text annotations for values
        for i in range(num_layers):
            for j in range(num_heads + 1):
                value = heatmap_data[i, j]
                text_color = 'black' if value < (0.5 * vmax_val) else 'white'
                ax.text(j, i, f'{value:.2f}', ha='center', va='center', 
                       color=text_color, fontsize=8)
        
        # Adjust layout
        plt.tight_layout()
        
        # Save the plot
        output_file = f"results/{self.analysis_name}_activation_patching_finegrained_heatmap_{self.MODEL_NAME}.pdf"
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"Fine-grained activation patching heatmap saved to {output_file}")
        plt.show()
        
        plt.close()

    def heatmap(self, saliency_results=None, lesioning_results=None, patching_results=None, load_cached=True):
        """
        Create an integrated RGB heatmap combining saliency, lesioning, and activation patching analyses.
        
        Args:
            saliency_results: Results from saliency analysis (optional)
            lesioning_results: Results from fine-grained lesioning analysis (optional) 
            patching_results: Results from fine-grained activation patching analysis (optional)
        
        RGB Channels:
            - R (Red): Lesioning disruption scores (0-10 scale)
            - G (Green): Saliency scores (log space)
            - B (Blue): Activation patching effects (log space)
        
        Color mapping:
            - White (0,0,0): No effects in any analysis
            - Black (1,1,1): High effects in all analyses
        """
        import numpy as np
        import matplotlib.pyplot as plt
        from matplotlib.colors import Normalize
        
        print("Creating integrated RGB heatmap...")
        
        # Load cached results if not provided
        if load_cached:
            if saliency_results is None:
                print("Loading cached saliency results...")
                saliency_results = self.load_cached('saliency')
            if lesioning_results is None:
                print("Loading cached lesioning results...")
                lesioning_results = self.load_cached('lesioning_finegrained')
            if patching_results is None:
                print("Loading cached patching results...")
                patching_results = self.load_cached('activation_patching_finegrained')
        
        # Get model configuration
        if hasattr(self, 'model') and self.model is not None:
            cfg = getattr(self.model, "config", getattr(self.model, "model").config)
            if hasattr(cfg, "text_config"):
                num_heads = getattr(cfg.text_config, "num_attention_heads", None) or getattr(cfg.text_config, "num_heads", None)
                num_layers = getattr(cfg.text_config, "num_hidden_layers", None) or getattr(cfg.text_config, "n_layers", None)
            else:
                num_heads = getattr(cfg, "num_attention_heads", None) or getattr(cfg, "num_heads")
                num_layers = getattr(cfg, "num_hidden_layers", None) or getattr(cfg, "n_layers")
        else:
            # Fallback to hardcoded values for common models
            if 'Llama-3.2-1B' in self.MODEL_NAME:
                num_heads = 8
                num_layers = 16
            elif 'Llama-3.1-8B' in self.MODEL_NAME:
                num_heads = 32
                num_layers = 32
            elif 'Llama-3.3-70B' in self.MODEL_NAME:
                num_heads = 64
                num_layers = 80
            else:
                # Default fallback
                num_heads = 8
                num_layers = 16
                print(f"Warning: Using default model config for {self.MODEL_NAME}")
        
        print(f"Using model config: {num_heads} heads, {num_layers} layers")
        
        # Initialize RGB channels (layers x (heads + MLP))
        n_cols = num_heads + 1  # +1 for MLP
        n_rows = num_layers
        
        # Initialize CMY arrays (subtractive color model)
        C_channel = np.zeros((n_rows, n_cols))  # Lesioning (Cyan)
        M_channel = np.zeros((n_rows, n_cols))  # Saliency (Magenta) 
        Y_channel = np.zeros((n_rows, n_cols))  # Patching (Yellow)
        
        # Process lesioning results (Cyan channel)
        if lesioning_results is not None:
            print("Processing lesioning results...")
            print(f"Lesioning results keys: {list(lesioning_results.keys())}")
            fine = lesioning_results.get('finegrained_responses', {})
            per_prompt = lesioning_results.get('per_prompt_analysis', {})
            print(f"Fine responses layers: {list(fine.keys())}")
            print(f"Per prompt keys: {list(per_prompt.keys())}")
            
            # Parse lesioning scores using the same logic as plot_finegrained_lesioning_heatmap
            import re
            re_head = re.compile(r"Head\s+(\d+)\s*:\s*Score\s*([0-9]+(?:\.[0-9]+)?)", re.IGNORECASE)
            re_mlp = re.compile(r"MLP\s*:\s*Score\s*([0-9]+(?:\.[0-9]+)?)", re.IGNORECASE)
            
            cell_scores = [[[] for _ in range(n_cols)] for _ in range(n_rows)]
            
            for p_idx, layer_map in per_prompt.items():
                print(f"Processing prompt {p_idx}, layers: {list(layer_map.keys())}")
                for layer_key, txt in layer_map.items():
                    try:
                        layer_idx = int(layer_key)
                    except Exception:
                        continue
                    if layer_idx < 0 or layer_idx >= n_rows:
                        continue
                    if not isinstance(txt, str):
                        continue
                    
                    # Parse head scores
                    head_matches = 0
                    for m in re_head.finditer(txt):
                        h = int(m.group(1))
                        try:
                            s = float(m.group(2))
                            head_matches += 1
                        except Exception:
                            continue
                        if 0 <= h < num_heads:
                            cell_scores[layer_idx][h].append(s)
                    
                    # Parse MLP score
                    mlp_matches = 0
                    m2 = re_mlp.search(txt)
                    if m2:
                        try:
                            s = float(m2.group(1))
                            mlp_matches += 1
                            cell_scores[layer_idx][num_heads].append(s)
                        except Exception:
                            pass
                    
                    if head_matches > 0 or mlp_matches > 0:
                        print(f"  Layer {layer_idx}: {head_matches} head matches, {mlp_matches} MLP matches")
            
            # Fill C channel with average lesioning scores
            for li in range(n_rows):
                for hj in range(n_cols):
                    vals = cell_scores[li][hj]
                    C_channel[li, hj] = float(np.mean(vals)) if len(vals) > 0 else 0.0
            
            print(f"C channel stats: min={np.min(C_channel):.3f}, max={np.max(C_channel):.3f}, mean={np.mean(C_channel):.3f}")
        else:
            print("No lesioning results provided")
        
        # Process saliency results (Magenta channel)
        if saliency_results is not None:
            print("Processing saliency results...")
            print(f"Saliency results keys: {list(saliency_results.keys())}")
            
            # Try different possible keys for saliency data
            avg_sal_per_head = saliency_results.get('avg_sal_per_head', {})
            avg_sal_per_mlp = saliency_results.get('avg_sal_per_mlp', {})
            
            # If not found, try to extract from avg_sal_per_prompt structure
            if not avg_sal_per_head and 'avg_sal_per_prompt' in saliency_results:
                print("Trying to extract saliency from avg_sal_per_prompt...")
                avg_sal_per_prompt = saliency_results['avg_sal_per_prompt']
                print(f"avg_sal_per_prompt type: {type(avg_sal_per_prompt)}")
                if isinstance(avg_sal_per_prompt, dict):
                    print(f"avg_sal_per_prompt keys: {list(avg_sal_per_prompt.keys())}")
                    
                    # Extract saliency from the first disease (they should be similar)
                    first_disease = list(avg_sal_per_prompt.keys())[0]
                    disease_saliency = avg_sal_per_prompt[first_disease]
                    print(f"Using saliency from disease: {first_disease}")
                    
                    # Parse attention head saliency from parameter names
                    for param_name, saliency_value in disease_saliency.items():
                        if 'self_attn.q_proj.weight' in param_name:
                            # Extract layer from parameter name like "model.layers.5.self_attn.q_proj.weight"
                            try:
                                layer_part = param_name.split('model.layers.')[1].split('.self_attn')[0]
                                layer_idx = int(layer_part)
                                if 0 <= layer_idx < n_rows:
                                    # Average across all heads for this layer (q_proj represents all heads)
                                    M_channel[layer_idx, :num_heads] = saliency_value
                            except (ValueError, IndexError):
                                continue
                        elif 'mlp.gate_proj.weight' in param_name:
                            # Extract layer from parameter name like "model.layers.5.mlp.gate_proj.weight"
                            try:
                                layer_part = param_name.split('model.layers.')[1].split('.mlp')[0]
                                layer_idx = int(layer_part)
                                if 0 <= layer_idx < n_rows:
                                    M_channel[layer_idx, num_heads] = saliency_value
                            except (ValueError, IndexError):
                                continue
            
            print(f"Head saliency entries: {len(avg_sal_per_head)}")
            print(f"MLP saliency entries: {len(avg_sal_per_mlp)}")
            
            # Fill attention head saliency (if available)
            for (layer_idx, head_idx), saliency in avg_sal_per_head.items():
                try:
                    layer_idx = int(layer_idx)
                    head_idx = int(head_idx)
                    if 0 <= layer_idx < n_rows and 0 <= head_idx < num_heads:
                        M_channel[layer_idx, head_idx] = saliency
                except (ValueError, TypeError):
                    continue
            
            # Fill MLP saliency (if available)
            for layer_idx, saliency in avg_sal_per_mlp.items():
                try:
                    layer_idx = int(layer_idx)
                    if 0 <= layer_idx < n_rows:
                        M_channel[layer_idx, num_heads] = saliency
                except (ValueError, TypeError):
                    continue
            
            print(f"M channel stats: min={np.min(M_channel):.3f}, max={np.max(M_channel):.3f}, mean={np.mean(M_channel):.3f}")
        else:
            print("No saliency results provided")
        
        # Process activation patching results (Yellow channel)
        if patching_results is not None:
            print("Processing activation patching results...")
            print(f"Patching results keys: {list(patching_results.keys())}")
            all_results = patching_results.get('all_patching_results', {})
            print(f"All results layers: {list(all_results.keys())}")
            
            for layer_idx in range(n_rows):
                layer_key_str = str(layer_idx)
                layer_results = all_results.get(layer_key_str) or all_results.get(layer_idx)
                if layer_results is None:
                    continue
                
                # Fill attention head patching effects
                for head_idx in range(num_heads):
                    head_key = f'head_{head_idx}'
                    if head_key in layer_results:
                        patching_effects = layer_results[head_key]['patching_effect']
                        avg_effect = np.mean(list(patching_effects.values()))
                        Y_channel[layer_idx, head_idx] = avg_effect
                
                # Fill MLP patching effect
                if 'mlp' in layer_results:
                    patching_effects = layer_results['mlp']['patching_effect']
                    avg_effect = np.mean(list(patching_effects.values()))
                    Y_channel[layer_idx, num_heads] = avg_effect
            
            print(f"Y channel stats: min={np.min(Y_channel):.3f}, max={np.max(Y_channel):.3f}, mean={np.mean(Y_channel):.3f}")
        else:
            print("No patching results provided")
        
        # Normalize each channel to [0,1] for subtractive color model
        def normalize(channel):
            if np.max(channel) > 0:
                normalized = channel / np.max(channel)
                return normalized
            return np.zeros_like(channel)  # All white if no data
        
        C_norm = normalize(C_channel)
        M_norm = normalize(M_channel) 
        Y_norm = normalize(Y_channel)
        
        # Ensure all values are in [0,1] range
        C_norm = np.clip(C_norm, 0, 1)
        M_norm = np.clip(M_norm, 0, 1)
        Y_norm = np.clip(Y_norm, 0, 1)
        
        print(f"Normalized C channel stats: min={np.min(C_norm):.3f}, max={np.max(C_norm):.3f}, mean={np.mean(C_norm):.3f}")
        print(f"Normalized M channel stats: min={np.min(M_norm):.3f}, max={np.max(M_norm):.3f}, mean={np.mean(M_norm):.3f}")
        print(f"Normalized Y channel stats: min={np.min(Y_norm):.3f}, max={np.max(Y_norm):.3f}, mean={np.mean(Y_norm):.3f}")
        
        # Convert CMY to RGB for display (subtractive to additive conversion)
        # RGB = 1 - CMY for subtractive color model
        R_display = 1 - C_norm
        G_display = 1 - M_norm  
        B_display = 1 - Y_norm
        
        # Combine into RGB image for display
        rgb_image = np.stack([R_display, G_display, B_display], axis=-1)
        
        print(f"RGB image shape: {rgb_image.shape}")
        print(f"RGB image stats: min={np.min(rgb_image):.3f}, max={np.max(rgb_image):.3f}, mean={np.mean(rgb_image):.3f}")
        
        # Print some sample RGB values for debugging
        print("Sample RGB values (first 5x5):")
        for i in range(min(5, n_rows)):
            for j in range(min(5, n_cols)):
                r, g, b = rgb_image[i, j]
                print(f"  [{i},{j}]: R={r:.3f}, G={g:.3f}, B={b:.3f}")
        
        # Create the plot
        fig, ax = plt.subplots(figsize=(max(12, n_cols * 0.6), max(8, n_rows * 0.4)))
        
        # Display RGB image
        ax.imshow(rgb_image, aspect='auto')
        
        # Set labels
        ax.set_xlabel('Attention Head / MLP', fontsize=30)
        ax.set_ylabel('Layer', fontsize=30)
        ax.set_title(f'Integrated Analysis Heatmap\n{self.MODEL_NAME}\nC:Cyan M:Magenta Y:Yellow', fontsize=30)
        
        # Set ticks and labels
        x_labels = [f'H{i}' for i in range(num_heads)] + ['MLP']
        ax.set_xticks(range(n_cols))
        ax.set_xticklabels(x_labels, fontsize=30)
        ax.set_yticks(range(n_rows))
        ax.set_yticklabels([f'L{i}' for i in range(n_rows)], fontsize=30)
        
        # Show every other tick label
        ax.set_xticks(range(0, n_cols, 2))
        ax.set_xticklabels([x_labels[i] for i in range(0, n_cols, 2)], fontsize=30)
        ax.set_yticks(range(0, n_rows, 2))
        ax.set_yticklabels([f'L{i}' for i in range(0, n_rows, 2)], fontsize=30)
        

        # Add vertical line to separate attention heads from MLP
        ax.axvline(x=num_heads - 0.5, color='black', linewidth=2)
        
        # Add color legend
        legend_elements = [
            plt.Rectangle((0,0),1,1, facecolor='cyan', alpha=0.7, label='Lesioning (C)'),
            plt.Rectangle((0,0),1,1, facecolor='magenta', alpha=0.7, label='Saliency (M)'),
            plt.Rectangle((0,0),1,1, facecolor='yellow', alpha=0.7, label='Patching (Y)'),
            plt.Rectangle((0,0),1,1, facecolor='blue', alpha=0.7, label='Lesioning + Saliency'),
            plt.Rectangle((0,0),1,1, facecolor='green', alpha=0.7, label='Lesioning + Patching'),
            plt.Rectangle((0,0),1,1, facecolor='red', alpha=0.7, label='Saliency + Patching'),
            plt.Rectangle((0,0),1,1, facecolor='white', edgecolor='black', linewidth=1, label='No effects'),
            plt.Rectangle((0,0),1,1, facecolor='black', label='All effects')
        ]
        
        # Create legend with title
        legend = ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.05, 1), fontsize=30)
        legend.set_title("Primary Effects:\nDual Effects:", prop={'size': 30, 'weight': 'bold'})
        
        plt.tight_layout()
        
        # Save the plot
        output_file = f"results/{self.analysis_name}_integrated_heatmap_{self.MODEL_NAME}.pdf"
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"Integrated RGB heatmap saved to {output_file}")
        
        plt.show()
        plt.close()
        
        return rgb_image