# probe_gemma_singlepass_with_json_output.py
import json, gc, os, random, itertools, pathlib, re, warnings, math
from typing import List, Dict, Iterable, Literal, Optional
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import IterableDataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm

# silence some noisy but harmless warnings
warnings.filterwarnings("ignore", message=".*not compiled with flash attention.*", category=UserWarning)

# -------------- USER / MODEL CONFIG --------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "google/gemma-3-1b-it"
HF_TOKEN = os.environ.get("HF_TOKEN", "your hf_token goes here")  # set HF_TOKEN env var if needed

# Output dir for heatmaps:
PROBE_OUTPUT_DIR = "your output directory goes here"
os.makedirs(PROBE_OUTPUT_DIR, exist_ok=True)

# Domain JSONs - change paths if needed
DOMAIN_JSONS = [
    "General_datasets/finance.json",
    "General_datasets/medical.json",
    "General_datasets/math.json",
    "General_datasets/science.json",
    "Coding_datasets/cpp_top.json",
    "Coding_datasets/python_top.json" 
    
]
# -------------------------------------------------

# --------- Load tokenizer & model -------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, token=HF_TOKEN)
# keep float16 (or change to quantized later as you prefer)
model = (AutoModelForCausalLM
         .from_pretrained(MODEL_NAME, torch_dtype=torch.float16,
                          low_cpu_mem_usage=True, token=HF_TOKEN)
         .to(DEVICE)
         .eval())
if getattr(tokenizer, "pad_token", None) is None:
    tokenizer.pad_token = tokenizer.eos_token

def get_num_layers(m):
    for attr in ("model", "transformer", "transformer_model", "decoder", "gemma"):
        if hasattr(m, attr):
            root = getattr(m, attr)
            layers = getattr(root, "layers", None) or getattr(root, "h", None) or getattr(root, "block", None)
            if layers is not None:
                return len(layers)
    try:
        return len(m.model.layers)
    except Exception:
        return None

TOTAL_LAYERS =  26
print(f"Loaded {MODEL_NAME} on {DEVICE}. Total layers (detected): {TOTAL_LAYERS}")
# -----------------------------------------------

# ---------------- Dataset ----------------
class DomainIterableDataset(IterableDataset):
    def __init__(self, json_paths: List[str], max_per_domain: int = 1000, shuffle: bool = True):
        self.paths = [pathlib.Path(p) for p in json_paths]
        self.max_per_domain = max_per_domain
        self.shuffle = shuffle
        self.domain2id = {p: i for i, p in enumerate(self.paths)}
        self.domain_names = [p.stem for p in self.paths]

    def _sample_texts(self, file_path: pathlib.Path) -> Iterable[str]:
        with open(file_path, "r", encoding="utf-8") as f:
            texts = json.load(f)
        if not texts:
            return []
        if self.shuffle:
            random.shuffle(texts)
        return texts[:self.max_per_domain]

    def __iter__(self):
        files = self.paths
        if self.shuffle:
            random.shuffle(files)
        for fp in files:
            d_id = self.domain2id[fp]
            for txt in self._sample_texts(fp):
                yield txt, d_id

def collate(batch):
    texts, labels = zip(*batch)
    return list(texts), list(labels)
# ------------------------------------------

