{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dd090b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved 4 PDFs (and PNGs) to: loss_landscape/\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as colors\n",
    "from matplotlib.ticker import MaxNLocator\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from functools import partial\n",
    "from mpl_toolkits.mplot3d import Axes3D  # noqa: F401\n",
    "\n",
    "\n",
    "mpl.rcParams.update({\n",
    "    \"font.family\": \"serif\",\n",
    "    \"mathtext.fontset\": \"cm\",\n",
    "    \"font.size\": 12,\n",
    "    \"axes.labelsize\": 12,\n",
    "    \"axes.titlesize\": 18,\n",
    "    \"xtick.labelsize\": 11,\n",
    "    \"ytick.labelsize\": 11,\n",
    "    \"figure.dpi\": 150,\n",
    "    \"savefig.dpi\": 350,\n",
    "})\n",
    "\n",
    "\n",
    "r1, r2 = 0.0, 1.0\n",
    "lambda_reg = 1.0\n",
    "theta_lo, theta_hi = -50.0, 50.0\n",
    "grid_n = 50\n",
    "\n",
    "alpha = 0.1\n",
    "prior = jnp.array([1.0, 1.0])\n",
    "\n",
    "cmap = plt.get_cmap(\"viridis\")\n",
    "\n",
    "\n",
    "def f_reg_tsallis(pi, a):\n",
    "    return (pi**a - a*pi + a - 1) / (a * (a - 1))\n",
    "\n",
    "def neg_entropy(pi, eps=1e-12):\n",
    "    pi_c = np.clip(pi, eps, 1.0)\n",
    "    return pi_c * np.log(pi_c)\n",
    "\n",
    "theta1 = np.linspace(theta_lo, theta_hi, grid_n)\n",
    "theta2 = np.linspace(theta_lo, theta_hi, grid_n)\n",
    "T1, T2 = np.meshgrid(theta1, theta2)\n",
    "\n",
    "\n",
    "m = np.maximum(T1, T2)\n",
    "exp1 = np.exp(T1 - m)\n",
    "exp2 = np.exp(T2 - m)\n",
    "Z = exp1 + exp2\n",
    "pi1_soft = exp1 / Z\n",
    "pi2_soft = exp2 / Z\n",
    "\n",
    "\n",
    "@partial(jax.jit, static_argnums=(2,))\n",
    "def f_tsallis_softmax_jax(theta, prior, alpha, eps=1e-6, max_iter=200):\n",
    "    j_star = jnp.argmax(theta)\n",
    "    theta_max = jnp.max(theta)\n",
    "    prior_star = prior[j_star]\n",
    "\n",
    "    tau_min = theta_max - ((1 / prior_star) ** (alpha - 1) - 1) / (alpha - 1)\n",
    "    tau_max = theta_max - ((1 / jnp.sum(prior)) ** (alpha - 1) - 1) / (alpha - 1)\n",
    "\n",
    "    tau = (tau_min + tau_max) / 2\n",
    "    p_tau = prior * (1 + (alpha - 1) * (theta - tau)) ** (1 / (alpha - 1))\n",
    "    phi_tau = jnp.sum(p_tau) - 1\n",
    "\n",
    "    def cond(s):\n",
    "        it, _, _, phi = s\n",
    "        return (jnp.abs(phi) > eps) & (it < max_iter)\n",
    "\n",
    "    def body(s):\n",
    "        it, tmin, tmax, phi = s\n",
    "        mid = (tmin + tmax) / 2\n",
    "        tmin2, tmax2 = jax.lax.cond(phi < 0, lambda: (tmin, mid), lambda: (mid, tmax))\n",
    "        mid = (tmin2 + tmax2) / 2\n",
    "        p = prior * (1 + (alpha - 1) * (theta - mid)) ** (1 / (alpha - 1))\n",
    "        phi = jnp.sum(p) - 1\n",
    "        return it + 1, tmin2, tmax2, phi\n",
    "\n",
    "    _, tmin, tmax, _ = jax.lax.while_loop(cond, body, (0, tau_min, tau_max, phi_tau))\n",
    "    tau = (tmin + tmax) / 2\n",
    "    return prior * (1 + (alpha - 1) * (theta - tau)) ** (1 / (alpha - 1))\n",
    "\n",
    "grid = jnp.stack([T1.ravel(), T2.ravel()], axis=1)\n",
    "P = jax.vmap(lambda th: f_tsallis_softmax_jax(th, prior, alpha))(grid)\n",
    "pi1_t = np.asarray(P[:, 0]).reshape(T1.shape)\n",
    "pi2_t = np.asarray(P[:, 1]).reshape(T1.shape)\n",
    "\n",
    "\n",
    "V_soft_ent = (\n",
    "    pi1_soft * r1 + pi2_soft * r2\n",
    "    - lambda_reg * (neg_entropy(pi1_soft) + neg_entropy(pi2_soft))\n",
    ")\n",
    "\n",
    "V_soft_tsreg = (\n",
    "    pi1_soft * r1 + pi2_soft * r2\n",
    "    - lambda_reg * (f_reg_tsallis(pi1_soft, alpha) + f_reg_tsallis(pi2_soft, alpha))\n",
    ")\n",
    "\n",
    "V_tsparam_ent = (\n",
    "    pi1_t * r1 + pi2_t * r2\n",
    "    - lambda_reg * (neg_entropy(pi1_t) + neg_entropy(pi2_t))\n",
    ")\n",
    "\n",
    "V_tsparam_tsreg = (\n",
    "    pi1_t * r1 + pi2_t * r2\n",
    "    - lambda_reg * (f_reg_tsallis(pi1_t, alpha) + f_reg_tsallis(pi2_t, alpha))\n",
    ")\n",
    "\n",
    "\n",
    "def style_ax(ax):\n",
    "    ax.view_init(elev=15, azim=25)\n",
    "    for axis in (ax.xaxis, ax.yaxis, ax.zaxis):\n",
    "        axis.pane.set_facecolor((1, 1, 1, 0))\n",
    "        axis.pane.set_edgecolor((0, 0, 0, 0.15))\n",
    "        axis.set_major_locator(MaxNLocator(nbins=4, integer=True))\n",
    "    ax.grid(True, linewidth=0.4, alpha=0.25)\n",
    "    ax.set_xlabel(r\"$\\theta_1$\", labelpad=3)\n",
    "    ax.set_ylabel(r\"$\\theta_2$\", labelpad=3)\n",
    "\n",
    "def save_single_surface(V, out_pdf, out_png=None, pad=0.03):\n",
    "    vmin, vmax = float(V.min()), float(V.max())\n",
    "    norm = colors.Normalize(vmin=vmin, vmax=vmax)\n",
    "    zlim = (vmin - pad*(vmax - vmin),\n",
    "            vmax + pad*(vmax - vmin))\n",
    "\n",
    "    fig = plt.figure(figsize=(3.6, 3.1))\n",
    "    ax = fig.add_subplot(1, 1, 1, projection=\"3d\")\n",
    "    style_ax(ax)\n",
    "    ax.set_zlim(*zlim)\n",
    "    ax.set_title(\"\")  # no title\n",
    "\n",
    "    ax.plot_surface(T1, T2, V, cmap=cmap, norm=norm,\n",
    "                    rcount=140, ccount=140, alpha=0.85)\n",
    "    ax.contourf(T1, T2, V, zdir=\"z\", offset=zlim[0],\n",
    "                levels=20, cmap=cmap, norm=norm, alpha=0.40)\n",
    "\n",
    "    ax.set_zlabel(r\"$\\tilde V_\\lambda(\\theta)$\", labelpad=8)\n",
    "\n",
    "    plt.subplots_adjust(left=0.00, right=1.00, bottom=0.00, top=1.00)\n",
    "\n",
    "    fig.savefig(out_pdf, bbox_inches=\"tight\")\n",
    "    if out_png is not None:\n",
    "        fig.savefig(out_png, bbox_inches=\"tight\")\n",
    "    plt.close(fig)\n",
    "\n",
    "os.makedirs(\"loss_landscape\", exist_ok=True)\n",
    "\n",
    "save_single_surface(\n",
    "    V_soft_ent,\n",
    "    out_pdf=f\"loss_landscape/landscape_softmax_entropy.pdf\",\n",
    "    #out_png=f\"loss_landscape/landscape_softmax_entropy.png\",\n",
    ")\n",
    "save_single_surface(\n",
    "    V_soft_tsreg,\n",
    "    out_pdf=f\"loss_landscape/landscape_softmax_tsallisreg_alpha{alpha}.pdf\",\n",
    "    #out_png=f\"loss_landscape/landscape_softmax_tsallisreg_alpha{alpha}.png\",\n",
    ")\n",
    "save_single_surface(\n",
    "    V_tsparam_ent,\n",
    "    out_pdf=f\"loss_landscape/landscape_tsallisparam_alpha{alpha}_entropy.pdf\",\n",
    "    #out_png=f\"loss_landscape/landscape_tsallisparam_alpha{alpha}_entropy.png\",\n",
    ")\n",
    "save_single_surface(\n",
    "    V_tsparam_tsreg,\n",
    "    out_pdf=f\"loss_landscape/landscape_tsallisparam_alpha{alpha}_tsallisreg_alpha{alpha}.pdf\",\n",
    "    #out_png=f\"loss_landscape/landscape_tsallisparam_alpha{alpha}_tsallisreg_alpha{alpha}.png\",\n",
    ")\n",
    "\n",
    "print(\"Saved 4 PDFs (and PNGs) to: loss_landscape/\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
