{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed982c07",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from src.light_sb import LightSB\n",
    "from src.distributions import StandardNormalSampler, SwissRollSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41a0ed90",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = []\n",
    "\n",
    "for eps in [0.002]:\n",
    "    EXP_NAME = f'LightSB_Swiss_Roll_EPSILON_{eps}'\n",
    "    OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)\n",
    "\n",
    "    D = LightSB(dim=2, n_potentials=500, epsilon=eps,\n",
    "                sampling_batch_size=128, is_diagonal=True)\n",
    "\n",
    "    D.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D.pt')))\n",
    "    \n",
    "    models.append(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cba1ec52",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sampler = StandardNormalSampler(dim=2, device=\"cpu\")\n",
    "Y_sampler = SwissRollSampler(dim=2, device=\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c178cf53",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "SEED = 12\n",
    "\n",
    "torch.manual_seed(SEED); np.random.seed(SEED)\n",
    "\n",
    "fig, axes = plt.subplots(1, 4, figsize=(15, 3.75), dpi=200)\n",
    "\n",
    "for ax in axes:\n",
    "    ax.grid(zorder=-20)\n",
    "\n",
    "x_samples = X_sampler.sample(2048)\n",
    "y_samples = Y_sampler.sample(2048)\n",
    "tr_samples = torch.tensor([[0.0, 0.0], [1.75, -1.75], [-1.5, 1.5], [2, 2]])\n",
    "\n",
    "tr_samples = tr_samples[None].repeat(3, 1, 1).reshape(12, 2)\n",
    "\n",
    "axes[0].scatter(x_samples[:, 0], x_samples[:, 1], alpha=0.3, \n",
    "                c=\"g\", s=32, edgecolors=\"black\", label = r\"Input distirubtion $p_0$\")\n",
    "axes[0].scatter(y_samples[:, 0], y_samples[:, 1], \n",
    "                c=\"orange\", s=32, edgecolors=\"black\", label = r\"Target distribution $p_1$\")\n",
    "\n",
    "for ax, model in zip(axes[1:], models):\n",
    "    y_pred = model(x_samples)\n",
    "    \n",
    "    ax.scatter(y_pred[:, 0], y_pred[:, 1], \n",
    "               c=\"yellow\", s=32, edgecolors=\"black\", label = \"Fitted distribution\", zorder=1)\n",
    "    \n",
    "    trajectory = model.sample_euler_maruyama(tr_samples, 1000).detach().cpu()\n",
    "    \n",
    "    ax.scatter(tr_samples[:, 0], tr_samples[:, 1], \n",
    "       c=\"g\", s=128, edgecolors=\"black\", label = r\"Trajectory start ($x \\sim p_0$)\", zorder=3)\n",
    "    \n",
    "    ax.scatter(trajectory[:, -1, 0], trajectory[:, -1, 1], \n",
    "       c=\"red\", s=64, edgecolors=\"black\", label = r\"Trajectory end (fitted)\", zorder=3)\n",
    "        \n",
    "    for i in range(12):\n",
    "        ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], \"black\", markeredgecolor=\"black\",\n",
    "             linewidth=1.5, zorder=2)\n",
    "        if i == 0:\n",
    "            ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], \"grey\", markeredgecolor=\"black\",\n",
    "                     linewidth=0.5, zorder=2, label=r\"Trajectory of $T_{\\theta}$\")\n",
    "        else:\n",
    "            ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], \"grey\", markeredgecolor=\"black\",\n",
    "                     linewidth=0.5, zorder=2)\n",
    "    \n",
    "for ax, title in zip(axes, titles):\n",
    "    ax.set_xlim([-2.5, 2.5])\n",
    "    ax.set_ylim([-2.5, 2.5])\n",
    "    ax.legend(loc=\"lower left\")\n",
    "\n",
    "fig.tight_layout(pad=0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1e53baf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cec4b67",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