# ---------- Multi-layer activation grabber (single-pass) ----------
class MultiLayerActivationGrabber:
    """
    Registers forward hooks across requested layers and components and stores pooled outputs on DEVICE.
    Designed to work with Gemma naming (self_attn.o_proj, mlp.{up_proj,gate_proj,down_proj}, layernorms).
    """
    def __init__(self, model, layers: List[int], components: List[Literal["attn","mlp","resid"]], seq_pool: str = "mean"):
        self.model = model
        self.layers_idx = layers
        self.components = components
        self.seq_pool = seq_pool
        self.buffers: Dict[str, Dict[int, List[torch.Tensor]]] = {c: {L: [] for L in layers} for c in components}
        self.handles = []
        self._registered = False

    def _pool_tensor(self, x: torch.Tensor) -> torch.Tensor:
        # pooling on GPU
        if self.seq_pool == "mean":
            return x.mean(dim=1)
        elif self.seq_pool == "first":
            return x[:, 0]
        else:
            raise ValueError(f"Unsupported seq_pool {self.seq_pool}")

    def _save_hook_factory(self, comp: str, L: int):
        def hook(module, input, out):
            out_t = out[0] if isinstance(out, tuple) else out
            pooled = self._pool_tensor(out_t)      # [batch, dim] on DEVICE
            # detach and append (keeps on DEVICE)
            self.buffers[comp][L].append(pooled.detach())
        return hook

    def _find_layer_module(self, L: int):
        candidate_roots = []
        for root_attr in ("model", "transformer", "transformer_model", "decoder", "gemma"):
            if hasattr(self.model, root_attr):
                candidate_roots.append(getattr(self.model, root_attr))
        if not candidate_roots:
            candidate_roots = [self.model]
        for root in candidate_roots:
            layers = getattr(root, "layers", None) or getattr(root, "h", None) or getattr(root, "block", None)
            if layers is None:
                continue
            if L < 0 or L >= len(layers):
                continue
            return layers[L]
        raise RuntimeError(f"Layer {L} not found in model")

    def _find_module_for(self, layer_module, component: str):
        # Gemma-style: try clear attribute names first, else fallback to name-scan
        if component == "attn":
            if hasattr(layer_module, "self_attn") and hasattr(layer_module.self_attn, "o_proj"):
                return layer_module.self_attn.o_proj
            for name, sub in layer_module.named_modules():
                if re.search(r"(o_proj|out_proj|wo|self_attn)", name, flags=re.IGNORECASE):
                    return sub
        elif component == "mlp":
            if hasattr(layer_module, "mlp"):
                mlp = layer_module.mlp
                for cand in ("up_proj", "gate_proj", "down_proj"):
                    if hasattr(mlp, cand):
                        return getattr(mlp, cand)
            for name, sub in layer_module.named_modules():
                if re.search(r"(mlp|up_proj|gate_proj|down_proj|wo)", name, flags=re.IGNORECASE):
                    return sub
        elif component == "resid":
            for cand in ("post_feedforward_layernorm", "post_attention_layernorm", "post_norm", "layernorm2"):
                if hasattr(layer_module, cand):
                    return getattr(layer_module, cand)
            # fallback to full layer module (captures block output)
            return layer_module
        raise RuntimeError(f"Module for component {component} not found in layer")

    def register_hooks(self):
        if self._registered:
            return
        for comp in self.components:
            for L in self.layers_idx:
                layer_module = self._find_layer_module(L)
                mod = self._find_module_for(layer_module, comp)
                h = mod.register_forward_hook(self._save_hook_factory(comp, L))
                self.handles.append(h)
        self._registered = True

    def remove_hooks(self):
        for h in self.handles:
            h.remove()
        self.handles = []
        self._registered = False

    def clear_buffers(self):
        self.buffers = {c: {L: [] for L in self.layers_idx} for c in self.components}

    def stacked_features(self, component: str, layer: int) -> Optional[torch.Tensor]:
        lst = self.buffers[component][layer]
        if not lst:
            return None
        return torch.cat(lst, dim=0)  # [N, D] on DEVICE
# ------------------------------------------------------------------


# ---------- Single-pass collection (keeps everything on GPU) ----------
def collect_all_features_gpu(dataloader: DataLoader, layers: List[int], components: List[str], seq_pool: str = "mean", n_examples: Optional[int] = None):
    grabber = MultiLayerActivationGrabber(model, layers, components, seq_pool=seq_pool)
    grabber.clear_buffers()
    grabber.register_hooks()

    labels_accum = []  # small python list then converted to torch on DEVICE
    seen = 0
    with torch.no_grad():
        for batch_texts, batch_labels in tqdm(dataloader, desc="Collecting activations (single-pass)"):
            toks = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
            _ = model(**toks)
            labels_accum.extend(batch_labels)
            seen += len(batch_labels)
            if n_examples is not None and seen >= n_examples:
                break

    grabber.remove_hooks()

    # labels -> tensor on DEVICE to keep subsequent masking on GPU
    y = torch.tensor(labels_accum, dtype=torch.long, device=DEVICE)

    features: Dict[str, Dict[int, torch.Tensor]] = {c: {} for c in components}
    for c in components:
        for L in layers:
            stacked = grabber.stacked_features(c, L)  # on DEVICE or None
            if stacked is None:
                features[c][L] = torch.empty((0,0), device=DEVICE)
            else:
                features[c][L] = stacked  # kept on DEVICE

    grabber.clear_buffers()
    gc.collect()
    torch.cuda.empty_cache()
    return features, y
# ---------------------------------------------------------------------


