# probe_gemma_4b_singlepass_with_json.py
# Single-pass Gemma-4B probing pipeline (GPU-only activations + GPU metrics)
# - captures activations for many layers/components in one forward pass
# - computes Fisher separability, MMD (RBF), Wasserstein (1D Gaussian approx) on GPU
# - only transfers small metric matrices to CPU for plotting
#
# Usage: run on a machine with a CUDA GPU and a suitable HF token if required.
# Make sure you have transformers and torch installed. If you want 8-bit quantization
# later, you can adapt the model.from_pretrained call to use BitsAndBytesConfig (not included here).

import json
import gc
import os
import random
import itertools
import pathlib
import re
import warnings
import math
from typing import List, Dict, Iterable, Literal, Optional

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import IterableDataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm

# silence some noisy but harmless warnings
warnings.filterwarnings("ignore", message=".*not compiled with flash attention.*", category=UserWarning)

# ---------------- User / Model config ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "google/gemma-3-4b-it"
HF_TOKEN = os.environ.get("HF_TOKEN", "hf_token enter here")  # set HF_TOKEN if your model requires auth

PROBE_OUTPUT_DIR = "Your directory path goes here"
os.makedirs(PROBE_OUTPUT_DIR, exist_ok=True)

DOMAIN_JSONS = [
   "General_datasets/finance.json",
    "General_datasets/medical.json",
    "General_datasets/math.json",
    "General_datasets/science.json",
    "Coding_datasets/cpp_top.json",
    "Coding_datasets/python_top.json" 
]
# -----------------------------------------------------

# ------------- Load tokenizer & model --------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, token=HF_TOKEN)
# keep bfloat16 for speed / memory; optionally swap to quantized later
model = (AutoModelForCausalLM
         .from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16,
                          low_cpu_mem_usage=True, token=HF_TOKEN)
         .to(DEVICE)
         .eval())
if getattr(tokenizer, "pad_token", None) is None:
    tokenizer.pad_token = tokenizer.eos_token

def get_num_layers(m):
    for attr in ("model", "transformer", "transformer_model", "decoder", "gemma", "language_model"):
        if hasattr(m, attr):
            root = getattr(m, attr)
            layers = getattr(root, "layers", None) or getattr(root, "h", None) or getattr(root, "block", None)
            if layers is not None:
                try:
                    return len(layers)
                except Exception:
                    pass
    try:
        return len(m.model.layers)
    except Exception:
        return None

TOTAL_LAYERS = 34
print(f"Loaded {MODEL_NAME} on {DEVICE}. Detected TOTAL_LAYERS = {TOTAL_LAYERS}")
# -----------------------------------------------------

# ---------------- Dataset ----------------
class DomainIterableDataset(IterableDataset):
    def __init__(self, json_paths: List[str], max_per_domain: int = 1000, shuffle: bool = True):
        self.paths = [pathlib.Path(p) for p in json_paths]
        self.max_per_domain = max_per_domain
        self.shuffle = shuffle
        self.domain2id = {p: i for i, p in enumerate(self.paths)}
        self.domain_names = [p.stem for p in self.paths]

    def _sample_texts(self, file_path: pathlib.Path) -> Iterable[str]:
        with open(file_path, "r", encoding="utf-8") as f:
            texts = json.load(f)
        if not texts:
            return []
        if self.shuffle:
            random.shuffle(texts)
        return texts[:self.max_per_domain]

    def __iter__(self):
        files = self.paths
        if self.shuffle:
            random.shuffle(files)
        for fp in files:
            d_id = self.domain2id[fp]
            for txt in self._sample_texts(fp):
                yield txt, d_id

def collate(batch):
    texts, labels = zip(*batch)
    return list(texts), list(labels)
# -------------------------------------------

