# Script for empirical evaluation of softmax Lipschitz constant — supplementary material for TMLR submission.

import os, math, random, argparse
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
from typing import Optional, Tuple

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env, make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.utils import set_random_seed


# ----------------------------- Env helpers ---------------------------------- #

def has_ale() -> bool:
    try:
        import ale_py  # noqa
        return True
    except Exception:
        return False


def make_env(env_choice: str, n_envs: int, seed: Optional[int] = None):
    c = env_choice.lower().strip()
    if c == "auto":
        c = "atari_pong" if has_ale() else "cartpole"

    if c == "atari_pong":
        if not has_ale():
            print("[WARN] ALE not available; falling back to CartPole-v1.")
            env = make_vec_env("CartPole-v1", n_envs=n_envs, seed=seed)
            policy = "MlpPolicy"
            return env, policy
        env = make_atari_env("ALE/Pong-v5", n_envs=n_envs, seed=seed)
        env = VecFrameStack(env, n_stack=4)
        policy = "CnnPolicy"
    elif c == "cartpole":
        env = make_vec_env("CartPole-v1", n_envs=n_envs, seed=seed)
        policy = "MlpPolicy"
    elif c == "lunarlander":
        env = make_vec_env("LunarLander-v2", n_envs=n_envs, seed=seed)
        policy = "MlpPolicy"
    else:
        raise ValueError("ENV must be 'auto', 'cartpole', 'lunarlander', or 'atari_pong'.")
    return env, policy


# ----------------------------- Core utilities ------------------------------- #

def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    set_random_seed(seed)


def lp_norm(t: torch.Tensor, p: float, dim=None, keepdim=False):
    if p == float('inf'):
        return t.abs().amax(dim=dim, keepdim=keepdim)
    return (t.abs()**p).sum(dim=dim, keepdim=keepdim)**(1.0/p)


softmax = nn.Softmax(dim=-1)


def obs_to_logits(policy, obs_batch: np.ndarray) -> torch.Tensor:
    """Return pre-softmax logits [B, A] from SB3 policy for a batch of observations."""
    obs_t, _ = policy.obs_to_tensor(obs_batch)
    with torch.no_grad():
        dist = policy.get_distribution(obs_t)
        logits = None
        # Most SB3 distributions expose distribution.logits
        try:
            logits = dist.distribution.logits
        except Exception:
            pass
        # Fallback: explicit head
        if logits is None:
            latent_pi, _ = policy._get_latent(obs_t)
            logits = policy.action_net(latent_pi)
    return logits  # [B, A]


def state_loader(states: np.ndarray, batch_size: int):
    N = states.shape[0]
    for i in range(0, N, batch_size):
        yield states[i:i+batch_size]


def collect_observations(env, model, n_total: int) -> np.ndarray:
    """Roll out vectorized env to collect n_total observations (on-policy)."""
    obs = env.reset()
    if isinstance(obs, tuple) and len(obs) == 2:
        obs, _ = obs
    states = []
    while len(states) < n_total:
        actions, _ = model.predict(obs, deterministic=False)
        step_out = env.step(actions)
        if isinstance(step_out, tuple) and len(step_out) >= 4:
            obs, rewards, dones, infos = step_out[:4]
        else:
            obs = step_out
        if isinstance(obs, np.ndarray):
            # VecEnv: obs shape [n_envs, obs_dim...]
            for i in range(obs.shape[0]):
                states.append(obs[i])
                if len(states) >= n_total:
                    break
        else:
            # Dict observation space: take first key
            if isinstance(obs, dict):
                first_key = list(obs.keys())[0]
                arr = obs[first_key]
                for i in range(arr.shape[0]):
                    states.append(arr[i])
                    if len(states) >= n_total:
                        break
            else:
                raise RuntimeError("Unsupported observation type from environment.")
    return np.stack(states, axis=0)  # [N, obs_dim...]


# ----------------------------- Main experiment ------------------------------ #