# ---------- GPU metric computations (Fisher, MMD, Wasserstein) ----------
def compute_pairwise_metrics_gpu_tensors(X: torch.Tensor, y: torch.Tensor, class_i: int, class_j: int):
    """
    X: [N, D] torch on DEVICE
    y: [N] torch on DEVICE
    returns: fisher (float), mmd (float), wasserstein (float) - as Python floats
    """
    device = X.device
    # mask on GPU
    pair_mask = (y == class_i) | (y == class_j)
    if pair_mask.sum() < 2:
        return float('nan'), float('nan'), float('nan')

    X_pair = X[pair_mask]
    y_pair = y[pair_mask]
    if torch.unique(y_pair).numel() < 2 or X_pair.size(0) < 2:
        return float('nan'), float('nan'), float('nan')

    X_i = X_pair[y_pair == class_i]
    X_j = X_pair[y_pair == class_j]
    if X_i.size(0) == 0 or X_j.size(0) == 0:
        return float('nan'), float('nan'), float('nan')

    eps = 1e-6
    mu_i = torch.mean(X_i, dim=0)
    mu_j = torch.mean(X_j, dim=0)

    var_i = torch.var(X_i, dim=0, unbiased=False).sum() + eps
    var_j = torch.var(X_j, dim=0, unbiased=False).sum() + eps
    between = torch.sum((mu_i - mu_j) ** 2)
    fisher = (between / (var_i + var_j)).item() if (var_i + var_j) > 0 else float('nan')

    # MMD (RBF) on GPU using median heuristic for gamma
    with torch.no_grad():
        # guard for tiny sizes
        if X_i.size(0) == 0 or X_j.size(0) == 0:
            mmd = float('nan')
        else:
            dists_ij = torch.cdist(X_i, X_j, p=2)
            median_dist = torch.median(dists_ij).item() if dists_ij.numel() > 0 else 1.0
            gamma = 1.0 / (2.0 * (median_dist ** 2 + 1e-12))

            def rbf(a, b, gamma):
                aa = (a * a).sum(dim=1).unsqueeze(1)
                bb = (b * b).sum(dim=1).unsqueeze(0)
                dist2 = aa + bb - 2.0 * (a @ b.t())
                return torch.exp(-gamma * dist2)

            K_ii = rbf(X_i, X_i, gamma)
            K_jj = rbf(X_j, X_j, gamma)
            K_ij = rbf(X_i, X_j, gamma)
            mmd_sq = K_ii.mean() + K_jj.mean() - 2.0 * K_ij.mean()
            mmd_sq = torch.clamp(mmd_sq, min=0.0)
            mmd = torch.sqrt(mmd_sq).item()

    # Wasserstein 1D Gaussian approx on GPU
    std_i = torch.sqrt(torch.var(X_i, dim=0, unbiased=False).sum() / X_i.size(1) + eps)
    std_j = torch.sqrt(torch.var(X_j, dim=0, unbiased=False).sum() / X_j.size(1) + eps)
    wasserstein = torch.sqrt(torch.sum((mu_i - mu_j) ** 2) + (std_i - std_j) ** 2).item()

    return fisher, mmd, wasserstein
# ---------------------------------------------------------------------

