# Script for empirical evaluation of softmax Lipschitz constant — supplementary material for TMLR submission.


import os
import math
import argparse
import random
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from transformers import GPT2Model, GPT2Tokenizer

# ----------------------------- Utilities ------------------------------------ #

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def lp_norm(t: torch.Tensor, p: float, dim=None, keepdim=False) -> torch.Tensor:
    if p == float('inf'):
        return t.abs().amax(dim=dim, keepdim=keepdim)
    return (t.abs() ** p).sum(dim=dim, keepdim=keepdim) ** (1.0 / p)

def to_list_floats(x: List[float]) -> List[float]:
    out = []
    for v in x:
        if isinstance(v, str) and v.lower() in ("inf", "infty", "infinite"):
            out.append(float('inf'))
        else:
            out.append(float(v))
    return out

# -------------------------- Dataset loading --------------------------------- #

def load_prompts(dataset_name: str, split: str, num_prompts: int, seed: int) -> List[str]:
    """
    Loads a lightweight text dataset and returns a list of prompts.
    Supported options (lightweight):
      - piqa: uses 'goal' field (commonsense)
      - hellaswag: uses 'ctx' field
    """
    from datasets import load_dataset
    rng = random.Random(seed)

    ds_name = dataset_name.lower().strip()
    if ds_name == "piqa":
        ds = load_dataset("regisss/piqa", split=split)  # alt mirrors if needed
        all_prompts = [ex["goal"] for ex in ds]
    elif ds_name == "hellaswag":
        ds = load_dataset("hellaswag", split=split)
        all_prompts = [ex["ctx"] for ex in ds]
    else:
        raise ValueError("Unsupported dataset_name. Use 'piqa' or 'hellaswag'.")
    if num_prompts <= 0:
        return all_prompts
    num = min(num_prompts, len(all_prompts))
    return rng.sample(all_prompts, k=num)

# ----------------------------- Hooking -------------------------------------- #

def register_attn_hooks(model: GPT2Model):
    """
    Attach hooks to each block's c_attn to capture pre-softmax scores and weights.
    Returns list of handles and a shared capture list.
    """
    captured: List[Dict] = []
    handles = []

    for i, blk in enumerate(model.h):
        attn = blk.attn
        n_head = attn.num_heads
        head_dim = attn.head_dim

        def make_hook(layer_idx, n_head=n_head, head_dim=head_dim):
            def hook_fn(module, input, output):
                # output of c_attn: [B, T, 3*E]
                B, T, threeE = output.shape
                E = threeE // 3

                # Split into q, k, v (still in [B, T, E])
                q, k, v = output[..., :E], output[..., E:2*E], output[..., 2*E:]

                # Reshape to [B, H, T, Dh]
                q = q.view(B, T, n_head, head_dim).transpose(1, 2)  # [B, H, T, Dh]
                k = k.view(B, T, n_head, head_dim).transpose(1, 2)

                # Compute pre-softmax scores: [B, H, T, T]
                scale = 1.0 / math.sqrt(head_dim)
                scores = torch.matmul(q, k.transpose(-1, -2)) * scale

                # Post-softmax weights (for reference)
                weights = F.softmax(scores, dim=-1)

                captured.append({
                    "layer": layer_idx,
                    "scores": scores.detach(),
                    "weights": weights.detach()
                })
            return hook_fn

        h = attn.c_attn.register_forward_hook(make_hook(i))
        handles.append(h)

    return handles, captured

def remove_hooks(handles):
    for h in handles:
        try:
            h.remove()
        except Exception:
            pass

# --------------------------- Core measurement ------------------------------- #

@torch.no_grad()
def measure_empirical_lip_on_scores(
    model: GPT2Model,
    tokenizer: GPT2Tokenizer,
    prompts: List[str],
    p_list: List[float],
    eps_list: List[float],
    num_trials: int,
    device: str,
    max_length: int
) -> pd.DataFrame:
    """
    For each prompt: run GPT-2, capture pre-softmax attention scores S [B,H,T,T].
    For each (layer, head): perturb S by Δ (row-wise in last dim), evaluate
        r = || softmax(S + Δ) - softmax(S) ||_p / || Δ ||_p
    Aggregate: keep max over trials and max over prompts per (layer, head, p, eps).
    """
    handles, captured = register_attn_hooks(model)

    ratios: Dict[Tuple[int, int, float, float], List[float]] = {}
    
    prompt_no = 0
    for text in prompts:
        prompt_no = prompt_no + 1
        print("Started processing prompt no: ", prompt_no)
        captured.clear()
        enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
        _ = model(**enc)

        for blob in captured:
            layer_idx = blob["layer"]
            S_all = blob["scores"]  # [B,H,T,T]
            B, H, T, _ = S_all.shape

            for h in range(H):
                S = S_all[:, h, :, :]  # [B, T, T]
                S = S.to(torch.float32)

                for p in p_list:
                    for eps in eps_list:
                        ratios_trials = []
                        for _ in range(num_trials):
                            delta = torch.randn_like(S)
                            dv = delta.reshape(B, -1)
                            denom = lp_norm(dv, p, dim=1, keepdim=True).clamp_min(1e-12)
                            dv = eps * dv / denom
                            delta = dv.view_as(S)
                            delta = delta.to(torch.float32)
                            
                            Sp = S + delta
                            Wp = F.softmax(Sp, dim=-1)
                            W0 = F.softmax(S,  dim=-1)

                            num = lp_norm(Wp - W0, p, dim=-1)      # [B, T]
                            den = lp_norm(delta, p, dim=-1).clamp_min(1e-12)
                            ratio = (num / den).max().item()
                            ratios_trials.append(ratio)

                        key = (layer_idx, h, p, eps)
                        ratios.setdefault(key, []).append(float(np.max(ratios_trials)))

    remove_hooks(handles)

    # Aggregate across prompts
    rows = []
    for (layer_idx, head, p, eps), vals in ratios.items():
        rows.append({
            "layer": layer_idx,
            "head": head,
            "p": p,
            "epsilon": eps,
            "ratio_max_over_prompts": float(np.max(vals)),
        })
    df = pd.DataFrame(rows)
    if not df.empty:
        df = df.sort_values(["layer", "head", "p", "epsilon"]).reset_index(drop=True)
    return df

