# Script for empirical evaluation of softmax Lipschitz constant — supplementary material for TMLR submission.

import os
import argparse
from typing import List, Tuple, Dict

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import pandas as pd

from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from torchvision.models import resnet50, ResNet50_Weights

# -------------------------- Utils -------------------------- #

def set_seed(seed: int) -> None:
    torch.manual_seed(seed)
    np.random.seed(seed)

def get_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def lp_norm(t: torch.Tensor, p: float, dim=None, keepdim=False) -> torch.Tensor:
    if p == float('inf'):
        return t.abs().amax(dim=dim, keepdim=keepdim)
    return (t.abs()**p).sum(dim=dim, keepdim=keepdim)**(1.0/p)

def normalize_to_lp(delta: torch.Tensor, p: float, eps: float) -> torch.Tensor:
    """
    Normalize per-sample perturbation (first dim is batch) to have Lp norm = eps.
    """
    B = delta.shape[0]
    flat = delta.reshape(B, -1)
    if p == float("inf"):
        scale = eps / flat.abs().amax(dim=-1).clamp(min=1e-12)
    else:
        scale = eps / lp_norm(flat, p, dim=-1).clamp(min=1e-12)
    return (flat * scale.view(B, 1)).view_as(delta)

# -------------------------- Data -------------------------- #

def build_transform(image_size: int) -> transforms.Compose:
    # ImageNet normalization
    return transforms.Compose([
        transforms.Resize(max(32, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ])

def make_loader(
    dataset_choice: str,
    dataset_path: str,
    image_size: int,
    batch_size: int,
    seed: int
) -> DataLoader:
    """
    Build a DataLoader for CIFAR10/CIFAR100/ImageFolder/Random exactly as requested.
    """
    tfm = build_transform(image_size)

    if dataset_choice.upper() == "CIFAR10":
        ds = torch.utils.data.ConcatDataset([
            CIFAR10(root=dataset_path, train=True,  download=True, transform=tfm),
            CIFAR10(root=dataset_path, train=False, download=True, transform=tfm),
        ])
        shuffle = True

    elif dataset_choice.upper() == "CIFAR100":
        ds = torch.utils.data.ConcatDataset([
            CIFAR100(root=dataset_path, train=True,  download=True, transform=tfm),
            CIFAR100(root=dataset_path, train=False, download=True, transform=tfm),
        ])
        shuffle = True

    elif dataset_choice.upper() == "IMAGEFOLDER":
        ds = ImageFolder(root=dataset_path, transform=tfm)
        shuffle = True
    elif dataset_choice.upper() == "RANDOM":
        g = torch.Generator().manual_seed(seed)
        x = torch.randn(batch_size, 3, image_size, image_size, generator=g)
        x = (x - x.min()) / (x.max() - x.min())
        x = transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))(x)
        y = torch.zeros(len(x), dtype=torch.long)
        from torch.utils.data import TensorDataset
        ds = TensorDataset(x, y)
        shuffle = False
    else:
        raise ValueError(f"Unsupported dataset_choice: {dataset_choice}")

    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=2, pin_memory=True)

# -------------------------- Model -------------------------- #

@torch.no_grad()
def forward_logits(model: torch.nn.Module, x: torch.Tensor, device: torch.device) -> torch.Tensor:
    """
    Forward pass to obtain classification logits [B, C].
    """
    x = x.to(device)
    logits = model(x)
    if not isinstance(logits, torch.Tensor):
        raise RuntimeError("Model forward did not return a tensor of logits.")
    return logits

def load_model(pretrained: bool, device: torch.device) -> torch.nn.Module:
    weights = ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
    model = resnet50(weights=weights)
    model.eval().to(device)
    return model

# -------------------------- Experiment -------------------------- #

def empirical_softmax_lipschitz(
    logits: torch.Tensor,
    p_list: List[float],
    eps_list: List[float],
    num_trials: int,
    seed: int
) -> pd.DataFrame:
    """
    For logits z (shape [B, C]), compute:
        r = || softmax(z + Δ) - softmax(z) ||_p / ||Δ||_p
    for each p in p_list and epsilon in eps_list.
    Return a DataFrame with aggregated results (max over trials, averaged over batch).
    Schema (matching request): p, epsilon, ratio_max_over_trials_max, num_samples, num_trials.
    """
    torch.manual_seed(seed)
    B, C = logits.shape

    W0 = torch.softmax(logits, dim=-1)

    rows = []
    for p in p_list:
        for eps in eps_list:
            ratios_trials = []
            for _ in range(num_trials):
                delta = torch.randn_like(logits)
                delta = normalize_to_lp(delta, p, eps)   # per sample Lp = eps
                Wp = torch.softmax(logits + delta, dim=-1)

                num = lp_norm(Wp - W0, p, dim=-1)       # [B]
                den = lp_norm(delta, p, dim=-1).clamp(min=1e-12)  # [B]
                r = (num / den)                         # [B]
                ratios_trials.append(r.cpu().numpy())

            R = np.stack(ratios_trials, axis=0)         # [T, B]
            r_max_over_trials = R.max(axis=0)           # [B]
            rows.append({
                "p": p,
                "epsilon": eps,
                "ratio_max_over_trials_max": float(r_max_over_trials.max()),
                "num_samples": int(B),
                "num_trials": int(num_trials),
            })

    df = pd.DataFrame(rows).sort_values(["p", "epsilon"])
    return df

