# Script for empirical evaluation of softmax Lipschitz constant — supplementary material for TMLR submission.

import os
import math
import argparse
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import pandas as pd

import torchvision
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from torchvision.models import (
    vit_b_16, vit_l_16, vit_h_14,
    ViT_B_16_Weights, ViT_L_16_Weights, ViT_H_14_Weights
)

# ----------------------------- Utilities ------------------------------------ #

def set_seed(seed: int) -> None:
    torch.manual_seed(seed)
    np.random.seed(seed)

def get_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def lp_norm(t: torch.Tensor, p: float, dim=None, keepdim=False) -> torch.Tensor:
    if p == float('inf'):
        return t.abs().amax(dim=dim, keepdim=keepdim)
    return (t.abs()**p).sum(dim=dim, keepdim=keepdim)**(1.0/p)

def normalize_to_lp(delta: torch.Tensor, p: float, eps: float) -> torch.Tensor:
    """
    Normalize each sample (first dim) of `delta` to have Lp norm = eps.
    """
    B = delta.shape[0]
    flat = delta.reshape(B, -1)

    if p == float("inf"):
        maxabs = flat.abs().amax(dim=-1).clamp(min=1e-12)
        scale = (eps / maxabs).view(B, 1)
        out = flat * scale
    else:
        lp = lp_norm(flat, p, dim=-1).clamp(min=1e-12)
        scale = (eps / lp).view(B, 1)
        out = flat * scale

    return out.view_as(delta)

# ----------------------------- Data ----------------------------------------- #

def build_transform(image_size: int) -> transforms.Compose:
    return transforms.Compose([
        transforms.Resize(max(image_size, 32), interpolation=InterpolationMode.BICUBIC),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ])

def make_loader(dataset_choice: str, dataset_path: str, image_size: int,
                batch_size: int, seed: int) -> DataLoader:
    """
    Returns a DataLoader over the chosen dataset.
    - CIFAR10/100: test split, shuffled
    - ImageFolder: arbitrary root (e.g., ./Imagenet/val), shuffled
    - Random: single synthetic batch of size batch_size
    """
    tfm = build_transform(image_size)

    if dataset_choice.upper() == "CIFAR10":
        ds = CIFAR10(root=dataset_path, train=False, download=True, transform=tfm)
        shuffle = True
    elif dataset_choice.upper() == "CIFAR100":
        ds = CIFAR100(root=dataset_path, train=False, download=True, transform=tfm)
        shuffle = True
    elif dataset_choice.upper() == "IMAGEFOLDER":
        ds = ImageFolder(root=dataset_path, transform=tfm)
        shuffle = True
    elif dataset_choice.upper() == "RANDOM":
        g = torch.Generator().manual_seed(seed)
        x = torch.randn(batch_size, 3, image_size, image_size, generator=g)
        x = (x - x.min()) / (x.max() - x.min())
        x = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))(x)
        y = torch.zeros(len(x), dtype=torch.long)
        from torch.utils.data import TensorDataset
        ds = TensorDataset(x, y)
        shuffle = False
    else:
        raise ValueError(f"Unsupported dataset_choice: {dataset_choice}")

    loader = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=2, pin_memory=True)
    return loader

# ----------------------------- Model + Hooks -------------------------------- #

