# Script for empirical evaluation of softmax Lipschitz constant — supplementary material for TMLR submission.

import os
import re
import math
import argparse
import random
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

def softmax_torch(x: torch.Tensor, temperature: float = 1.0, dim: int = -1) -> torch.Tensor:
    """
    Temperature-scaled softmax implemented in PyTorch.

    Args:
        x (torch.Tensor): Input tensor.
        temperature (float): Positive scalar temperature. Lower = sharper softmax.
        dim (int): Dimension over which to apply softmax.

    Returns:
        torch.Tensor: Softmax output along the specified dimension.
    """
    if temperature <= 0:
        raise ValueError("temperature must be positive.")

    # Numerical stability: subtract max along the dimension before exponentiation
    z = x / temperature
    z = z - z.max(dim=dim, keepdim=True).values
    exp_z = torch.exp(z)
    return exp_z / exp_z.sum(dim=dim, keepdim=True)

# ----------------------------- Utilities ------------------------------------ #

def set_seed(seed: Optional[int]):
    if seed is None:
        return
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def lp_norm(t: torch.Tensor, p: float, dim=None, keepdim=False) -> torch.Tensor:
#    if p == float('inf'):
#        return t.abs().amax(dim=dim, keepdim=keepdim)
#    return (t.abs() ** p).sum(dim=dim, keepdim=keepdim) ** (1.0 / p)
    ord_ = float('inf') if p == float('inf') else p
    return torch.linalg.vector_norm(t, ord=ord_, dim=dim, keepdim=keepdim)
# -------------------------- Dataset loading --------------------------------- #

def load_prompts(dataset_name: str, split: str, num_prompts: int, seed: int) -> List[str]:
    """
    Lightweight text sources:
      - piqa: use 'goal' field
      - hellaswag: use 'ctx'
    """
    from datasets import load_dataset
    rng = random.Random(seed)

    ds_name = dataset_name.lower().strip()
    if ds_name == "piqa":
        ds = load_dataset("regisss/piqa", split=split)
        all_prompts = [ex["goal"] for ex in ds]
    elif ds_name == "hellaswag":
        ds = load_dataset("hellaswag", split=split)
        all_prompts = [ex["ctx"] for ex in ds]
    else:
        raise ValueError("Unsupported dataset_name. Use 'piqa' or 'hellaswag'.")
    if num_prompts <= 0:
        return all_prompts
    num = min(num_prompts, len(all_prompts))
    return rng.sample(all_prompts, k=num)

def normalize_to_lp(delta: torch.Tensor, p: float, eps: float) -> torch.Tensor:
    """
    Normalize each sample (first dim) of `delta` to have Lp norm = eps.
    """
    B = delta.shape[0]
    flat = delta.reshape(B, -1)

    if p == float("inf"):
        maxabs = flat.abs().amax(dim=-1).clamp(min=1e-12)
        scale = (eps / maxabs).view(B, 1)
        out = flat * scale
    else:
        lp = lp_norm(flat, p, dim=-1).clamp(min=1e-12)
        scale = (eps / lp).view(B, 1)
        out = flat * scale

    return out.view_as(delta)

# --------------------------- Q/K Hooking ------------------------------------ #


# Replace your two regexes with this broader set:
Q_PATTERNS = [
    r"\.(attn|self_attn)\.(q_proj)$",
    r"\.(attn|self_attn)\.(wq)$",            # some Qwen / LLaMA ports
]
K_PATTERNS = [
    r"\.(attn|self_attn)\.(k_proj)$",
    r"\.(attn|self_attn)\.(wk)$",
]
# packed qkv (produce all three at once)
QKV_PATTERNS = [
    r"\.(attn|self_attn)\.(qkv|qkv_proj|W_pack|wqkv)$"
]

Q_REGEXES   = [re.compile(p) for p in Q_PATTERNS]
K_REGEXES   = [re.compile(p) for p in K_PATTERNS]
QKV_REGEXES = [re.compile(p) for p in QKV_PATTERNS]

