{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4cfb1849-03af-460a-a94e-9bb478af6cf5",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d3d75eb4-5d77-4647-ba3f-1fe047848671",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from jax import random\n",
    "from jax.experimental.ode import odeint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b1bcfaeb-3893-486e-b55f-1603d06365f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def a(t):\n",
    "    return jnp.exp(-t)\n",
    "\n",
    "\n",
    "def b(t):\n",
    "    return jnp.sqrt(1 - jnp.exp(-2 * t))\n",
    "\n",
    "\n",
    "def p_btxi_grad(x, t):\n",
    "    # Expects x to be shape (1, d).\n",
    "    sigma_sq = b(t) ** 2\n",
    "    return -1/sigma_sq * jnp.exp(-(x ** 2).sum(axis=1, keepdims=1)/(2 * sigma_sq)) * x\n",
    "\n",
    "\n",
    "def p_btxi(x, t):\n",
    "    # Expects x to be shape (1, d).\n",
    "    sigma_sq = b(t) ** 2\n",
    "    return jnp.exp(-(x ** 2).sum(axis=1, keepdims=1)/(2 * sigma_sq))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "981f5c54-018a-4878-a897-58003722a28b",
   "metadata": {},
   "source": [
    "## Data Distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "46cc126c-c3b8-43be-a5d0-d60380470cbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sep_scale = 1\n",
    "k = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a4f4b6d1-6a3a-4ccf-805b-a60332cf7209",
   "metadata": {},
   "outputs": [],
   "source": [
    "def uniform_square(key, score_samples, t):\n",
    "    key, xy_key = random.split(key, 2)\n",
    "    xy = a(t) * random.uniform(xy_key, shape=(score_samples, 2), minval=-0.5, maxval=0.5)\n",
    "    return key, xy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fad05567-dd56-47b9-91ab-13bd3dcf3701",
   "metadata": {},
   "outputs": [],
   "source": [
    "def shifted_uniform_square(key, score_samples, t, y):\n",
    "    key, xy = uniform_square(key, score_samples, t)\n",
    "    shift = jnp.hstack((jnp.ones(shape=(score_samples, 1)),\n",
    "                       jnp.zeros(shape=(score_samples, 1))))\n",
    "    return key, xy + a(t) * sep_scale * (2 * y - 1) * shift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "03c1756e-4e36-4049-bf87-41b945a9568b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def shifted_uniform_square_marginal(key, score_samples, t):\n",
    "    key, c_key = random.split(key)\n",
    "    y = random.randint(c_key, shape=(score_samples, 1), minval=0, maxval=k)\n",
    "    return shifted_uniform_square(key, score_samples, t, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e01d38bf-f64c-447c-8b38-4a17be277aca",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gaussian(key, score_samples, t):\n",
    "    key, xy_key = random.split(key, 2)\n",
    "    xy = a(t) * random.normal(xy_key, shape=(score_samples, 2))\n",
    "    return key, xy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "33a6a02c-9200-4455-9e0e-10027f8def80",
   "metadata": {},
   "outputs": [],
   "source": [
    "def shifted_gaussian(key, score_samples, t, y):\n",
    "    key, xy = gaussian(key, score_samples, t)\n",
    "    shift = jnp.hstack((jnp.ones(shape=(score_samples, 1)),\n",
    "                       jnp.zeros(shape=(score_samples, 1))))\n",
    "    return key, xy + a(t) * sep_scale * (2 * y - 1) * shift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4fa5d87f-7882-43e9-bb7c-68a61952fe98",
   "metadata": {},
   "outputs": [],
   "source": [
    "def shifted_gaussian_marginal(key, score_samples, t):\n",
    "    key, c_key = random.split(key)\n",
    "    y = random.randint(c_key, shape=(score_samples, 1), minval=0, maxval=k)\n",
    "    return shifted_gaussian(key, score_samples, t, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a63168d-af50-4cac-9312-07f9cf855dc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.key(0)\n",
    "key, marginal_samples = shifted_gaussian_marginal(key, 500, 0)\n",
    "plt.scatter(marginal_samples[:,0], marginal_samples[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "682c3ef9-dc7d-4650-8a7c-2f5e8d0358f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "key, left_samples = shifted_gaussian(key, 500, 0, y=0)\n",
    "key, right_samples = shifted_gaussian(key, 500, 0, y=1)\n",
    "\n",
    "plt.scatter(left_samples[:,0], left_samples[:,1], color=\"blue\")\n",
    "plt.scatter(right_samples[:,0], right_samples[:,1], color=\"red\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"synthetic_plots/two_gaussians.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b3498e94-5ca0-4017-8e53-c7803c970afc",
   "metadata": {},
   "outputs": [],
   "source": [
    "global_score_samples = 500\n",
    "global_uncond_sampler = shifted_gaussian_marginal\n",
    "global_cond_sampler = shifted_gaussian\n",
    "global_T = 5\n",
    "global_delta = 5e-3\n",
    "global_ts = jnp.linspace(0, global_T, int(global_T/global_delta))[:-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2128dfde-9b0c-4549-8f53-eda23716c46a",
   "metadata": {},
   "source": [
    "## Diffusion Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "133a63fd-a435-49ca-af9f-005764ee11f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def uncond_score(key, sampler, x, score_samples, t):\n",
    "    key, num = sampler(key, score_samples, t)\n",
    "    key, den = sampler(key, score_samples, t)\n",
    "    return key, p_btxi_grad(x - num, t).mean(axis=0) / p_btxi(x - den, t).mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1c05493b-2224-41d7-996f-3404d9cb0e26",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cond_score(key, sampler, x, score_samples, t, y):\n",
    "    key, num = sampler(key, score_samples, t, y)\n",
    "    key, den = sampler(key, score_samples, t, y)\n",
    "    return key, p_btxi_grad(x - num, t).mean(axis=0) / p_btxi(x - den, t).mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ee15b8ec-d812-42b1-9089-486d8269342c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def uncond_ode(key, sampler, x, score_samples, t):\n",
    "    key, score = uncond_score(key, sampler, x, score_samples, t)\n",
    "    return key, x + score\n",
    "\n",
    "def partial_uncond_ode(x, t, key):\n",
    "    key, ode_key = random.split(key)\n",
    "    ode_key, next = uncond_ode(key=ode_key,\n",
    "                               sampler=global_uncond_sampler,\n",
    "                               x=x,\n",
    "                               score_samples=global_score_samples,\n",
    "                               t=global_T-t)\n",
    "    return next\n",
    "\n",
    "def diffuse_uncond(key, x):\n",
    "    key, diffuse_key = random.split(key)\n",
    "    return key, odeint(partial_uncond_ode, x, global_ts, diffuse_key, mxstep=25)\n",
    "\n",
    "v_diffuse_uncond = jax.vmap(diffuse_uncond)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "89c2e297-bca1-4339-bb8b-7122e488b6d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def guided_ode(key, uncond_sampler, cond_sampler, x, score_samples, t, y, w):\n",
    "    key, uncond = uncond_score(key, uncond_sampler, x, score_samples, t)\n",
    "    key, cond = cond_score(key, cond_sampler, x, score_samples, t, y)\n",
    "    guided_score = (1 + w) * cond - w * uncond\n",
    "    return key, x + guided_score\n",
    "\n",
    "def partial_guided_ode(x, t, key, y, w):\n",
    "    key, ode_key = random.split(key)\n",
    "    ode_key, next = guided_ode(key=key, \n",
    "                               uncond_sampler=global_uncond_sampler, \n",
    "                               cond_sampler=global_cond_sampler, \n",
    "                               x=x, \n",
    "                               score_samples=global_score_samples, \n",
    "                               t=global_T-t, \n",
    "                               y=y, \n",
    "                               w=w)\n",
    "    return next\n",
    "\n",
    "def diffuse_cond(key, x, y, w):\n",
    "    key, diffuse_key = random.split(key)\n",
    "    return key, odeint(partial_guided_ode, x, global_ts, diffuse_key, y, w, mxstep=25)\n",
    "\n",
    "v_diffuse_cond = jax.vmap(diffuse_cond)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "178c69e7-e8ed-489e-b529-e57a85f54b6a",
   "metadata": {},
   "source": [
    "## Generate Unconditional Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6b109448-bcd0-4511-8db9-911948c6e3ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_sample = 500\n",
    "key, x_key = random.split(key)\n",
    "x = random.normal(x_key, shape=(n_sample, 2))\n",
    "key, *all_keys = random.split(key, n_sample + 1)\n",
    "_, trajs = v_diffuse_uncond(jnp.array(all_keys), x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f04b232-8cfd-4bba-86d3-d1a017454e42",
   "metadata": {},
   "outputs": [],
   "source": [
    "trajs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6c656a2-ded6-4b62-8c59-150a5f858be2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(trajs[:, -1, 0], trajs[:, -1, 1])\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "326843e6-354a-4bf5-a8c5-1f29e00f9a35",
   "metadata": {},
   "source": [
    "## Generate Conditional Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "28130c83-4d05-4799-ba4f-47659ae44317",
   "metadata": {},
   "outputs": [],
   "source": [
    "def conditional_sample(key, y, w):\n",
    "    n_sample = 500\n",
    "    key, x_key = random.split(key)\n",
    "    x = random.normal(x_key, shape=(n_sample, 2))\n",
    "    key, *all_keys = random.split(key, n_sample + 1)\n",
    "    _, trajs = v_diffuse_cond(jnp.array(all_keys),\n",
    "                              x, \n",
    "                              y * jnp.ones(shape=(n_sample,)), \n",
    "                              w * jnp.ones(shape=(n_sample,)))\n",
    "    return key, trajs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1984fad8-b547-49ac-9823-e6f07dab2c63",
   "metadata": {},
   "outputs": [],
   "source": [
    "key, cond_trajs = conditional_sample(key, y=1, w=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48647902-33f8-431d-99aa-eeec733dc5b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "key, guide_trajs = conditional_sample(key, y=1, w=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa9e10af-456c-4de0-b390-900d26ed8b40",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(cond_trajs[:, -1, 0], cond_trajs[:, -1, 1], color=\"red\")\n",
    "plt.xlim(-4, 4)\n",
    "plt.ylim(-3, 3)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"synthetic_plots/right_gaussian.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06746ff8-c8e6-4601-ab40-b4cf11f361f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(guide_trajs[:, -1, 0], guide_trajs[:, -1, 1], color=\"red\")\n",
    "plt.xlim(-4, 4)\n",
    "plt.ylim(-3, 3)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"synthetic_plots/right_gaussian_guided.jpg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f3c8472-51fb-4ee4-a6c8-06c7f3842b67",
   "metadata": {},
   "source": [
    "## Analyze Trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fb8563d-4893-4526-8878-653f0b2e6de7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_good_inds(trajs, y):\n",
    "    # x_min, x_max = -0.5 + sep_scale * (2 * y - 1), 0.5 + sep_scale * (2 * y - 1)\n",
    "    # y_min, y_max = -0.5, 0.5\n",
    "    \n",
    "    # samples = trajs[:, -1]\n",
    "    \n",
    "    # valid_xs = jnp.logical_and(samples[:, 0] >= x_min, samples[:, 0] <= x_max)\n",
    "    # valid_ys = jnp.logical_and(samples[:, 1] >= y_min, samples[:, 1] <= y_max)\n",
    "    # valid_both = jnp.logical_and(valid_xs, valid_ys)\n",
    "    \n",
    "    # good_inds = jnp.logical_and(~(jnp.isnan(samples)).any(axis=1), valid_both)\n",
    "    samples = trajs[:, -1]\n",
    "    return ~(jnp.isnan(trajs.reshape(trajs.shape[0], -1)).any(axis=1))\n",
    "    # return jnp.logical_and(~(jnp.isnan(trajs.reshape(trajs.shape[0], -1)).any(axis=1)), ~((jnp.abs(trajs.reshape(trajs.shape[0], -1)) > 1000).any(axis=1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1aab1c04-d329-47b8-9f26-931d67de7a2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = 1\n",
    "w_vals = [1, 3, 7, 15]\n",
    "all_guide_trajs = []\n",
    "for w in w_vals:\n",
    "    key, guide_trajs = conditional_sample(key, y=y, w=w)\n",
    "    all_guide_trajs.append(guide_trajs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a882e84f-9983-4818-9f16-4db5a7678d30",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_good_inds = [filter_good_inds(traj, y) for traj in all_guide_trajs]\n",
    "all_samples = [traj[:, -1] for traj in all_guide_trajs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6127294a-38b3-4808-a7ef-1d2bc450af0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)\n",
    "for i, w in enumerate(w_vals):\n",
    "    row, col = i // 2, i % 2\n",
    "    samples = all_samples[i]\n",
    "    axs[row, col].scatter(samples[:, 0], samples[:, 1], color=\"red\")\n",
    "    axs[row, col].set_title(f\"Guidance w = {w}\")\n",
    "\n",
    "    support_frac = all_good_inds[i].sum() / len(samples)\n",
    "    props = dict(boxstyle='round', facecolor='gray', alpha=0.3)\n",
    "    axs[row, col].text(0.05, \n",
    "                       0.95, \n",
    "                       f\"Good: {support_frac:.2f}\", \n",
    "                       transform=axs[row, col].transAxes, \n",
    "                       fontsize=14,\n",
    "                       verticalalignment='top', \n",
    "                       bbox=props)\n",
    "plt.xlim(-4, 30)\n",
    "plt.ylim(-5, 5)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"synthetic_plots/gaussian_guidance_grid.jpg\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "355b12c0-10f4-4d67-899b-4bc608888d6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)\n",
    "for i, w in enumerate(w_vals):\n",
    "    row, col = i // 2, i % 2\n",
    "    good_trajs = all_guide_trajs[i][all_good_inds[i]]\n",
    "    # good_trajs = all_guide_trajs[i]\n",
    "    \n",
    "    normal_vec = jnp.array([1, 0])\n",
    "    proj_trajs = good_trajs @ normal_vec\n",
    "\n",
    "    mean_traj = proj_trajs.mean(axis=0)\n",
    "    traj_std = proj_trajs.std(axis=0)\n",
    "    \n",
    "    axs[row, col].plot(jnp.arange(mean_traj.shape[0]), mean_traj)\n",
    "    axs[row, col].fill_between(jnp.arange(mean_traj.shape[0]), \n",
    "                               mean_traj + traj_std / 2, \n",
    "                               mean_traj - traj_std / 2, \n",
    "                               alpha=0.3)\n",
    "    axs[row, col].grid()\n",
    "    axs[row, col].set_title(f\"Guidance w = {w}\")\n",
    "\n",
    "plt.ylim(0, 10)\n",
    "fig.text(0.53, 0.04, 'Iteration (t)', ha='center')\n",
    "fig.text(0.04, 0.53, 'Mean X Coordinate', va='center', rotation='vertical')\n",
    "plt.tight_layout(rect=[0.04, 0.04, 1, 1])\n",
    "plt.savefig(\"synthetic_plots/gaussian_trajectory_grid.jpg\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