# ---------- Multi-layer activation grabber (single-pass) ----------
class MultiLayerActivationGrabber:
    """
    Registers forward hooks across requested layers and components and stores pooled outputs on DEVICE.
    Designed to work with Gemma naming (self_attn.o_proj, mlp.{up_proj,gate_proj,down_proj}, layernorms).
    """
    def __init__(self, model, layers: List[int], components: List[Literal["attn","mlp","resid"]], seq_pool: str = "mean"):
        self.model = model
        self.layers_idx = layers
        self.components = components
        self.seq_pool = seq_pool
        # buffers[component][layer] -> list of [batch, dim] tensors (kept on DEVICE)
        self.buffers: Dict[str, Dict[int, List[torch.Tensor]]] = {c: {L: [] for L in layers} for c in components}
        self.handles = []
        self._registered = False

    def _pool_tensor(self, x: torch.Tensor) -> torch.Tensor:
        # pooling on GPU
        if self.seq_pool == "mean":
            return x.mean(dim=1)
        elif self.seq_pool == "first":
            return x[:, 0]
        else:
            raise ValueError(f"Unsupported seq_pool {self.seq_pool}")

    def _save_hook_factory(self, comp: str, L: int):
        def hook(module, input, out):
            out_t = out[0] if isinstance(out, tuple) else out
            pooled = self._pool_tensor(out_t)      # [batch, dim] on DEVICE
            # detach and append (keeps on DEVICE)
            self.buffers[comp][L].append(pooled.detach())
        return hook

    def _find_layer_module(self, L: int):
        """
        Robust search for a layer module indexed by L. Handles Gemma variants where
        transformer blocks may live under model.language_model.[encoder/decoder/model], etc.
        """
        from torch import nn

        # Collect candidate roots to search
        candidate_roots = []
        for root_attr in ("model", "transformer", "transformer_model", "decoder", "gemma", "language_model"):
            if hasattr(self.model, root_attr):
                candidate_roots.append(getattr(self.model, root_attr))
        if hasattr(self.model, "model"):
            candidate_roots.append(getattr(self.model, "model"))
        if not candidate_roots:
            candidate_roots = [self.model]

        def is_indexable_container(obj):
            return hasattr(obj, "__len__") and hasattr(obj, "__getitem__")

        # 1) Look for obvious container attributes under each root
        for root in candidate_roots:
            for candidate_name in ("layers", "h", "block", "blocks", "encoder", "decoder", "transformer"):
                if hasattr(root, candidate_name):
                    cont = getattr(root, candidate_name)
                    if is_indexable_container(cont):
                        try:
                            if 0 <= L < len(cont):
                                return cont[L]
                        except Exception:
                            continue
            # also try root.model.* under root
            if hasattr(root, "model"):
                sub = getattr(root, "model")
                for candidate_name in ("layers", "h", "block", "blocks", "encoder", "decoder"):
                    if hasattr(sub, candidate_name):
                        cont = getattr(sub, candidate_name)
                        if is_indexable_container(cont):
                            try:
                                if 0 <= L < len(cont):
                                    return cont[L]
                            except Exception:
                                continue

        # 2) Search named_modules for containers (list/ModuleList) whose elements look like layers
        for name, sub in self.model.named_modules():
            for attr_name in dir(sub):
                if attr_name.startswith("_"):
                    continue
                try:
                    attr = getattr(sub, attr_name)
                except Exception:
                    continue
                if isinstance(attr, (list, nn.ModuleList, tuple)):
                    try:
                        if len(attr) == 0:
                            continue
                    except Exception:
                        continue
                    first = attr[0]
                    if any(hasattr(first, a) for a in ("self_attn", "mlp", "attention", "feed_forward", "ffn", "out_proj", "up_proj")):
                        if 0 <= L < len(attr):
                            return attr[L]

        # 3) Build candidates by scanning modules whose children look like transformer layers
        layer_candidates = []
        for nm, sub in self.model.named_modules():
            child_attrs = [c for _, c in sub.named_children()]
            if not child_attrs:
                continue
            if any(any(hasattr(ch, a) for a in ("self_attn", "mlp", "attention", "feed_forward", "ffn", "out_proj", "up_proj")) for ch in child_attrs):
                layer_candidates.append(sub)
        if layer_candidates:
            if 0 <= L < len(layer_candidates):
                return layer_candidates[L]

        # diagnostics to help debugging if not found
        diagnostics = []
        for root in candidate_roots:
            for candidate_name in ("layers", "h", "block", "blocks", "encoder", "decoder", "transformer"):
                if hasattr(root, candidate_name):
                    cont = getattr(root, candidate_name)
                    try:
                        diagnostics.append(f"{root.__class__.__name__}.{candidate_name} (len={len(cont)})")
                    except Exception:
                        diagnostics.append(f"{root.__class__.__name__}.{candidate_name} (len=n/a)")

        nm_sample = []
        for i, (n, sub) in enumerate(self.model.named_modules()):
            nm_sample.append((i, n, sub.__class__.__name__))
            if i >= 120:
                break

        raise RuntimeError(
            f"Layer {L} not found in model. Diagnostics: candidate containers: {diagnostics}\n"
            f"Sample named_modules[0..120]: {nm_sample[:30]} ...\n"
            f"Use the printed diagnostics to inspect model structure."
        )

    def _find_module_for(self, layer_module, component: str):
        # Gemma-style: try clear attribute names first, else fallback to name-scan
        if component == "attn":
            if hasattr(layer_module, "self_attn") and hasattr(layer_module.self_attn, "o_proj"):
                return layer_module.self_attn.o_proj
            for name, sub in layer_module.named_modules():
                if re.search(r"(o_proj|out_proj|wo|self_attn|out_proj|out_proj)", name, flags=re.IGNORECASE):
                    return sub
        elif component == "mlp":
            if hasattr(layer_module, "mlp"):
                mlp = layer_module.mlp
                for cand in ("up_proj", "gate_proj", "down_proj", "fc1", "fc2"):
                    if hasattr(mlp, cand):
                        return getattr(mlp, cand)
            for name, sub in layer_module.named_modules():
                if re.search(r"(mlp|up_proj|gate_proj|down_proj|wo|fc1|fc2)", name, flags=re.IGNORECASE):
                    return sub
        elif component == "resid":
            for cand in ("post_feedforward_layernorm", "post_attention_layernorm", "post_norm", "layernorm2", "layer_norm2"):
                if hasattr(layer_module, cand):
                    return getattr(layer_module, cand)
            # fallback to full layer module (captures block output)
            return layer_module
        raise RuntimeError(f"Module for component {component} not found in layer")

    def register_hooks(self):
        if self._registered:
            return
        for comp in self.components:
            for L in self.layers_idx:
                layer_module = self._find_layer_module(L)
                mod = self._find_module_for(layer_module, comp)
                h = mod.register_forward_hook(self._save_hook_factory(comp, L))
                self.handles.append(h)
        self._registered = True

    def remove_hooks(self):
        for h in self.handles:
            try:
                h.remove()
            except Exception:
                pass
        self.handles = []
        self._registered = False

    def clear_buffers(self):
        self.buffers = {c: {L: [] for L in self.layers_idx} for c in self.components}

    def stacked_features(self, component: str, layer: int) -> Optional[torch.Tensor]:
        lst = self.buffers[component][layer]
        if not lst:
            return None
        return torch.cat(lst, dim=0)  # [N, D] on DEVICE