def _name_matches(name: str, regs):
    return any(r.search(name) for r in regs)

class QKCollector:
    def __init__(self, model, num_q_heads, num_kv_heads):
        self.model = model
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.q, self.k = {}, {}
        self.handles = []
        self.layer_counter = 0

        for name, module in model.named_modules():
            if _name_matches(name, Q_REGEXES):
                lid = self._layer_index_from_name(name)
                self.handles.append(module.register_forward_hook(self._make_store_single(self.q, lid)))
            elif _name_matches(name, K_REGEXES):
                lid = self._layer_index_from_name(name)
                self.handles.append(module.register_forward_hook(self._make_store_single(self.k, lid)))
            elif _name_matches(name, QKV_REGEXES):
                lid = self._layer_index_from_name(name)
                self.handles.append(module.register_forward_hook(self._make_store_packed(self.q, self.k, lid)))

    def _layer_index_from_name(self, name: str) -> int:
        m = re.search(r"\.(\d+)\.(attn|self_attn)\.", name)
        if m:
            return int(m.group(1))
        lid = self.layer_counter
        self.layer_counter += 1
        return lid

    def _make_store_single(self, store: dict, lid: int):
        def hook_fn(module, inputs, output):
            # expected Linear: [B, T, hidden]
            store[lid] = output
        return hook_fn

    def _make_store_packed(self, q_store: dict, k_store: dict, lid: int):
        def hook_fn(module, inputs, output):
            # output is [B, T, 3*hidden] or [B,T,hidden] with packed heads
            x = output
            threeE = x.shape[-1]
            assert threeE % 3 == 0, "Packed QKV projection not divisible by 3"
            E = threeE // 3
            q_store[lid] = x[..., :E]
            k_store[lid] = x[..., E:2*E]
        return hook_fn

    def clear(self):
        self.q.clear(); self.k.clear()

    def remove(self):
        for h in self.handles:
            try: h.remove()
            except: pass
        self.handles = []

    def compute_scores(self, idx: int, hidden_size: int) -> Optional[torch.Tensor]:
        """Return S = (Q K^T)/sqrt(dh) with shape [B, Hq, T, T] for layer idx, or None if missing."""
        if idx not in self.q or idx not in self.k:
            return None
        q = self.q[idx]  # [B, T, H]
        k = self.k[idx]  # [B, T, Hk] (usually Hk == H)
        B, T, _ = q.shape
        Hq = self.num_q_heads
        Hkv = self.num_kv_heads if self.num_kv_heads is not None else Hq
        Dh = q.shape[-1] // Hq
        # reshape
        qh = q.view(B, T, Hq, Dh).transpose(1, 2).contiguous()   # [B, Hq, T, Dh]
        kh = k.view(B, T, Hkv, Dh).transpose(1, 2).contiguous()  # [B, Hkv, T, Dh]
        if Hkv != Hq:
            assert Hq % Hkv == 0, "num_attention_heads must be divisible by num_key_value_heads for GQA"
            rep = Hq // Hkv
            kh = kh.repeat_interleave(rep, dim=1)  # [B, Hq, T, Dh]
        scale = 1.0 / math.sqrt(Dh)
        S = torch.matmul(qh, kh.transpose(-1, -2)) * scale        # [B, Hq, T, T]
        return S

# --------------------------- Core measurement ------------------------------- #

