# Script for empirical evaluation of softmax Lipschitz constant — supplementary material for TMLR submission.

import os, math, argparse, random
from typing import List
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
from typing import Optional, Tuple

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.utils import set_random_seed

# ----------------------------- Utils ---------------------------------------- #

def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    set_random_seed(seed)

def lp_norm(t: torch.Tensor, p: float, dim=None, keepdim=False):
    if p == float('inf'):
        return t.abs().amax(dim=dim, keepdim=keepdim)
    return (t.abs()**p).sum(dim=dim, keepdim=keepdim)**(1.0/p)

softmax = nn.Softmax(dim=-1)

def obs_to_logits(policy, obs_batch: np.ndarray) -> torch.Tensor:
    """Return pre-softmax logits [B, A] from SB3 policy for a batch of observations."""
    obs_t, _ = policy.obs_to_tensor(obs_batch)
    with torch.no_grad():
        dist = policy.get_distribution(obs_t)
        logits = None
        # Common path: distribution exposes logits
        try:
            logits = dist.distribution.logits
        except Exception:
            pass
        # Fallback: compute from latent head
        if logits is None:
            latent_pi, _ = policy._get_latent(obs_t)
            logits = policy.action_net(latent_pi)
    return logits  # [B, A]

def state_loader(states: np.ndarray, batch_size: int):
    N = states.shape[0]
    for i in range(0, N, batch_size):
        yield states[i:i+batch_size]

def collect_observations(env, model, n_total: int) -> np.ndarray:
    """Roll out vectorized env to collect n_total observations (on-policy)."""
    obs = env.reset()
    if isinstance(obs, tuple) and len(obs) == 2:
        obs, _ = obs
    states = []
    while len(states) < n_total:
        actions, _ = model.predict(obs, deterministic=False)
        step_out = env.step(actions)
        if isinstance(step_out, tuple) and len(step_out) >= 4:
            obs, rewards, dones, infos = step_out[:4]
        else:
            obs = step_out
        if isinstance(obs, np.ndarray):
            for i in range(obs.shape[0]):
                states.append(obs[i])
                if len(states) >= n_total:
                    break
        else:
            if isinstance(obs, dict):
                first_key = list(obs.keys())[0]
                arr = obs[first_key]
                for i in range(arr.shape[0]):
                    states.append(arr[i])
                    if len(states) >= n_total:
                        break
            else:
                raise RuntimeError("Unsupported observation type.")
    return np.stack(states, axis=0)

# ----------------------------- Core ----------------------------------------- #