def plot_curves_with_name(
    df: pd.DataFrame,
    p_list: List[float],
    out_png: str,
    stat: str = "ratio_max_over_trials_max",
    lip_ref: float = 0.5
):
    # Single plot saved to out_png, name provided by --out_png
    os.makedirs(os.path.dirname(out_png) or ".", exist_ok=True)
    fig = plt.figure(figsize=(6, 4))
    for p in p_list:
        sub = df[df["p"] == p].sort_values("epsilon")
        xs = sub["epsilon"].tolist()
        ys = sub[stat].tolist()
        label = "p=∞" if p == float("inf") else f"p={int(p)}"
        plt.plot(xs, ys, marker="o", linewidth=2.0, markersize=5, label=label)

    if lip_ref is not None:
        plt.axhline(lip_ref, linestyle="--", linewidth=2.0)

    plt.xscale("log")
    plt.xlabel(r"Perturbation $\epsilon$", fontsize=15)
    plt.ylabel("Empirical $L_p$", fontsize=15)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved plot: {os.path.abspath(out_png)}")

# -------------------------- Main -------------------------- #

def main():
    ap = argparse.ArgumentParser(description="Empirical Lipschitz of softmax in logit space (selectable datasets, batched aggregation)")
    ap.add_argument("--dataset_choice", type=str, default="CIFAR100",
                    choices=["CIFAR10", "CIFAR100", "ImageFolder", "Random"],
                    help="Dataset to use. For ImageFolder, point --dataset_path to the root (e.g., ./Imagenet/val).")
    ap.add_argument("--dataset_path", type=str, default="./data",
                    help="Path for ImageFolder or CIFAR download root. (Random ignores this.)")
    ap.add_argument("--image_size", type=int, default=224,
                    help="Input size (e.g., 224 for ImageNet models)")

    ap.add_argument("--num_images", type=int, default=25000,
                    help="Total images to evaluate across batches.")
    ap.add_argument("--batch_size", type=int, default=128,
                    help="Per-batch size; reduce for memory-heavy models.")

    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--pretrained", action="store_true", default=True)

    ap.add_argument("--p_list", type=float, nargs="+", default=[1.0, 2.0, 5.0, 10.0, float("inf")])
    ap.add_argument("--eps_list", type=float, nargs="+", default=[1e-3, 1e-2, 1e-1, 1, 10])
    ap.add_argument("--num_trials", type=int, default=5)

    ap.add_argument("--out_png", type=str, default="./images/empirical_Lp_softmax.png")
    args = ap.parse_args()

    # Optional guard: if not ImageFolder, prefer ./data to avoid user path issues.
    if args.dataset_choice != "ImageFolder" and args.dataset_path != "./data":
        print("[info] Overriding --dataset_path to ./data for non-ImageFolder datasets.")
        args.dataset_path = "./data"

    set_seed(args.seed)
    device = get_device()
    print(f"Device: {device}")

    # Data loader (we'll iterate multiple batches until num_images are processed)
    loader = make_loader(args.dataset_choice, args.dataset_path, args.image_size,
                         args.batch_size, args.seed)

    # Model
    model = load_model(pretrained=args.pretrained, device=device)

    processed = 0
    accum: Dict[Tuple[float, float], float] = {}  # key=(p, epsilon) -> running max of ratio_max_over_trials_max

    for xb, _ in loader:
        if processed >= args.num_images:
            break

        # Trim last batch to not exceed num_images
        remaining = args.num_images - processed
        if xb.shape[0] > remaining:
            xb = xb[:remaining]

        # logits for this batch
        logits = forward_logits(model, xb, device)  # [B, C]

        # compute per-batch df (first batch: initialize; later: max-aggregate)
        df_batch = empirical_softmax_lipschitz(
            logits=logits,
            p_list=args.p_list,
            eps_list=args.eps_list,
            num_trials=args.num_trials,
            seed=args.seed
        )

        if not accum:
            # initialize from first batch
            for _, row in df_batch.iterrows():
                key = (float(row["p"]), float(row["epsilon"]))
                accum[key] = float(row["ratio_max_over_trials_max"])
        else:
            # update running max
            for _, row in df_batch.iterrows():
                key = (float(row["p"]), float(row["epsilon"]))
                val = float(row["ratio_max_over_trials_max"])
                prev = accum.get(key, None)
                accum[key] = val if prev is None else max(prev, val)

        processed += xb.shape[0]
        print(f"Processed {processed}/{args.num_images} images...")

    # Build final DataFrame from accum (schema matching request)
    rows = []
    for (p, eps), vmax in sorted(accum.items()):
        rows.append({
            "p": p,
            "epsilon": eps,
            "ratio_max_over_trials_max": float(vmax),
            "num_samples": int(processed),
            "num_trials": int(args.num_trials),
        })
    df_final = pd.DataFrame(rows).sort_values(["p", "epsilon"])


    # Single plot written to --out_png
    plot_curves_with_name(df_final, p_list=args.p_list, out_png=args.out_png,
                          stat="ratio_max_over_trials_max", lip_ref=0.5)

if __name__ == "__main__":
    main()