# ------------------------------------------------------------------

# ---------- Single-pass collection (keeps everything on GPU) ----------
def collect_all_features_gpu(dataloader: DataLoader, layers: List[int], components: List[str], seq_pool: str = "mean", n_examples: Optional[int] = None):
    grabber = MultiLayerActivationGrabber(model, layers, components, seq_pool=seq_pool)
    grabber.clear_buffers()
    grabber.register_hooks()

    labels_accum = []
    seen = 0
    with torch.no_grad():
        for batch_texts, batch_labels in tqdm(dataloader, desc="Collecting activations (single-pass)"):
            toks = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
            _ = model(**toks)
            labels_accum.extend(batch_labels)
            seen += len(batch_labels)
            if n_examples is not None and seen >= n_examples:
                break

    grabber.remove_hooks()

    # labels -> tensor on DEVICE
    y = torch.tensor(labels_accum, dtype=torch.long, device=DEVICE)

    features: Dict[str, Dict[int, torch.Tensor]] = {c: {} for c in components}
    for c in components:
        for L in layers:
            stacked = grabber.stacked_features(c, L)  # on DEVICE or None
            if stacked is None:
                features[c][L] = torch.empty((0,0), device=DEVICE)
            else:
                features[c][L] = stacked  # kept on DEVICE

    grabber.clear_buffers()
    gc.collect()
    torch.cuda.empty_cache()
    return features, y
# ---------------------------------------------------------------------

