{
 "cells": [
  {
   "cell_type": "code",
   "id": "7c9420a5-7045-434d-b421-05522a6461b9",
   "metadata": {},
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "from matplotlib.colors import Normalize\n",
    "from mpl_toolkits.mplot3d import Axes3D    # noqa: F401"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "567c0e63-6b62-4bf9-a99b-71db87d0b323",
   "metadata": {},
   "source": [
    "AOD_Data = np.load('results/AOD_data/FFNP/targets.npy')\n",
    "AOD_Locs = np.load('results/AOD_data/FFNP/pred_locs.npy')"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "cfc951e6-f7bc-461e-9356-347cd34d9b73",
   "metadata": {},
   "source": [
    "STACI_pred = np.load('results/AOD_data/FFNP/preds.npy')\n",
    "STACI_pred_var = np.load('results/AOD_data/FFNP/pred_uncer.npy')\n",
    "STACI_pred_conf = np.load('results/AOD_data/FFNP/pred_uncer_conf.npy')"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def plotTests_sphere(datapath, save_dir, save_dir2, conformal, prefix,\n",
    "                          cmap_mean=\"RdBu_r\", cmap_unc=\"viridis\",\n",
    "                          vmin_mean=-1, vmax_mean=3,\n",
    "                          vmin_unc=0,  vmax_unc=6,\n",
    "                          sphere_res=(60, 120),   # target resolution (v, u)\n",
    "                          sphere_size=6):\n",
    "    # --- load & build your full-resolution 2D arrays as before ---\n",
    "    pred_locs = np.load(os.path.join(save_dir, 'pred_locs.npy'))\n",
    "    test_locs = pd.DataFrame(pred_locs[:, :2], columns=['x','y'])\n",
    "    targets   = np.load(os.path.join(save_dir, 'targets.npy'))\n",
    "\n",
    "    dx = np.diff(np.sort(test_locs[\"x\"].unique())).min()\n",
    "    dy = np.diff(np.sort(test_locs[\"y\"].unique())).min()\n",
    "    x0, x1 = test_locs[\"x\"].min(), test_locs[\"x\"].max()\n",
    "    y0, y1 = test_locs[\"y\"].min(), test_locs[\"y\"].max()\n",
    "    x_grid, y_grid = np.meshgrid(\n",
    "        np.arange(x0, x1+dx, dx),\n",
    "        np.arange(y0, y1+dy, dy)\n",
    "    )\n",
    "\n",
    "    ix = np.round((test_locs['x']-x0)/dx).astype(int)\n",
    "    iy = np.round((test_locs['y']-y0)/dy).astype(int)\n",
    "\n",
    "    def fill2d(arr):\n",
    "        grid = np.full_like(x_grid, np.nan, dtype=float)\n",
    "        grid[iy, ix] = arr\n",
    "        return grid\n",
    "\n",
    "    fields = {\n",
    "        \"ground_truth\": fill2d(targets),\n",
    "        \"pred_mean\":    fill2d(np.load(os.path.join(save_dir, 'preds.npy'))),\n",
    "        \"bayes_uncer\":  fill2d(np.load(os.path.join(save_dir, 'pred_uncer.npy'))),\n",
    "        \"resids\":       fill2d(np.load(os.path.join(save_dir, 'resids.npy')))\n",
    "    }\n",
    "    if conformal:\n",
    "        fields[\"conf_uncer\"] = fill2d(\n",
    "            np.load(os.path.join(save_dir, 'pred_uncer_conf.npy'))\n",
    "        )\n",
    "\n",
    "    # --- downsample each field to (sphere_res) for speed ---\n",
    "    M, N = fields[\"ground_truth\"].shape\n",
    "    target_m, target_n = sphere_res\n",
    "    step_m = max(1, M // target_m)\n",
    "    step_n = max(1, N // target_n)\n",
    "\n",
    "    for k,v in fields.items():\n",
    "        fields[k] = v[::step_m, ::step_n]\n",
    "\n",
    "    # new mesh sizes\n",
    "    M2, N2 = fields[\"ground_truth\"].shape\n",
    "\n",
    "    # --- precompute a coarse unit-sphere mesh ---\n",
    "    u = np.linspace(0, 2*np.pi, N2)\n",
    "    v = np.linspace(0,     np.pi,  M2)\n",
    "    uu, vv = np.meshgrid(u, v)\n",
    "    Xs = np.sin(vv) * np.cos(uu)\n",
    "    Ys = np.sin(vv) * np.sin(uu)\n",
    "    Zs = np.cos(vv)\n",
    "\n",
    "    def _plot_sphere(field2d, fname, cmap, vmin, vmax):\n",
    "        fig = plt.figure(figsize=(sphere_size, sphere_size))\n",
    "        ax  = fig.add_subplot(1,1,1, projection='3d')\n",
    "        ax.set_axis_off()\n",
    "\n",
    "        # force full-figure, zero-margins\n",
    "        fig.subplots_adjust(left=0, right=1, bottom=0, top=1)\n",
    "        ax.set_position([0, 0, 1, 1])\n",
    "        \n",
    "        # force equal aspect so the sphere isn't an ellipsoid\n",
    "        #ax.set_box_aspect((1,1,1))\n",
    "        #1. equal aspect ratio\n",
    "        ax.set_box_aspect((1,1,1))\n",
    "        # 2. ensure the plotting limits are symmetric\n",
    "        ax.set_xlim(-1, 1)\n",
    "        ax.set_ylim(-1, 1)\n",
    "        ax.set_zlim(-1, 1)\n",
    "        # 3. optional: remove perspective\n",
    "        try:\n",
    "            ax.set_proj_type('ortho')\n",
    "        except AttributeError:\n",
    "            pass\n",
    "        # set camera distance: lower = zoom in\n",
    "        ax.dist = 7\n",
    "\n",
    "        norm     = Normalize(vmin=vmin, vmax=vmax, clip=True)\n",
    "        mappable = cm.ScalarMappable(norm=norm, cmap=cmap)\n",
    "        facecols = mappable.to_rgba(field2d)\n",
    "\n",
    "        ax.plot_surface(\n",
    "            Xs, Ys, Zs,\n",
    "            rstride=1, cstride=1,\n",
    "            facecolors=facecols,\n",
    "            linewidth=0, antialiased=False, shade=False\n",
    "        )\n",
    "\n",
    "        out = os.path.join(save_dir2, fname)\n",
    "        fig.savefig(out, dpi=500, bbox_inches='tight', pad_inches=0)\n",
    "        plt.close(fig)\n",
    "\n",
    "    # --- render and save ---\n",
    "    if conformal:\n",
    "        _plot_sphere(fields[\"ground_truth\"],\n",
    "                     f\"{prefix}_ground_truth_{datapath}.png\",\n",
    "                     cmap_mean, vmin_mean, vmax_mean)\n",
    "    _plot_sphere(fields[\"pred_mean\"],\n",
    "                 f\"{prefix}_pred_mean_{datapath}.png\",\n",
    "                 cmap_mean, vmin_mean, vmax_mean)\n",
    "    _plot_sphere(fields[\"bayes_uncer\"],\n",
    "                 f\"{prefix}_bayes_uncer_{datapath}.png\",\n",
    "                 cmap_unc,  vmin_unc,  vmax_unc)\n",
    "    if conformal:\n",
    "        _plot_sphere(fields[\"conf_uncer\"],\n",
    "                     f\"{prefix}_conf_uncer_{datapath}.png\",\n",
    "                     cmap_unc, vmin_unc, vmax_unc)\n",
    "\n",
    "    print(\"Fast sphere plots saved to\", save_dir2)"
   ],
   "id": "908e37436c18d895",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "plotTests_sphere('STACI', 'results/AOD_data/FFNP', 'results/AOD_data/Plots_Sphere', True, 'AOD_data', sphere_res=(400, 750))",
   "id": "e429fd7c62e075bb",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "NERSC Python",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