def load_vit(model_name: str, use_pretrained: bool, device: torch.device):
    """
    Load torchvision ViT and associated preprocessing transforms.
    """
    model_name = model_name.lower()

    if model_name == "vit_b_16":
        if use_pretrained:
            try:
                weights = ViT_B_16_Weights.IMAGENET1K_V1
                model = vit_b_16(weights=weights)
                preprocess = weights.transforms()
            except Exception:
                print("Pretrained weights unavailable. Falling back to random init.")
                model = vit_b_16(weights=None)
                preprocess = ViT_B_16_Weights.IMAGENET1K_V1.transforms()
        else:
            model = vit_b_16(weights=None)
            preprocess = ViT_B_16_Weights.IMAGENET1K_V1.transforms()

    elif model_name == "vit_l_16":
        if use_pretrained:
            try:
                weights = ViT_L_16_Weights.IMAGENET1K_V1
                model = vit_l_16(weights=weights)
                preprocess = weights.transforms()
            except Exception:
                print("Pretrained weights unavailable. Falling back to random init.")
                model = vit_l_16(weights=None)
                preprocess = ViT_L_16_Weights.IMAGENET1K_V1.transforms()
        else:
            model = vit_l_16(weights=None)
            preprocess = ViT_L_16_Weights.IMAGENET1K_V1.transforms()

    elif model_name == "vit_h_14":
        if use_pretrained:
            try:
                weights = ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1
                model = vit_h_14(weights=weights)
                preprocess = weights.transforms()
            except Exception:
                print("Pretrained weights unavailable. Falling back to random init.")
                model = vit_h_14(weights=None)
                preprocess = ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1.transforms()
        else:
            model = vit_h_14(weights=None)
            preprocess = ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1.transforms()
    else:
        raise ValueError(f"Unsupported model_name: {model_name}")

    model.eval().to(device)
    return model, preprocess

def register_qkv_hooks(model: nn.Module, max_layers: int) -> Tuple[List[torch.utils.hooks.RemovableHandle], Dict[int, torch.Tensor]]:
    """
    Register forward hooks on MultiheadAttention modules to capture concatenated qkv
    BEFORE the internal attention computation.
    Returns (handles, qkv_buffers).
    """
    qkv_buffers: Dict[int, torch.Tensor] = {}
    handles: List[torch.utils.hooks.RemovableHandle] = []

    for i, blk in enumerate(model.encoder.layers):
        if i >= max_layers:
            break
        attn = blk.self_attention
        if isinstance(attn, nn.MultiheadAttention):
            def make_hook(idx):
                def hook(mod: nn.MultiheadAttention, inp, out):
                    q_in, k_in, v_in = inp[:3]
                    E = mod.embed_dim
                    if mod._qkv_same_embed_dim:
                        W = mod.in_proj_weight
                        b = mod.in_proj_bias
                        Wq, Wk, Wv = W[:E], W[E:2*E], W[2*E:]
                        bq = b[:E] if b is not None else None
                        bk = b[E:2*E] if b is not None else None
                        bv = b[2*E:] if b is not None else None
                        q = F.linear(q_in, Wq, bq)
                        k = F.linear(k_in, Wk, bk)
                        v = F.linear(v_in, Wv, bv)
                    else:
                        q = F.linear(q_in, mod.q_proj_weight,
                                     mod.in_proj_bias[:E] if mod.in_proj_bias is not None else None)
                        k = F.linear(k_in, mod.k_proj_weight,
                                     mod.in_proj_bias[E:2*E] if mod.in_proj_bias is not None else None)
                        v = F.linear(v_in, mod.v_proj_weight,
                                     mod.in_proj_bias[2*E:] if mod.in_proj_bias is not None else None)
                    qkv = torch.cat([q, k, v], dim=-1)  # [*, 3E]
                    qkv_buffers[idx] = qkv.detach()
                return hook
            h = attn.register_forward_hook(make_hook(i))
            handles.append(h)

    return handles, qkv_buffers

def remove_hooks(handles: List[torch.utils.hooks.RemovableHandle]) -> None:
    for h in handles:
        try:
            h.remove()
        except Exception:
            pass

# ----------------------------- Attention math ------------------------------- #

def reshape_qkv(qkv_cat: torch.Tensor, num_heads: int, embed_dim: int):
    """
    qkv_cat: [B, N, 3*E] -> q,k,v each [B,H,N,D] with E=H*D.
    """
    B, N, threeE = qkv_cat.shape
    assert threeE == 3 * embed_dim, f"Expected 3*{embed_dim}, got {threeE}"
    qkv = qkv_cat.reshape(B, N, 3, num_heads, embed_dim // num_heads).permute(0, 2, 3, 1, 4)  # [B,3,H,N,D]
    q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]  # each [B,H,N,D]
    return q, k, v