# ---------- GPU metric computations (Fisher, MMD, Wasserstein) ----------
def compute_pairwise_metrics_gpu_tensors(X: torch.Tensor, y: torch.Tensor, class_i: int, class_j: int):
    """
    X: [N, D] torch on DEVICE
    y: [N] torch on DEVICE
    returns: fisher (float), mmd (float), wasserstein (float) - Python floats
    """
    device = X.device
    pair_mask = (y == class_i) | (y == class_j)
    if pair_mask.sum() < 2:
        return float('nan'), float('nan'), float('nan')

    X_pair = X[pair_mask]
    y_pair = y[pair_mask]
    if torch.unique(y_pair).numel() < 2 or X_pair.size(0) < 2:
        return float('nan'), float('nan'), float('nan')

    X_i = X_pair[y_pair == class_i]
    X_j = X_pair[y_pair == class_j]
    if X_i.size(0) == 0 or X_j.size(0) == 0:
        return float('nan'), float('nan'), float('nan')

    eps = 1e-6
    mu_i = torch.mean(X_i, dim=0)
    mu_j = torch.mean(X_j, dim=0)

    var_i = torch.var(X_i, dim=0, unbiased=False).sum() + eps
    var_j = torch.var(X_j, dim=0, unbiased=False).sum() + eps
    between = torch.sum((mu_i - mu_j) ** 2)
    fisher = (between / (var_i + var_j)).item() if (var_i + var_j) > 0 else float('nan')

    # MMD (RBF) on GPU using median heuristic for gamma
    with torch.no_grad():
        if X_i.size(0) == 0 or X_j.size(0) == 0:
            mmd = float('nan')
        else:
            dists_ij = torch.cdist(X_i, X_j, p=2)
            median_dist = torch.median(dists_ij).item() if dists_ij.numel() > 0 else 1.0
            gamma = 1.0 / (2.0 * (median_dist ** 2 + 1e-12))

            def rbf(a, b, gamma):
                aa = (a * a).sum(dim=1).unsqueeze(1)
                bb = (b * b).sum(dim=1).unsqueeze(0)
                dist2 = aa + bb - 2.0 * (a @ b.t())
                return torch.exp(-gamma * dist2)

            K_ii = rbf(X_i, X_i, gamma)
            K_jj = rbf(X_j, X_j, gamma)
            K_ij = rbf(X_i, X_j, gamma)
            mmd_sq = K_ii.mean() + K_jj.mean() - 2.0 * K_ij.mean()
            mmd_sq = torch.clamp(mmd_sq, min=0.0)
            mmd = torch.sqrt(mmd_sq).item()

    # Wasserstein 1D Gaussian approx on GPU
    std_i = torch.sqrt(torch.var(X_i, dim=0, unbiased=False).sum() / X_i.size(1) + eps)
    std_j = torch.sqrt(torch.var(X_j, dim=0, unbiased=False).sum() / X_j.size(1) + eps)
    wasserstein = torch.sqrt(torch.sum((mu_i - mu_j) ** 2) + (std_i - std_j) ** 2).item()

    return fisher, mmd, wasserstein
# ---------------------------------------------------------------------