def run(args):
    device = "cuda" if (torch.cuda.is_available() and not args.force_cpu) else "cpu"
    print("Device:", device)
    if args.seed is not None:
        set_all_seeds(args.seed)

    # Environment (LunarLander with 4 discrete actions)
    env_id = "LunarLander-v3" if args.env_version == 3 else "LunarLander-v2"
    env = make_vec_env(env_id, n_envs=args.num_envs, seed=args.seed)
    print(f"Env: {env_id} | n_envs={args.num_envs}")

    # Build or load PPO policy
    policy_kwargs = dict(net_arch=[256, 256])
    if args.pretrained is not None and os.path.exists(args.pretrained):
        print(f"Loading pretrained PPO from: {args.pretrained}")
        model = PPO.load(args.pretrained, env=env, device=device)
    else:
        model = PPO("MlpPolicy", env, verbose=0, tensorboard_log=None,
                    policy_kwargs=policy_kwargs, device=device)
        if args.train_steps > 0:
            print(f"Training PPO for {args.train_steps} timesteps...")
            model.learn(total_timesteps=args.train_steps)
        else:
            print("Using randomly initialized policy (no training).")

    policy = model.policy.to(device).eval()

    # Confirm number of actions
    try:
        a_n = env.action_space.n
    except AttributeError:
        a_n = env.envs[0].action_space.n
    print("Number of discrete actions A =", a_n)
    if a_n != 4:
        print("[WARN] Expected A=4 for LunarLander; got A=", a_n)

    # Collect states
    states = collect_observations(env, model, args.num_states)
    print("Collected states:", states.shape)

    # Sanity check: one batch of logits
    test_batch = next(state_loader(states, args.batch_size))
    with torch.no_grad():
        test_logits = obs_to_logits(policy, test_batch)
    print("Actor logits shape:", tuple(test_logits.shape))

    # Sweep and measure empirical Lipschitz
    rows = []
    torch.set_grad_enabled(False)
    for tau in args.tau_list:
        for p in args.p_list:
            for eps in args.eps_list:
                per_batches = []
                for obs_batch in state_loader(states, args.batch_size):
                    z0 = obs_to_logits(policy, obs_batch)   # [B, A]
                    p0 = softmax(z0 / tau)

                    per_trials = []
                    for _ in range(args.num_trials):
                        delta = torch.randn_like(z0)
                        denom = lp_norm(delta, p, dim=1, keepdim=True).clamp_min(1e-12)
                        delta = delta / denom                 # normalize direction
                        z1 = z0 + eps * delta
                        p1 = softmax(z1 / tau)

                        # empirical Lipschitz ratio: ||Δπ||_p / ε
                        ratio = lp_norm(p1 - p0, p, dim=-1) / eps
                        per_trials.append(ratio.max().item())

                    per_batches.append(float(np.max(per_trials)))

                rows.append({
                    "tau": float(tau),
                    "p": float(p),
                    "epsilon": float(eps),
                    "ratio_max_over_states": float(np.max(per_batches)),
                    "n_batches": int(len(per_batches)),
                    "num_trials": int(args.num_trials),
                })

    df = pd.DataFrame(rows).sort_values(["tau", "p", "epsilon"]).reset_index(drop=True)

    # Save plots
    base, ext = os.path.splitext(args.out_png)
    if not ext:
        ext = ".png"

    # Plot: if one tau, save exactly to --out_png; else suffix by tau
    def p_label(p): return "∞" if np.isinf(p) else str(int(p))

    unique_taus = sorted(df["tau"].unique())
    for tau in unique_taus:
        fig = plt.figure(figsize=(6, 4))
        df_tau = df[df["tau"] == tau]
        for p in args.p_list:
            sub = df_tau[df_tau["p"] == float(p)].sort_values("epsilon")
            if sub.empty:
                continue
            xs = sub["epsilon"].values
            ys = sub["ratio_max_over_states"].values
            plt.plot(xs, ys, marker="o", markersize=5, linewidth=2.2,
                     label=f"p={p_label(p)}")
        # reference line ~ 0.5 / tau
        plt.axhline(0.5 / float(tau), linestyle="--", linewidth=2, label=r"$\lambda/2$")
        plt.xscale('log')
        plt.xlabel(r"Perturbation $\epsilon$", fontsize=15)
        plt.ylabel("Empirical $L_p$", fontsize=15)
        plt.legend()
        plt.tight_layout()

        if len(unique_taus) == 1:
            out_path = f"{base}{ext}"
        else:
            out_path = f"{base}_tau_{tau}{ext}"
        plt.savefig(out_path, dpi=300, bbox_inches="tight")
        plt.close(fig)
        print(f"Saved plot: {out_path}")

def parse_args():
    parser = argparse.ArgumentParser(description="Empirical Lipschitz of softmax in RL (LunarLander; 4 actions)")
    # Data collection
    parser.add_argument("--num_envs", type=int, default=4, help="Number of parallel envs")
    parser.add_argument("--num_states", type=int, default=256, help="Total states to collect")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for logits/perturb eval")
    parser.add_argument("--seed", type=int, default=None, help="RNG seed")
    parser.add_argument("--env_version", type=int, choices=[2,3], default=3, help="Gymnasium LunarLander version (2 or 3)")

    # PPO policy (train briefly, or load)
    parser.add_argument("--train_steps", type=int, default=0, help="If >0, briefly train PPO for these timesteps")
    parser.add_argument("--pretrained", type=str, default=None, help="Path to PPO .zip to load (overrides train_steps)")
    parser.add_argument("--force_cpu", action="store_true", help="Force CPU even if CUDA is available")

    # Perturbation sweep
    parser.add_argument("--p_list", type=float, nargs="+", default=[1.0, 2.0, 4.0, 8.0, float('inf')],
                        help="p norms to evaluate")
    parser.add_argument("--eps_list", type=float, nargs="+", default=[1e-2, 5e-2, 1e-1, 5e-1, 1, 5, 10, 50, 100],
                        help="epsilon magnitudes")
    parser.add_argument("--tau_list", type=float, nargs="+", default=[0.25, 0.5, 1.0, 2.0],
                        help="softmax temperatures")
    parser.add_argument("--num_trials", type=int, default=5, help="Random directions per (p, eps) per batch")

    # Outputs
    parser.add_argument("--out_png", type=str, default="empirical_Lp.png",
                        help="Plot filename (if multiple tau values, suffix '_tau_T' is added)")

    return parser.parse_args()

def main():
    args = parse_args()
    run(args)

if __name__ == "__main__":
    main()