@torch.no_grad()
def measure_empirical_lip_on_scores(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompts: List[str],
    p_list: List[float],
    eps_list: List[float],
    num_trials: int,
    device: str,
    max_length: int,
    batch_prompts: int,
) -> pd.DataFrame:
    """
    For each prompt: run the model once while hooks capture Q and K per layer.
    Build S = (QK^T)/sqrt(dh), then evaluate empirical Lipschitz on softmax(S).
    Aggregate max over trials and prompts.
    """
    cfg = model.config
    num_q_heads = int(getattr(cfg, "num_attention_heads", 0) or 0)
    num_kv_heads = int(getattr(cfg, "num_key_value_heads", num_q_heads) or num_q_heads)
    hidden_size = int(getattr(cfg, "hidden_size", None) or getattr(cfg, "n_embd", None))

    if hidden_size is None or num_q_heads == 0:
        raise ValueError("Could not infer hidden_size/num_attention_heads from model.config")

    collector = QKCollector(model, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads)
    ratios: Dict[Tuple[int, int, float, float], List[float]] = {}

    def chunks(lst, n):
        for i in range(0, len(lst), n):
            yield lst[i:i+n]
    for batch_texts in chunks(prompts, batch_prompts):
        collector.clear()

        enc = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True,
                        max_length=max_length).to(device)
        _ = model(**enc, output_attentions=False, use_cache=False)

        # For each layer index that has both Q and K
        layer_ids = sorted(set(collector.q.keys()) & set(collector.k.keys()))
        for lid in layer_ids:
            S = collector.compute_scores(lid, hidden_size=hidden_size)  # [B, Hq, T, T]
            if S is None:
                continue
            B, H, T, _ = S.shape
            S = S.contiguous()

            for h in range(H):
                S_h = S[:, h, :, :]  # [B, T, T]
                for p in p_list:
                    for eps in eps_list:
                        ratios_trials = []
                        for _ in range(num_trials):
                            delta = torch.randn_like(S_h)
                            den_rows = lp_norm(delta, p, dim=-1, keepdim=True).clamp_min(1e-12)  # [B, T, 1]
                            delta = eps * (delta / den_rows)          
                            S_h = S_h.to(torch.float32)
                            delta = delta.to(torch.float32)
                            Sp = S_h + delta

                            with torch.no_grad(): 
                                Wp = F.softmax(Sp, dim=-1)
                                W0 = F.softmax(S_h, dim=-1)

                            num = lp_norm(Wp - W0, p, dim=-1)      # [B, T]
                            den = lp_norm(delta, p, dim=-1).clamp_min(1e-12)
                                
                            ratio = (num / den).max().item()
                            ratios_trials.append(ratio)

                        key = (lid, h, float(p), float(eps))
                        ratios.setdefault(key, []).append(float(np.max(ratios_trials)))

    collector.remove()

    # Aggregate across prompt batches
    rows = []
    for (layer_idx, head, p, eps), vals in ratios.items():
        rows.append({
            "layer": layer_idx,
            "head": head,
            "p": p,
            "epsilon": eps,
            "ratio_max_over_prompts": float(np.max(vals)),
        })
    df = pd.DataFrame(rows)
    if not df.empty:
        df = df.sort_values(["layer", "head", "p", "epsilon"]).reset_index(drop=True)
    return df

def plot_agg(df: pd.DataFrame, p_list: List[float], out_png: str, y_ref: float = 0.5):
    """
    For each p, plot max over layers of ratio_max_over_prompts vs epsilon.
    Save to out_png.
    """
    base, ext = os.path.splitext(out_png)
    if not ext:
        ext = ".png"

    if df.empty:
        print("No data to plot. (Empty dataframe)")
        return

    agg = (
        df.groupby(["p", "epsilon"], as_index=False)["ratio_max_over_prompts"]
          .max()
          .sort_values(["p", "epsilon"])
    )

    fig = plt.figure(figsize=(6, 4))
    for p_val, g in agg.groupby("p"):
        g = g.sort_values("epsilon")
        plt.plot(g["epsilon"], g["ratio_max_over_prompts"],
                 marker="o", markersize=5, linewidth=2.2, label=f"p={p_val}")
    if y_ref is not None:
        plt.axhline(y_ref, linestyle="--", linewidth=2.0)
    plt.xscale("log")
    plt.xlabel(r"Perturbation $\epsilon$", fontsize=15)
    plt.ylabel("Empirical $L_p$", fontsize=15)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{base}{ext}", dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved plot to: {base}{ext}")

# ------------------------------- CLI ---------------------------------------- #