def scores_from_qk(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
    """
    q,k: [B,H,N,D] -> S: [B,H,N,N] with 1/sqrt(D) scaling.
    """
    D = q.shape[-1]
    S = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
    return S

# ----------------------------- Experiment core ------------------------------ #

def compute_stability(
    S_by_layer: List[Tuple[int, torch.Tensor]],
    p_list: List[float],
    epsilons: List[float],
    max_heads: int,
    seed: int,
    accum: Optional[dict] = None,   
) -> dict:
    """
    If accum is None (first batch), it will be initialized.
    For each (layer,p,eps), update accum[key] = max(previous, current_batch_ratio_max).
    Returns the accum dict (to be reused across batches).
    """
    if accum is None:
        accum = {}

    for (layer_idx, S) in S_by_layer:
        B, H, N, _ = S.shape
        H_sel = H if (max_heads is None or max_heads <= 0) else min(H, max_heads)
        S = S[:, :H_sel, :, :].contiguous()
        S_flat = S.view(B * H_sel, N, N)

        with torch.no_grad():
            A = F.softmax(S_flat, dim=-1)

        for p in p_list:
            for eps in epsilons:
                torch.manual_seed(seed + int(1e6 * eps) + (0 if p == float('inf') else int(100 * p)))
                delta = torch.randn_like(S_flat)
                delta = normalize_to_lp(delta, p, eps)
                S_pert = S_flat + delta

                with torch.no_grad():
                    A_pert = F.softmax(S_pert, dim=-1)

                num = lp_norm(A_pert - A, p, dim=-1)    # [num_samples]
                den = lp_norm(delta, p, dim=-1)
                r_sample = (num / den.clamp(min=1e-12)).cpu()
                key = (layer_idx, p, eps)

                cur = float(r_sample.max())
                prev = accum.get(key, None)
                if prev is None:
                    accum[key] = cur
                else:
                    accum[key] = max(prev, cur)

    return accum

def aggregate_and_plot(
    df: pd.DataFrame,
    p_list: List[float],
    out_png: str,
    lip_line: float = 0.5
) -> None:
    """
    Aggregate by (p, epsilon) across layers using max, and plot ratio_max vs epsilon for each p.
    """
    agg = (
        df.groupby(["p", "epsilon"], as_index=False)["ratio_max"]
          .max()
          .sort_values(["p", "epsilon"])
    )

    fig = plt.figure(figsize=(6, 4))
    for p_val, g in agg.groupby("p"):
        g = g.sort_values("epsilon")
        plt.plot(g["epsilon"], g["ratio_max"], marker="o", markersize=5, linewidth=2.2, label=f"p = {p_val}")

    if lip_line is not None:
        plt.axhline(lip_line, linestyle="--", linewidth=2.0)

    plt.xscale("log")
    plt.xlabel(r"Perturbation $\epsilon$", fontsize=15)
    plt.ylabel("Empirical $L_p$", fontsize=15)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close(fig)

# ----------------------------- Main ----------------------------------------- #

def main():
    parser = argparse.ArgumentParser(description="ViT Attention Stability under Score Noise (ImageFolder/CIFAR, batched, running max)")
    parser.add_argument("--model_name", type=str, default="vit_b_16",
                        choices=["vit_b_16", "vit_l_16", "vit_h_14"])
    parser.add_argument("--use_pretrained", action="store_true", default=True)
    parser.add_argument("--image_size", type=int, default=224)

    parser.add_argument("--dataset_choice", type=str, default="ImageFolder",
                        choices=["CIFAR10", "CIFAR100", "ImageFolder", "Random"])
    parser.add_argument("--dataset_path", type=str, default="./data",
                        help="Path root. For ImageFolder, set to e.g. ./Imagenet/val. For CIFAR*/Random, ./data is always used.")
    parser.add_argument("--num_images", type=int, default=100,
                        help="Total number of images to evaluate across batches.")
    parser.add_argument("--batch_size", type=int, default=16,
                        help="Per-batch size; lower it for large models to fit memory.")
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument("--epsilons", type=float, nargs="+",
                        default=[1e-3, 1e-2, 1e-1, 1, 10, 100])
    parser.add_argument("--p_list", type=float, nargs="+",
                        default=[1, 2, 10, 25, float("inf")])
    parser.add_argument("--max_layers", type=int, default=24,
                        help="Analyze at most this many early encoder blocks")
    parser.add_argument("--max_heads", type=int, default=0,
                        help="0 or negative = use all heads; else limit per layer")
    parser.add_argument("--out_png", type=str, default="empirical_Lp_ViT.png")

    args = parser.parse_args()

    # Enforce ./data for CIFAR/Random; allow custom path only for ImageFolder
    if args.dataset_choice != "ImageFolder":
        args.dataset_path = "./data"

    set_seed(args.seed)
    device = get_device()

    # Model
    model, _ = load_vit(args.model_name, args.use_pretrained, device)

    # Quick summary
    num_layers = len(model.encoder.layers)
    num_heads = model.encoder.layers[0].self_attention.num_heads
    embed_dim = model.hidden_dim
    head_dim = embed_dim // num_heads
    print(f"Model: {args.model_name} | layers={num_layers} | heads={num_heads} | "
          f"embed_dim={embed_dim} | head_dim={head_dim} | device={device}")

    # Data loader (iterate until num_images are processed)
    loader = make_loader(args.dataset_choice, args.dataset_path, args.image_size,
                         args.batch_size, args.seed)

    processed = 0
    accum = None  # running maxima dict {(layer,p,eps) -> ratio_max}

    for xb, _ in loader:
        xb = xb.to(device)
        this_bs = xb.shape[0]
        if processed >= args.num_images:
            break
        # Trim last batch if overshooting target
        if processed + this_bs > args.num_images:
            xb = xb[: (args.num_images - processed)]
            this_bs = xb.shape[0]

        # Hooks to capture qkv for THIS batch
        handles, qkv_buffers = register_qkv_hooks(model, max_layers=max(1, args.max_layers))
        with torch.no_grad():
            _ = model(xb)
        remove_hooks(handles)

        if len(qkv_buffers) == 0:
            raise RuntimeError("No qkv tensors were captured. Check max_layers and model internals.")

        # Build per-layer S for this batch
        S_by_layer: List[Tuple[int, torch.Tensor]] = []
        for i in sorted(qkv_buffers.keys()):
            qkv_cat = qkv_buffers[i].to(device)  # [B, N, 3E]
            q, k, v = reshape_qkv(qkv_cat, num_heads=num_heads, embed_dim=embed_dim)  # [B,H,N,D]
            S = scores_from_qk(q, k)  # [B,H,N,N]
            S_by_layer.append((i, S))

        # Update running maxima
        accum = compute_stability(
            S_by_layer=S_by_layer,
            p_list=args.p_list,
            epsilons=args.epsilons,
            max_heads=args.max_heads,
            seed=args.seed,
            accum=accum,   # first batch initializes, later batches take max
        )

        processed += this_bs
        print(f"Processed {processed}/{args.num_images} images...")

    # Convert accum -> DataFrame (one row per (layer,p,eps))
    rows = []
    for (layer_idx, p, eps), ratio_max in sorted(accum.items()):
        rows.append({
            "layer": layer_idx,
            "p": p,
            "epsilon": eps,
            "ratio_max": float(ratio_max),
            "num_samples": None,
        })
    df = pd.DataFrame(rows).sort_values(["layer", "p", "epsilon"])

    # Aggregate and plot
    aggregate_and_plot(df, p_list=args.p_list, out_png=args.out_png, lip_line=0.5)
    print(f"Saved plot to: {os.path.abspath(args.out_png)}")

if __name__ == "__main__":
    main()
