{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "997e6ee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyro import distributions as dist\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from glob import glob\n",
    "from src.usflows.explib.config_parser import from_checkpoint\n",
    "from src.usflows.explib.eval import RadialFlowEvaluator\n",
    "import os\n",
    "from src.usflows.distributions import Chi\n",
    "from src.usflows.explib.datasets import DistributionDataset\n",
    "from src.usflows.distributions import GMM\n",
    "from torch.nn.functional import softplus"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b37a8d7c",
   "metadata": {},
   "source": [
    "# Evaluate all Dims"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e7bdd13",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Change these variables as needed\n",
    "arch = \"\" # Architecture identifier string (for titles / file names)\n",
    "base_dir = \"\" # Path to directory containing subdirectories of model checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73587fee",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "subfolders = sorted(os.listdir(base_dir))\n",
    "subfolders = [os.path.join(base_dir, d) for d in subfolders]\n",
    "subfolders = sorted([d for d in subfolders if os.path.isdir(d)])\n",
    "\n",
    "model_dirs = [\n",
    "    os.path.join(base_dir, subfolder) for subfolder in subfolders if os.path.isdir(os.path.join(base_dir, subfolder))\n",
    "]\n",
    "\n",
    "data = {}\n",
    "print(model_dirs)\n",
    "for i, model_dir in enumerate(model_dirs):\n",
    "    print(model_dir)\n",
    "    # Locate model files\n",
    "    pkl_files = sorted([f for f in os.listdir(model_dir) if f.endswith(\".pkl\")])\n",
    "    pt_files = sorted([f for f in os.listdir(model_dir) if f.endswith(\".pt\")])\n",
    "\n",
    "    if not pkl_files or not pt_files:\n",
    "        print(f\"Skipping {model_dir} (missing files)\")\n",
    "        continue\n",
    "\n",
    "    pkl_path = os.path.join(model_dir, pkl_files[-1])\n",
    "    pt_path = os.path.join(model_dir, pt_files[-1])\n",
    "    try:\n",
    "        model = from_checkpoint(pkl_path, pt_path)\n",
    "    except:\n",
    "        continue\n",
    "        \n",
    "    dim = int(model_dir.split(\"_\")[-1][:-1])\n",
    "\n",
    "    print(f\"{dim}D GMM\")\n",
    "    hdim = int(dim/2)\n",
    "    distribution=GMM(\n",
    "        loc=torch.stack([-torch.ones(dim), torch.ones(dim)]), \n",
    "        covariance_matrix=torch.stack([\n",
    "            torch.diag(torch.tensor([5.]*hdim + [.5]*hdim)), \n",
    "            torch.eye(dim)\n",
    "        ]),\n",
    "        mixture_weights=torch.ones(2)/2\n",
    "    )\n",
    "    ref_dist = distribution\n",
    "    \n",
    "    ds = DistributionDataset(\n",
    "        distribution=distribution,\n",
    "        num_samples=10000\n",
    "    )[:][0]\n",
    "\n",
    "    data[i] = ds\n",
    "\n",
    "    evaluator = RadialFlowEvaluator(\n",
    "        model,\n",
    "        ds,\n",
    "        p=2.0,\n",
    "        norm_distribution=Chi(\n",
    "            df=dim,\n",
    "            scale=softplus(model.base_distribution.scale_unconstrained),\n",
    "            validate_args=False\n",
    "        )\n",
    "    )\n",
    "\n",
    "    row = i\n",
    "    col = 0\n",
    "    \n",
    "    scatter_fig, ax = plt.subplots()\n",
    "    evaluator.logprob_reference_scatter_plot(ax=ax, ref_distribution=ref_dist)\n",
    "    ax.set_title(f\"Log-Probability Comparison ({dim}D)\")\n",
    "    scatter_fig.savefig(f\"gmm_eval_logprobs_{dim}D_{arch}.png\")\n",
    "    \n",
    "    scatter_fig, ax = plt.subplots()\n",
    "    evaluator.nll_norm_scatter_plot(ax=ax, ref_distribution=ref_dist)\n",
    "    ax.set_title(f\"NLL vs Latent Norm ({dim}D)\")\n",
    "    scatter_fig.savefig(f\"gmm_eval_nll_vs_latent_norms_{dim}D_{arch}.png\")\n",
    "\n",
    "plt.show()    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cc7599d",
   "metadata": {},
   "source": [
    "# Eval 2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc7351fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.patches import Circle\n",
    "\n",
    "base_path =  base_dir + \"/0_gaussian_mixture_2D\"\n",
    "\n",
    "pkl_path = sorted(glob(f\"{base_path}/*.pkl\"))[-1]\n",
    "pt_path = sorted(glob(f\"{base_path}/*.pt\"))[-1]\n",
    "\n",
    "dim = 2\n",
    "distribution=GMM(\n",
    "    loc=torch.stack([-torch.ones(dim), torch.ones(dim)]), \n",
    "    covariance_matrix=torch.stack([torch.diag(torch.Tensor([5., .5])) ,torch.eye(dim)]),\n",
    "    mixture_weights=torch.ones(2)/2\n",
    ")\n",
    "ref_dist = distribution\n",
    "\n",
    "model = from_checkpoint(pkl_path, pt_path)\n",
    "\n",
    "with torch.no_grad():\n",
    "    ds = distribution.sample([1000])\n",
    "    latents = model.backward(ds) - model.base_distribution.loc"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f26ccd94",
   "metadata": {},
   "source": [
    "## Plot Determinant"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c46e04bf",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import gaussian_kde\n",
    "\n",
    "_, ax = plt.subplots(figsize=(5, 5))\n",
    "ax.set_facecolor('white')\n",
    "\n",
    "ax\n",
    "# 1. Generate sample data (replace with your dataset)\n",
    "np.random.seed(42)\n",
    "x = ds[:, 0].numpy()\n",
    "y = ds[:, 1].numpy()\n",
    "\n",
    "# 2. Create grid for density evaluation\n",
    "x_grid, y_grid = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]\n",
    "positions = np.vstack([x_grid.ravel(), y_grid.ravel()])\n",
    "\n",
    "# 3. Calculate density using Kernel Density Estimation (KDE)\n",
    "\n",
    "density = np.reshape(torch.exp(ref_dist.log_prob(torch.Tensor(positions).permute(1,0))).detach().T, x_grid.shape)\n",
    "\n",
    "\n",
    "# Contour lines only\n",
    "contour = ax.contour(x_grid, y_grid, density, levels=8, colors='black', linewidths=0.5)\n",
    "\n",
    "# Add data points overlay (optional)\n",
    "with torch.no_grad():\n",
    "    #c = torch.exp(distribution.log_prob(ds))\n",
    "    c = torch.exp(model.log_abs_det_jacobian(ds))\n",
    "    if c.dim() == 0:\n",
    "        c = [c] * len(ds)\n",
    "scatter = ax.scatter(x, y, s=5, c=c, cmap=\"magma\", alpha=1)\n",
    "plt.colorbar(scatter, label='$|\\\\det J_{\\\\phi}(x)|$')\n",
    "\n",
    "# Customize plot\n",
    "ax.set_title('Data Distribution 2D GMM')\n",
    "ax.set_aspect('equal')\n",
    "ax.grid(alpha=0.2)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(f\"contour_2d_gmm_{arch}_log_abs_det_jac.png\")\n",
    "\n",
    "###############################################3\n",
    "\n",
    "_, ax = plt.subplots(figsize=(5, 5))\n",
    "ax.set_facecolor('white')\n",
    "\n",
    "# 1. Generate sample data (replace with your dataset)\n",
    "np.random.seed(42)\n",
    "x = latents[:, 0].numpy()\n",
    "y = latents[:, 1].numpy()\n",
    "\n",
    "# 2. Create grid for density evaluation\n",
    "x_grid, y_grid = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]\n",
    "positions = np.vstack([x_grid.ravel(), y_grid.ravel()])\n",
    "\n",
    "# 3. Calculate density using Kernel Density Estimation (KDE)\n",
    "kde = gaussian_kde(np.vstack([x, y]))\n",
    "#density = np.reshape(kde(positions).T, x_grid.shape)\n",
    "\n",
    "# Add data points overlay (optional)\n",
    "\n",
    "scatter = ax.scatter(x, y, s=5, c=c, cmap=\"magma\", alpha=1)\n",
    "plt.colorbar(scatter, label='$|\\\\det J_{\\\\phi}(x)|$')\n",
    "\n",
    "# Customize plot\n",
    "ax.set_title(f'Centered Latent Data Distribution\\n2D GMM ({arch})')\n",
    "\n",
    "scale = softplus(model.base_distribution.scale_unconstrained)\n",
    "ax.add_patch(Circle((0., 0.), radius=scale, fill=False, edgecolor='black', linewidth=.5, linestyle='--'))\n",
    "ax.add_patch(Circle((0., 0.), radius=1.5*scale, fill=False, edgecolor='black', linewidth=.5, linestyle='--'))\n",
    "ax.add_patch(Circle((0., 0.), radius=2*scale, fill=False, edgecolor='black', linewidth=.5 , linestyle='--'))\n",
    "ax.add_patch(Circle((0., 0.), radius=2.5*scale, fill=False, edgecolor='black', linewidth=.5, linestyle=\"--\"))\n",
    "ax.add_patch(Circle((0., 0.), radius=3*scale, fill=False, edgecolor='black', linewidth=.5, linestyle=\"--\"))\n",
    "ax.set_aspect('equal')\n",
    "ax.grid(alpha=0.2)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"contour_latent_2d_gmm_{arch}_log_abs_det_jac.png\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fe90691",
   "metadata": {},
   "source": [
    "## Plot Densities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b03bdcd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import gaussian_kde\n",
    "\n",
    "_, ax = plt.subplots(figsize=(5, 5))\n",
    "ax.set_facecolor('white')\n",
    "\n",
    "ax\n",
    "# 1. Generate sample data (replace with your dataset)\n",
    "np.random.seed(42)\n",
    "x = ds[:, 0].numpy()\n",
    "y = ds[:, 1].numpy()\n",
    "\n",
    "# 2. Create grid for density evaluation\n",
    "x_grid, y_grid = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]\n",
    "positions = np.vstack([x_grid.ravel(), y_grid.ravel()])\n",
    "\n",
    "# 3. Calculate density using Kernel Density Estimation (KDE)\n",
    "\n",
    "density = np.reshape(torch.exp(ref_dist.log_prob(torch.Tensor(positions).permute(1,0))).detach().T, x_grid.shape)\n",
    "\n",
    "\n",
    "# Contour lines only\n",
    "contour = ax.contour(x_grid, y_grid, density, levels=8, colors='black', linewidths=0.5)\n",
    "\n",
    "\n",
    "# Add data points overlay (optional)\n",
    "with torch.no_grad():\n",
    "    c = torch.exp(distribution.log_prob(ds))\n",
    "    #c = torch.exp(model.log_abs_det_jacobian(ds))\n",
    "scatter = ax.scatter(x, y, s=5, c=c, cmap=\"magma\", alpha=1)\n",
    "plt.colorbar(scatter, label='Data Density')\n",
    "\n",
    "# Customize plot\n",
    "ax.set_title('Data Distribution 2D GMM')\n",
    "ax.set_aspect('equal')\n",
    "ax.grid(alpha=0.2)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(f\"contour_2d_gmm_{arch}_density.png\")\n",
    "\n",
    "###############################################3\n",
    "\n",
    "_, ax = plt.subplots(figsize=(5, 5))\n",
    "ax.set_facecolor('white')\n",
    "\n",
    "# 1. Generate sample data (replace with your dataset)\n",
    "np.random.seed(42)\n",
    "x = latents[:, 0].numpy()\n",
    "y = latents[:, 1].numpy()\n",
    "\n",
    "# 2. Create grid for density evaluation\n",
    "x_grid, y_grid = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]\n",
    "positions = np.vstack([x_grid.ravel(), y_grid.ravel()])\n",
    "\n",
    "# 3. Calculate density using Kernel Density Estimation (KDE)\n",
    "kde = gaussian_kde(np.vstack([x, y]))\n",
    "#density = np.reshape(kde(positions).T, x_grid.shape)\n",
    "\n",
    "\n",
    "# Add data points overlay (optional)\n",
    "\n",
    "scatter = ax.scatter(x, y, s=5, c=c, cmap=\"magma\", alpha=1)\n",
    "#plt.colorbar(scatter, label='Data Density')\n",
    "\n",
    "# Customize plot\n",
    "ax.set_title(f'Centered Latent Data Distribution\\n2D GMM ({arch})')\n",
    "\n",
    "ax.add_patch(Circle((0., 0.), radius=scale, fill=False, edgecolor='black', linewidth=.5, linestyle='--'))\n",
    "ax.add_patch(Circle((0., 0.), radius=1.5*scale, fill=False, edgecolor='black', linewidth=.5, linestyle='--'))\n",
    "ax.add_patch(Circle((0., 0.), radius=2*scale, fill=False, edgecolor='black', linewidth=.5 , linestyle='--'))\n",
    "ax.add_patch(Circle((0., 0.), radius=2.5*scale, fill=False, edgecolor='black', linewidth=.5, linestyle=\"--\"))\n",
    "ax.add_patch(Circle((0., 0.), radius=3*scale, fill=False, edgecolor='black', linewidth=.5, linestyle=\"--\"))\n",
    "ax.set_aspect('equal')\n",
    "ax.grid(alpha=0.2)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"contour_latent_2d_gmm_{arch}_density.png\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffd88ab5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