# ---------- Probe single-pass and plotting ----------
def probe_single_pass_and_plot(json_paths: List[str], components: List[str], layers: List[int], max_per_domain: int = 1000, batch_size: int = 8, n_examples: Optional[int] = None, seq_pool: str = "mean"):
    ds = DomainIterableDataset(json_paths, max_per_domain=max_per_domain, shuffle=True)
    dl = DataLoader(ds, batch_size=batch_size, collate_fn=collate, shuffle=False, num_workers=0)
    print("Starting single-pass collection...")
    features, y = collect_all_features_gpu(dl, layers, components, seq_pool=seq_pool, n_examples=n_examples)
    N = y.size(0)
    print(f"Collected activations for {N} examples (device={DEVICE}).")

    n_cls = len(ds.domain_names)
    if n_cls < 2:
        raise RuntimeError("Not enough classes for pairwise probing.")

    dataset_names = ds.domain_names  # move outside loops for convenience

    # Prepare results dict for JSON: keys -> lists (one entry per layer)
    # We will populate lists layer-by-layer in the same order as 'layers'
    pair_indices = list(itertools.combinations(range(n_cls), 2))
    metrics_keys = ["fisher_separability", "mmd", "wasserstein_distance"]
    results_json: Dict[str, List[Optional[float]]] = {}
    for comp in components:
        for (i, j) in pair_indices:
            name_i = dataset_names[i]
            name_j = dataset_names[j]
            # sanitize names to avoid problematic characters in keys
            san_i = re.sub(r"\W+", "_", name_i)
            san_j = re.sub(r"\W+", "_", name_j)
            for mk in metrics_keys:
                key = f"{comp}_{san_i}_vs_{san_j}_{mk}"
                results_json[key] = []  # will append per-layer values

    # Iterate over components & layers and fill results_json lists
    for comp in components:
        for L in layers:
            X = features[comp][L]  # [N, D] torch on DEVICE
            if X.numel() == 0:
                print(f"No features for {comp} layer {L}; skipping.")
                # still append nulls for each pair so arrays keep layer-length alignment
                for (i, j) in pair_indices:
                    name_i = dataset_names[i]; name_j = dataset_names[j]
                    san_i = re.sub(r"\W+", "_", name_i); san_j = re.sub(r"\W+", "_", name_j)
                    for mk in metrics_keys:
                        key = f"{comp}_{san_i}_vs_{san_j}_{mk}"
                        results_json[key].append(None)
                continue

            # allocate metric matrices on GPU
            fisher_mat = torch.full((n_cls, n_cls), float('nan'), device=DEVICE, dtype=torch.float32)
            mmd_mat = torch.full((n_cls, n_cls), float('nan'), device=DEVICE, dtype=torch.float32)
            wass_mat = torch.full((n_cls, n_cls), float('nan'), device=DEVICE, dtype=torch.float32)

            # print header for this component+layer
            print(f"\n--- Layer {L} | Component {comp.upper()} ---")

            # compute for each pair i<j
            for i, j in itertools.combinations(range(n_cls), 2):
                f, m, w = compute_pairwise_metrics_gpu_tensors(X, y, i, j)
                # store symmetrically as GPU scalars
                fisher_mat[i, j] = fisher_mat[j, i] = torch.tensor(f if not np.isnan(f) else float('nan'), device=DEVICE)
                mmd_mat[i, j] = mmd_mat[j, i] = torch.tensor(m if not np.isnan(m) else float('nan'), device=DEVICE)
                wass_mat[i, j] = wass_mat[j, i] = torch.tensor(w if not np.isnan(w) else float('nan'), device=DEVICE)

                # print per-pair metrics (added logging)
                name_i = dataset_names[i]
                name_j = dataset_names[j]
                print(f"  {name_i} vs {name_j} -> Fisher: {np.nan if np.isnan(f) else f:.6f}, MMD: {np.nan if np.isnan(m) else m:.6f}, Wasserstein: {np.nan if np.isnan(w) else w:.6f}")

                # append numeric or null to results_json lists for this pair & component
                san_i = re.sub(r"\W+", "_", name_i); san_j = re.sub(r"\W+", "_", name_j)
                key_f = f"{comp}_{san_i}_vs_{san_j}_fisher_separability"
                key_m = f"{comp}_{san_i}_vs_{san_j}_mmd"
                key_w = f"{comp}_{san_i}_vs_{san_j}_wasserstein_distance"

                results_json[key_f].append(None if (isinstance(f, float) and math.isnan(f)) else float(f))
                results_json[key_m].append(None if (isinstance(m, float) and math.isnan(m)) else float(m))
                results_json[key_w].append(None if (isinstance(w, float) and math.isnan(w)) else float(w))

            fisher_mat.fill_diagonal_(0.0)
            mmd_mat.fill_diagonal_(0.0)
            wass_mat.fill_diagonal_(0.0)

            # Move tiny matrices to CPU once for plotting
            fisher_np = fisher_mat.cpu().numpy()
            mmd_np = mmd_mat.cpu().numpy()
            wass_np = wass_mat.cpu().numpy()

            # Plot and save heatmaps for this comp+layer
            dataset_names_local = ds.domain_names
            metrics = [("fisher", fisher_np), ("mmd", mmd_np), ("wasserstein", wass_np)]
            cmap_map = {"fisher": "YlOrRd", "mmd": "Blues", "wasserstein": "Blues"}

            for metric_name, matrix in metrics:
                plt.figure(figsize=(9,7))
                sns.heatmap(matrix, annot=True, cmap=cmap_map[metric_name], fmt=".3f",
                            xticklabels=dataset_names_local, yticklabels=dataset_names_local,
                            linewidths=.5, linecolor='lightgray', square=True, cbar_kws={'label': metric_name})
                plt.xlabel("Domain Class")
                plt.ylabel("Domain Class")
                plt.title(f"{metric_name.upper()} - Component: {comp} Layer: {L}")
                plt.tight_layout()
                fname = os.path.join(PROBE_OUTPUT_DIR, f"{comp}_{metric_name}_layer_{L}.png")
                plt.savefig(fname)
                plt.close()
                print("Saved:", fname)

            # free GPU memory for this layer+component
            del features[comp][L]
            del fisher_mat, mmd_mat, wass_mat
            gc.collect(); torch.cuda.empty_cache()

    # Save results_json to file
    out_path = os.path.join(PROBE_OUTPUT_DIR, "metrics_per_pair_layers.json")
    # json.dump will write null for Python None which is valid JSON
    with open(out_path, "w", encoding="utf-8") as of:
        json.dump(results_json, of, indent=2)
    print("Saved per-pair per-layer metrics JSON to:", out_path)

    print("All done. Plots saved to:", PROBE_OUTPUT_DIR)
# ---------------------------------------------------------------------


# ----------------- Run if main -----------------
if __name__ == "__main__":
    print("Running Gemma single-pass GPU probe...")
    components = ["attn", "mlp","resid"]   # probe these components
    layers_to_probe = list(range(0,26))  # adjust if you want a subset e.g. list(range(7,26))
    probe_single_pass_and_plot(DOMAIN_JSONS, components, layers_to_probe, max_per_domain=1000, batch_size=8, n_examples=None, seq_pool="mean")