def parse_args():
    parser = argparse.ArgumentParser(description="Empirical Lipschitz over pre-softmax attention (Qwen; hooks Q/K)")
    # Model & loading
    parser.add_argument("--model_id", type=str, default="Qwen/Qwen3-32B",
                        help="HF repo id, e.g., Qwen/Qwen3-32B")
    parser.add_argument("--no_4bit", action="store_true",
                        help="Disable 4-bit loading (uses full precision bf16/fp16 if available)")
    parser.add_argument("--bf16", action="store_true", default=True,
                        help="Use bfloat16 where possible (default True)")
    parser.add_argument("--device_map", type=str, default="auto",
                        help="'auto' (shard across GPUs) or a device string like 'cuda:0'")
    parser.add_argument("--trust_remote_code", action="store_true", default=True,
                        help="Pass trust_remote_code=True to from_pretrained")

    # Data / prompts
    parser.add_argument("--dataset", type=str, default="piqa", choices=["piqa", "hellaswag"],
                        help="Dataset to draw prompts from")
    parser.add_argument("--split", type=str, default="train", help="Dataset split")
    parser.add_argument("--num_prompts", type=int, default=100, help="Number of prompts to sample")
    parser.add_argument("--batch_prompts", type=int, default=8, help="Batch size for tokenizer/model")
    parser.add_argument("--max_length", type=int, default=128, help="Tokenizer truncation length")

    # Perturbation
    parser.add_argument("--p_list", type=float, nargs="+",
                        default=[1.0, 2.0, 5.0, 10.0, float('inf')], help="p norms")
    parser.add_argument("--eps_list", type=float, nargs="+",
                        default=[5e-3, 1e-2, 5e-2, 1e-1, 5e-1, 1, 5, 10], help="epsilon magnitudes")
    parser.add_argument("--num_trials", type=int, default=1, help="Random directions per (p, eps)")

    # Misc
    parser.add_argument("--seed", type=int, default=0, help="RNG seed (set None for full randomness)")
    parser.add_argument("--out_png", type=str, default="empirical_Lp_qwen_qk.png", help="Output plot filename")

    return parser.parse_args()

# -------------------------------- Main -------------------------------------- #

def main():
    args = parse_args()
    set_seed(args.seed)

    # dtype
    dtype = torch.bfloat16 if args.bf16 else torch.float16

    # Config & tokenizer
    print("Loading model and tokenizer...")
    config = AutoConfig.from_pretrained(args.model_id, trust_remote_code=args.trust_remote_code)
    # We hook Q/K; no need for output_attentions
    tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=False, trust_remote_code=args.trust_remote_code)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Model
    load_kwargs = dict(
        config=config,
        trust_remote_code=args.trust_remote_code,
        device_map=args.device_map,
        torch_dtype=dtype,
    )
    if not args.no_4bit:
        load_kwargs.update(dict(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        ))
    model = AutoModelForCausalLM.from_pretrained(args.model_id, **load_kwargs)
    model.eval()

    # Device for token batches
    try:
        first_param = next(model.parameters())
        model_device = first_param.device
    except StopIteration:
        model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_str = str(model_device)

    print("Loaded:", args.model_id)
    print("Layers:", getattr(model.config, "num_hidden_layers", "N/A"))
    print("Heads :", getattr(model.config, "num_attention_heads", "GQA"))
    if hasattr(model.config, "num_key_value_heads"):
        print("KV Heads:", getattr(model.config, "num_key_value_heads"))

    # Prompts
    prompts = load_prompts(args.dataset, args.split, args.num_prompts, seed=(args.seed or 0))
    print(f"Loaded {len(prompts)} prompts from {args.dataset}:{args.split}.")

    # Measure
    df = measure_empirical_lip_on_scores(
        model=model,
        tokenizer=tokenizer,
        prompts=prompts,
        p_list=args.p_list,
        eps_list=args.eps_list,
        num_trials=args.num_trials,
        device=device_str,
        max_length=args.max_length,
        batch_prompts=args.batch_prompts,
    )

    # Plot 
    plot_agg(df, p_list=args.p_list, out_png=args.out_png, y_ref=0.5)

if __name__ == "__main__":
    main()