# ---------- Probe single-pass and plotting ----------
def probe_single_pass_and_plot(json_paths: List[str], components: List[str], layers: List[int], max_per_domain: int = 1000, batch_size: int = 8, n_examples: Optional[int] = None, seq_pool: str = "mean"):
    ds = DomainIterableDataset(json_paths, max_per_domain=max_per_domain, shuffle=True)
    dl = DataLoader(ds, batch_size=batch_size, collate_fn=collate, shuffle=False, num_workers=0)
    print("Starting single-pass collection...")
    features, y = collect_all_features_gpu(dl, layers, components, seq_pool=seq_pool, n_examples=n_examples)
    N = y.size(0)
    print(f"Collected activations for {N} examples (device={DEVICE}).")

    n_cls = len(ds.domain_names)
    if n_cls < 2:
        raise RuntimeError("Not enough classes for pairwise probing.")

    dataset_names = ds.domain_names

    # Prepare results dict for JSON: keys -> lists (one entry per layer)
    pair_indices = list(itertools.combinations(range(n_cls), 2))
    metrics_keys = ["fisher_separability", "mmd", "wasserstein_distance"]
    results_json: Dict[str, List[Optional[float]]] = {}
    for comp in components:
        for (i, j) in pair_indices:
            name_i = dataset_names[i]
            name_j = dataset_names[j]
            san_i = re.sub(r"\W+", "_", name_i)
            san_j = re.sub(r"\W+", "_", name_j)
            for mk in metrics_keys:
                key = f"{comp}_{san_i}_vs_{san_j}_{mk}"
                results_json[key] = []

    for comp in components:
        for L in layers:
            X = features[comp][L]  # [N, D] torch on DEVICE
            if X.numel() == 0:
                print(f"No features for {comp} layer {L}; skipping.")
                # append None for alignment
                for (i, j) in pair_indices:
                    name_i = dataset_names[i]; name_j = dataset_names[j]
                    san_i = re.sub(r"\W+", "_", name_i); san_j = re.sub(r"\W+", "_", name_j)
                    for mk in metrics_keys:
                        key = f"{comp}_{san_i}_vs_{san_j}_{mk}"
                        results_json[key].append(None)
                continue

            # allocate metric matrices on GPU (small)
            fisher_mat = torch.full((n_cls, n_cls), float('nan'), device=DEVICE, dtype=torch.float32)
            mmd_mat = torch.full((n_cls, n_cls), float('nan'), device=DEVICE, dtype=torch.float32)
            wass_mat = torch.full((n_cls, n_cls), float('nan'), device=DEVICE, dtype=torch.float32)

            # print header for this component+layer
            print(f"\n--- Layer {L} | Component {comp.upper()} ---")

            # compute pairwise metrics
            for i, j in itertools.combinations(range(n_cls), 2):
                f, m, w = compute_pairwise_metrics_gpu_tensors(X, y, i, j)
                fisher_mat[i, j] = fisher_mat[j, i] = torch.tensor(f if not np.isnan(f) else float('nan'), device=DEVICE)
                mmd_mat[i, j] = mmd_mat[j, i] = torch.tensor(m if not np.isnan(m) else float('nan'), device=DEVICE)
                wass_mat[i, j] = wass_mat[j, i] = torch.tensor(w if not np.isnan(w) else float('nan'), device=DEVICE)

                # per-pair print (added logging only)
                name_i = dataset_names[i]
                name_j = dataset_names[j]
                print(f"  {name_i} vs {name_j} -> Fisher: {np.nan if np.isnan(f) else f:.6f}, MMD: {np.nan if np.isnan(m) else m:.6f}, Wasserstein: {np.nan if np.isnan(w) else w:.6f}")

                # append numeric or null to results_json lists for this pair & component
                san_i = re.sub(r"\W+", "_", name_i); san_j = re.sub(r"\W+", "_", name_j)
                key_f = f"{comp}_{san_i}_vs_{san_j}_fisher_separability"
                key_m = f"{comp}_{san_i}_vs_{san_j}_mmd"
                key_w = f"{comp}_{san_i}_vs_{san_j}_wasserstein_distance"

                results_json[key_f].append(None if (isinstance(f, float) and math.isnan(f)) else float(f))
                results_json[key_m].append(None if (isinstance(m, float) and math.isnan(m)) else float(m))
                results_json[key_w].append(None if (isinstance(w, float) and math.isnan(w)) else float(w))

            fisher_mat.fill_diagonal_(0.0)
            mmd_mat.fill_diagonal_(0.0)
            wass_mat.fill_diagonal_(0.0)

            # Move tiny matrices to CPU once for plotting
            fisher_np = fisher_mat.cpu().numpy()
            mmd_np = mmd_mat.cpu().numpy()
            wass_np = wass_mat.cpu().numpy()

            # Plot and save heatmaps for this comp+layer
            metrics = [("fisher", fisher_np), ("mmd", mmd_np), ("wasserstein", wass_np)]
            cmap_map = {"fisher": "YlOrRd", "mmd": "Blues", "wasserstein": "Blues"}

            for metric_name, matrix in metrics:
                plt.figure(figsize=(9,7))
                sns.heatmap(matrix, annot=True, cmap=cmap_map[metric_name], fmt=".3f",
                            xticklabels=dataset_names, yticklabels=dataset_names,
                            linewidths=.5, linecolor='lightgray', square=True, cbar_kws={'label': metric_name})
                plt.xlabel("Domain Class")
                plt.ylabel("Domain Class")
                plt.title(f"{metric_name.upper()} - Component: {comp} Layer: {L}")
                plt.tight_layout()
                fname = os.path.join(PROBE_OUTPUT_DIR, f"{comp}_{metric_name}_layer_{L}.png")
                plt.savefig(fname)
                plt.close()
                print("Saved:", fname)

            # free GPU memory for this layer+component
            del features[comp][L]
            del fisher_mat, mmd_mat, wass_mat
            gc.collect(); torch.cuda.empty_cache()

    # Save results_json to file
    out_path = os.path.join(PROBE_OUTPUT_DIR, "metrics_per_pair_layers.json")
    with open(out_path, "w", encoding="utf-8") as of:
        json.dump(results_json, of, indent=2)
    print("Saved per-pair per-layer metrics JSON to:", out_path)

    print("All done. Plots saved to:", PROBE_OUTPUT_DIR)
# ---------------------------------------------------------------------

# --------------- Run if main ----------------
if __name__ == "__main__":
    print("Running Gemma single-pass GPU probe...")
    components = ["attn", "mlp", "resid"]
    layers_to_probe = list(range(TOTAL_LAYERS))  # e.g. list(range(0, TOTAL_LAYERS)) or subset like list(range(7,26))
    probe_single_pass_and_plot(DOMAIN_JSONS, components, layers_to_probe, max_per_domain=1000, batch_size=8, n_examples=None, seq_pool="mean")