def plot_agg(df: pd.DataFrame, p_list: List[float], out_png: str, y_ref: float = 0.5):
    """
    For each p, plot max over layers of ratio_max_over_prompts vs epsilon.
    Save to out_png.
    """
    base, ext = os.path.splitext(out_png)
    if not ext:
        ext = ".png"
    csv_path = f"{base}.csv"
    os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)

    if df.empty:
        print("No data to plot. (Empty dataframe)")
        return

    agg = (
        df.groupby(["p", "epsilon"], as_index=False)["ratio_max_over_prompts"]
          .max()
          .sort_values(["p", "epsilon"])
    )

    fig = plt.figure(figsize=(6, 4))
    for p_val, g in agg.groupby("p"):
        g = g.sort_values("epsilon")
        plt.plot(g["epsilon"], g["ratio_max_over_prompts"],
                 marker="o", markersize=5, linewidth=2.2, label=f"p={p_val}")
    if y_ref is not None:
        plt.axhline(y_ref, linestyle="--", linewidth=2.0)
    plt.xscale("log")
    plt.xlabel(r"Perturbation $\epsilon$", fontsize=15)
    plt.ylabel("Empirical $L_p$", fontsize=15)
    
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{base}{ext}", dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved plot to: {base}{ext}")

# ------------------------------- CLI ---------------------------------------- #

def parse_args():
    parser = argparse.ArgumentParser(description="Empirical Lipschitz of softmax over attention (GPT-2)")
    parser.add_argument("--model", type=str, default="gpt2", help="HF model name (e.g., gpt2, gpt2-medium)")
    parser.add_argument("--dataset", type=str, default="piqa", choices=["piqa", "hellaswag"],
                        help="Dataset to draw prompts from")
    parser.add_argument("--split", type=str, default="train", help="Dataset split")
    parser.add_argument("--num_prompts", type=int, default=100, help="Number of prompts to sample")
    parser.add_argument("--p_list", type=float, nargs="+",
                        default=[1.0, 2.0, 5.0, 10.0, float('inf')], help="p norms")
    parser.add_argument("--eps_list", type=float, nargs="+",
                        default=[5e-3, 1e-2, 5e-2, 1e-1, 5e-1, 1, 5, 10], help="epsilon magnitudes")
    parser.add_argument("--num_trials", type=int, default=5, help="Random directions per (p, eps)")
    parser.add_argument("--max_length", type=int, default=256, help="Tokenizer truncation length")
    parser.add_argument("--seed", type=int, default=0, help="RNG seed (set None for full randomness)")
    parser.add_argument("--device", type=str, default="auto", help="'auto'|'cuda'|'cpu'")
    parser.add_argument("--out_png", type=str, default="empirical_Lp_gpt2.png", help="Output plot filename")
    return parser.parse_args()

# -------------------------------- Main -------------------------------------- #

def main():
    args = parse_args()

    if args.seed is not None:
        set_seed(args.seed)

    # device resolution
    if args.device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = args.device
    print("Device:", device)

    # Load model/tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(args.model)
    model = GPT2Model.from_pretrained(args.model)
    model.eval().to(device)

    # Load prompts
    prompts = load_prompts(args.dataset, args.split, args.num_prompts, seed=(args.seed or 0))
    print(f"Loaded {len(prompts)} prompts from {args.dataset}:{args.split}.")

    # Measure
    df = measure_empirical_lip_on_scores(
        model=model,
        tokenizer=tokenizer,
        prompts=prompts,
        p_list=args.p_list,
        eps_list=args.eps_list,
        num_trials=args.num_trials,
        device=device,
        max_length=args.max_length
    )

    # Plot
    plot_agg(df, p_list=args.p_list, out_png=args.out_png, y_ref=0.5)

if __name__ == "__main__":
    main()
