{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a49fde41-5245-411f-8058-93d9e523ad2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from functorch import vmap, grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "eabf5c96-0f5e-4ca5-b3d5-d6a44e17fd98",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.distributions as D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "79a47b3d-bf3c-4ac8-8670-7ee4967bdd9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mix = D.Categorical(torch.ones(2,))\n",
    "comp = D.Normal(torch.tensor([-1.0, 1.0]), torch.ones(2,))\n",
    "gmm = D.MixtureSameFamily(mix, comp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f1e9840b-bf51-41d4-8166-1ca9fa04fa23",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.ones(1).requires_grad_(True)\n",
    "loss = gmm.log_prob(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "c90ed32e-11a1-4aec-90fa-436cce73ace3",
   "metadata": {},
   "outputs": [],
   "source": [
    "fn = lambda x: gmm.log_prob(x).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5a9177f4-2600-4338-9dff-7e5c5eece318",
   "metadata": {},
   "outputs": [],
   "source": [
    "grad_fn = grad(fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "8df323ea-45c9-467c-a406-9f2c9160080d",
   "metadata": {},
   "outputs": [],
   "source": [
    "v_grad_fn = vmap(grad_fn, (0, ), (0, ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "76639118-8051-4015-974c-5f81eeb62fb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jiamings/.miniconda3/envs/ddrm/lib/python3.9/site-packages/torch/distributions/distribution.py:292: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::all. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at  /__w/functorch/functorch/functorch/csrc/BatchedFallback.cpp:83.)\n",
      "  if not valid.all():\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "vmap: It looks like you're either (1) calling .item() on a Tensor or (2) attempting to use a Tensor in some data-dependent control flow or (3) encountering this error in PyTorch internals. For (1): we don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. For (2): If you're doing some control flow instead, we don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . For (3): please file an issue.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_1669761/768997495.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mv_grad_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/.miniconda3/envs/ddrm/lib/python3.9/site-packages/functorch/_src/vmap.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    363\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    364\u001b[0m             \u001b[0mbatched_inputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_create_batched_inputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_in_dims\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvmap_level\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs_spec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 365\u001b[0;31m             \u001b[0mbatched_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mbatched_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    366\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0m_unwrap_batched\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatched_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_dims\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvmap_level\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    367\u001b[0m         \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.miniconda3/envs/ddrm/lib/python3.9/site-packages/functorch/_src/eager_transforms.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1203\u001b[0m     \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1204\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1205\u001b[0;31m         \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgrad_and_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margnums\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhas_aux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1206\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1207\u001b[0m             \u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.miniconda3/envs/ddrm/lib/python3.9/site-packages/functorch/_src/eager_transforms.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1073\u001b[0m                 \u001b[0mtree_map_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_create_differentiable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiff_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1074\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1075\u001b[0;31m                 \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1076\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1077\u001b[0m                     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1669761/324354906.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(x)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mgmm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/.miniconda3/envs/ddrm/lib/python3.9/site-packages/torch/distributions/mixture_same_family.py\u001b[0m in \u001b[0;36mlog_prob\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    146\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mlog_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    147\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_args\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 148\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_sample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    149\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    150\u001b[0m         \u001b[0mlog_prob_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomponent_distribution\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# [S, B, k]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.miniconda3/envs/ddrm/lib/python3.9/site-packages/torch/distributions/distribution.py\u001b[0m in \u001b[0;36m_validate_sample\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m    290\u001b[0m         \u001b[0;32massert\u001b[0m \u001b[0msupport\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    291\u001b[0m         \u001b[0mvalid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msupport\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 292\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mvalid\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    293\u001b[0m             raise ValueError(\n\u001b[1;32m    294\u001b[0m                 \u001b[0;34m\"Expected value argument \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: vmap: It looks like you're either (1) calling .item() on a Tensor or (2) attempting to use a Tensor in some data-dependent control flow or (3) encountering this error in PyTorch internals. For (1): we don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. For (2): If you're doing some control flow instead, we don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . For (3): please file an issue."
     ]
    }
   ],
   "source": [
    "v_grad_fn(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfe4391b-03f3-45dc-9940-45b3f058d17f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