def run(args):
    device = "cuda" if (torch.cuda.is_available() and not args.force_cpu) else "cpu"
    print("Device:", device)
    if args.seed is not None:
        set_all_seeds(args.seed)

    # Environment + policy
    env, policy_type = make_env(args.env, args.num_envs, args.seed)
    print(f"Resolved env: {args.env} | Policy: {policy_type}")

    policy_kwargs = dict(net_arch=[256, 256]) if policy_type == "MlpPolicy" else {}
    if args.pretrained and os.path.exists(args.pretrained):
        print(f"Loading pretrained PPO from: {args.pretrained}")
        model = PPO.load(args.pretrained, env=env, device=device)
    else:
        model = PPO(policy_type, env, verbose=0, tensorboard_log=None,
                    policy_kwargs=policy_kwargs, device=device)
        if args.train_steps > 0:
            print(f"Training PPO for {args.train_steps} timesteps...")
            model.learn(total_timesteps=args.train_steps)
        else:
            print("Using randomly initialized policy (no training).")

    policy = model.policy.to(device).eval()

    # Collect states
    states = collect_observations(env, model, args.num_states)
    print("Collected states:", states.shape)

    # Sanity: logits on one batch
    test_batch = next(state_loader(states, args.batch_size))
    with torch.no_grad():
        test_logits = obs_to_logits(policy, test_batch)
    print("Actor logits shape:", tuple(test_logits.shape))

    # Sweep (tau, p, eps)
    rows = []
    torch.set_grad_enabled(False)
    for tau in args.tau_list:
        for p in args.p_list:
            for eps in args.eps_list:
                per_batches = []
                for obs_batch in state_loader(states, args.batch_size):
                    z0 = obs_to_logits(policy, obs_batch)
                    p0 = softmax(z0 / tau)

                    per_trials = []
                    for _ in range(args.num_trials):
                        delta = torch.randn_like(z0)
                        denom = lp_norm(delta, p, dim=1, keepdim=True).clamp_min(1e-12)
                        delta = eps* delta / denom
                        z1 = z0 + delta
                        p1 = softmax(z1 / tau)
                        # empirical Lipschitz ratio: ||Δπ||_p / ε
                        ratio = lp_norm(p1 - p0, p, dim=-1) / eps
                        per_trials.append(ratio.max().item())

                    per_batches.append(float(np.max(per_trials)))
                rows.append({
                    "env_choice": args.env,
                    "tau": float(tau),
                    "p": float(p),
                    "epsilon": float(eps),
                    "ratio_max_over_states": float(np.max(per_batches)),
                    "n_batches": int(len(per_batches)),
                    "num_trials": int(args.num_trials),
                })

    df = pd.DataFrame(rows).sort_values(["tau", "p", "epsilon"]).reset_index(drop=True)

    # Save plots
    base, ext = os.path.splitext(args.out_png)
    if ext == "":
        ext = ".png"

    # Plot per-tau; filenames derived from --out_png basename
    def p_label(p):
        return "∞" if np.isinf(p) else str(int(p))

    for tau in sorted(df["tau"].unique()):
        fig = plt.figure(figsize=(6, 4))
        df_tau = df[df["tau"] == tau]
        for p in args.p_list:
            sub = df_tau[df_tau["p"] == float(p)].sort_values("epsilon")
            if sub.empty:
                continue
            xs = sub["epsilon"].values
            ys = sub["ratio_max_over_states"].values
            plt.plot(xs, ys, marker="o", markersize=5, linewidth=2.2,
                     label=f"p={p_label(p)}")
        # baseline ~ 0.5 / tau
        plt.axhline(0.5 / float(tau), linestyle="--", linewidth=2, label=r"$\lambda/2$")
        plt.xscale('log')
        plt.xlabel(r"Perturbation $\epsilon$", fontsize=15)
        plt.ylabel("Empirical $L_p$", fontsize=15)
        plt.legend()
        plt.tight_layout()

        out_path = f"{base}_tau_{tau}{ext}"
        plt.savefig(out_path, dpi=300, bbox_inches="tight")
        plt.close(fig)
        print(f"Saved plot: {out_path}")


# ----------------------------- CLI ----------------------------------------- #

def parse_args():
    parser = argparse.ArgumentParser(description="Empirical Lipschitz of softmax in RL (actor; logits perturbation)")
    # Data collection
    parser.add_argument("--env", type=str, default="cartpole",
                        help="auto | cartpole | lunarlander | atari_pong")
    parser.add_argument("--num_envs", type=int, default=4, help="Number of parallel envs")
    parser.add_argument("--num_states", type=int, default=256, help="Total states to collect")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for logits/perturb eval")
    parser.add_argument("--seed", type=int, default=None, help="RNG seed")

    # PPO policy (random or brief training, or load)
    parser.add_argument("--train_steps", type=int, default=0, help="If >0, briefly train PPO for these timesteps")
    parser.add_argument("--pretrained", type=str, default=None, help="Path to PPO .zip to load (overrides train_steps)")
    parser.add_argument("--force_cpu", action="store_true", help="Force CPU even if CUDA is available")

    # Perturbation sweep
    parser.add_argument("--p_list", type=float, nargs="+", default=[1.0, 2.0, 8.0, 10.0, float('inf')],
                        help="p norms to evaluate")
    parser.add_argument("--eps_list", type=float, nargs="+", default=[1e-2, 5e-2, 1e-1, 5e-1, 1, 5, 10, 50, 100],
                        help="epsilon magnitudes")
    parser.add_argument("--tau_list", type=float, nargs="+", default=[0.25, 0.5, 2.0],
                        help="softmax temperatures")
    parser.add_argument("--num_trials", type=int, default=5, help="Random directions per (p, eps) per batch")

    # Outputs
    parser.add_argument("--out_png", type=str, default="empirical_Lp.png",
                        help="Basename for plots (plots saved as <basename>_tau_<T>.png)")

    return parser.parse_args()


def main():
    args = parse_args()
    run(args)


if __name__ == "__main__":
    main()

