{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import plotly.graph_objects as go\n",
    "import plotly.figure_factory as ff\n",
    "import numpy as np\n",
    "import torch\n",
    "import scipy"
   ],
   "id": "740f4a9b33cd1cad"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "\n",
    "# data generation\n",
    "\n",
    "scaleratio=0.8\n",
    "scale=0.4\n",
    "\n",
    "# create data\n",
    "xy = np.array(np.meshgrid(np.arange(-1, 1.1, .01), np.arange(-1, 1.1, .01)))\n",
    "\n",
    "xy_ = torch.tensor(xy, requires_grad=True)\n",
    "\n",
    "_f1 = lambda _x, _y: ((_x-1) ** 2 * torch.exp(-_y ** 2) + _y**4) / 5\n",
    "_f2 = lambda _x, _y: ((_y-1) ** 2 * torch.exp(-_y ** 2) + _x**4) / 5\n",
    "\n",
    "_f1 = lambda _x, _y: ((_x - _y ** 2 + .3) ** 2 / 10 +\n",
    "                      ((_x + .3 * _y - .2) ** 2) ** 1 / 5) ** 1 / 2\n",
    "_f2 = lambda _x, _y: (((_x + _y)*0.7 - ((_x -_y)*0.7) ** 2 + .3) ** 2 / 10 +\n",
    "                      (_x - .3 * _y + .4) ** 2 / 6) ** .9 / 2\n",
    "\n",
    "# _f1 = lambda _x, _y: (2*(_x-1) ** 2 + _y**2) / 5\n",
    "# _f2 = lambda _x, _y: ((_y-1) ** 2 + _x**2) / 5\n",
    "\n",
    "# _f1 = lambda _x, _y: (2*(_x-1) ** 2 + .5 * _y**2) / 5\n",
    "# _f2 = lambda _x, _y: (.3*(_x+1) ** 2 + 1 * _y**2) / 2\n",
    "\n",
    "xy_.grad = None\n",
    "f1_ = _f1(*xy_)\n",
    "f1 = f1_.detach().cpu().numpy()\n",
    "f1_.sum().backward()\n",
    "v1 = -xy_.grad.detach().cpu().numpy()\n",
    "\n",
    "xy_.grad = None\n",
    "f2_ = _f2(*xy_)\n",
    "f2 = f2_.detach().cpu().numpy()\n",
    "f2_.sum().backward()\n",
    "v2 = -xy_.grad.detach().cpu().numpy()\n",
    "\n",
    "hessians = []\n",
    "for _f in [_f1, _f2]:\n",
    "    hess = []\n",
    "    for i in range(len(xy_)):\n",
    "        xy_.grad = None\n",
    "        f2_ = _f(*xy_)\n",
    "        f2_.sum().backward(create_graph=True)\n",
    "        xy_.grad[i].sum().backward()\n",
    "        hess.append(-xy_.grad.detach().cpu().numpy())\n",
    "    hessians.append(hess)\n",
    "hessians = np.array(hessians)\n",
    "\n",
    "comm = np.einsum('baxy,axy->bxy',  hessians[0], v2) - np.einsum('baxy,axy->bxy',  hessians[1], v1)"
   ],
   "id": "b6e29ff24f15f21c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# vector fields plotting\n",
    "\n",
    "fig = ff.create_quiver(*xy[:, ::10, ::10], *v1[:, ::10, ::10], scale=scale, scaleratio=scaleratio)\n",
    "\n",
    "trace2 = ff.create_quiver(*xy[:, ::10, ::10], *v2[:, ::10, ::10], scale=scale, scaleratio=scaleratio).data[0]\n",
    "\n",
    "trace3 = ff.create_quiver(*xy[:, ::10, ::10], *comm[:, ::10, ::10]*3, scale=scale, scaleratio=scaleratio).data[0]\n",
    "\n",
    "fig.add_traces(trace2)\n",
    "fig.add_traces(trace3)\n",
    "\n",
    "\n",
    "fig.update_layout(\n",
    "        autosize=False,\n",
    "        margin=dict(l=20, r=20, t=20, b=20),\n",
    "        width=750,\n",
    "        height=600,\n",
    "        xaxis={'range':[-1,1]},\n",
    "        yaxis={'range':[-1,1]},\n",
    ")\n",
    "fig.show()"
   ],
   "id": "8ec4cc60ef4b30b5"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "94376e425210fbef"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "ff1d91b833c61c25"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "50d74633d409e186"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "36d16854b0ea40db"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "564b535b7d0ca03b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "cd58b89535279363"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "829a7152bbda806e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "6d5ed17f8b48836f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "2db7694719529674"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
