{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from adaptive_latents.vjf import VJF, BaseVJF\n",
    "import vjf.online\n",
    "import torch\n",
    "from tqdm.notebook import trange\n",
    "from adaptive_latents.input_sources import LDS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set seed(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "rng = np.random.default_rng(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "x, y, stim = LDS.run_nest_dynamical_system(500, rng=rng)\n",
    "\n",
    "u = 0*stim\n",
    "\n",
    "xdim = 2\n",
    "udim = u.shape[-1]\n",
    "ydim = y.shape[-1]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "\n",
    "\n",
    "config=dict(\n",
    "    resume=False,\n",
    "    xdim=xdim,  # dimension of hidden state\n",
    "    ydim=ydim,  # dimension of observations\n",
    "    udim=1,  # dimension of control vector\n",
    "    Ydim=udim,\n",
    "    Udim=udim,\n",
    "    rdim=50,  # number of RBFs\n",
    "    hdim=100,  # number of MLP hidden units\n",
    "    lr=1e-3,  # learning rate\n",
    "    clip_gradients=5.0,\n",
    "    debug=False,\n",
    "    likelihood='gaussian',  # \n",
    "    system='rbf',\n",
    "    recognizer='mlp',\n",
    "    C=(None, True),  # loading matrix: (initial, estimate)\n",
    "    b=(None, True),  # bias: (initial, estimate)\n",
    "    A=(None, False),  # transition matrix if LDS\n",
    "    B=(np.zeros((xdim, udim)), False),  # interaction matrix\n",
    "    Q=(1.0, True),  # state noise\n",
    "    R=(1.0, True),  # observation noise\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## minimal run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mdl = vjf.online.VJF(config=config)\n",
    "\n",
    "ys = torch.from_numpy(y).float()\n",
    "us = torch.from_numpy(u).float()\n",
    "\n",
    "mu = torch.zeros(ys.shape[0], xdim)\n",
    "q = None  # current state\n",
    "\n",
    "for i in np.arange(ys.shape[0]):\n",
    "    q, _ = mdl.feed((ys[i:i+1], us[i:i+1]), q0=q)\n",
    "    mu[i,:], _ = q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "ax.plot(*mu[-200:].detach().numpy().T)\n",
    "ax.axis('equal');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v = VJF(input_streams={0:'X'}, latent_d=2, rng=rng)\n",
    "ret = v.offline_run_on([y], show_tqdm=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "fig, axs = plt.subplots(ncols=2, figsize=(10, 4))\n",
    "\n",
    "axs[0].plot(ret[-200:,0], ret[-200:,1])\n",
    "axs[0].set_title('VJF estimated latent state');\n",
    "axs[0].axis('equal');\n",
    "\n",
    "axs[1].plot(x[-200:, 0], x[-200:, 1])\n",
    "axs[1].set_title('true latent state');\n",
    "axs[1].axis('equal');\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "cloud = v.get_cloud_at_time_t(0).detach()\n",
    "preds = v._vjf.decoder(cloud).detach().numpy()\n",
    "\n",
    "\n",
    "n = 31\n",
    "x_edges = np.linspace(-20,20, n)\n",
    "y_edges = np.linspace(-20,20, n)\n",
    "x_centers = np.convolve([.5,.5], x_edges, mode='valid')\n",
    "y_centers = np.convolve([.5,.5], y_edges, mode='valid')\n",
    "log_probs = np.zeros((len(y_centers), len(x_centers)))\n",
    "for i, y_i in enumerate(y_centers):\n",
    "    for j, x_j in enumerate(x_centers):\n",
    "        log_probs[i,j] = v.get_logprob_for_cloud(cloud=cloud, point=np.array([x_j,y_i,0]))\n",
    "i,j = np.unravel_index(np.argmax(log_probs), log_probs.shape)\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(ncols=2, figsize=(10, 4), sharey=True, sharex=True, subplot_kw={'adjustable': 'box', 'aspect':1})\n",
    "\n",
    "axs[0].plot(x[-200:, 0], x[-200:, 1])\n",
    "axs[0].scatter(preds[:, 0], preds[:, 1], color='C1', s=5, zorder=3)\n",
    "axs[0].scatter(preds[:, 0].mean(), preds[:, 1].mean(), color='C2', zorder=3)\n",
    "\n",
    "axs[1].pcolormesh(x_edges,y_edges,log_probs, vmin=np.quantile(log_probs.flatten(), .5), vmax=log_probs.max(), cmap='plasma')\n",
    "axs[1].scatter(x_centers[j], y_centers[i], color='C2')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v.predict(0,method='asdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "v = VJF(latent_d=2, rng=rng)\n",
    "mu, logvar, losses = v.fit(y=y[:-15])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "\n",
    "ax.plot(mu[:,0], mu[:,1])\n",
    "ax.axis('equal');\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## run with log_pred_p evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cpu' # 'cuda' does not work"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def log_step(mdl, ys, t, S=1000, T=10):\n",
    "    mdl: BaseVJF\n",
    "    x = mdl.generate_cloud()\n",
    "    \n",
    "    logprobs = []\n",
    "    distances = []\n",
    "    for i in range(T):\n",
    "        if t + i < ys.shape[0]:\n",
    "            y_tprime = ys[t + i].cpu().numpy()\n",
    "        else:\n",
    "            y_tprime = ys[t].cpu().numpy() * np.nan\n",
    "\n",
    "        x = mdl.step_for_cloud(x)\n",
    "        logprob = mdl.get_logprob_for_cloud(x, y_tprime)\n",
    "        distance = mdl.get_distance_for_cloud(x, y_tprime)\n",
    "\n",
    "        logprobs.append(logprob)\n",
    "        distances.append(distance)\n",
    "\n",
    "\n",
    "    return logprobs, distances\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mdl = BaseVJF(config=config, latent_d=2)\n",
    "\n",
    "mdl.init_vjf(ydim, udim)\n",
    "\n",
    "ys = torch.from_numpy(y).float()\n",
    "us = torch.from_numpy(u).float()\n",
    "\n",
    "\n",
    "logprobs = []\n",
    "distances = []\n",
    "mu2 = np.zeros((ys.shape[0], xdim))\n",
    "\n",
    "for t in trange(ys.shape[0]):\n",
    "    step_logprobs, step_distances = log_step(mdl, ys, t, T=1)\n",
    "    logprobs.append(step_logprobs)\n",
    "    distances.append(step_distances)\n",
    "\n",
    "    y_t = ys[t].unsqueeze(0)\n",
    "    u_t = us[t].unsqueeze(0)\n",
    "    mdl.observe(y_t, u_t)\n",
    "    \n",
    "    # mu2[t] = q[0].detach().numpy()\n",
    "\n",
    "logprobs, distances = np.array(logprobs), np.array(distances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 10\n",
    "y[i:] - y[:len(y)-i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "fig, axs = plt.subplots(nrows=2)\n",
    "\n",
    "for i in range(logprobs.shape[-1]):\n",
    "    axs[0].plot(logprobs[:, i], label=f\"{i+1} step{'s' if i > 0 else ''} ahead\")\n",
    "    axs[1].plot(distances[:, i], label=f\"{i+1} step{'s' if i > 0 else ''} ahead\")\n",
    "    \n",
    "for ax in axs:\n",
    "    ax.set_xlabel(\"time\")\n",
    "\n",
    "axs[0].set_ylabel(\"log probability\")\n",
    "axs[1].set_ylabel(\"average prediction distance\")\n",
    "axs[1].legend(bbox_to_anchor=(1.01, 0.95))\n",
    "# plt.ylim([-300, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Checks that the randomness is actually controlled by SEED.\n",
    "# This method is obviously hacky, but I'm keeping it because the logprobs\n",
    "# and distances were inconsistent between runs before, despite seeding.\n",
    "\n",
    "vars = {}\n",
    "for var in ['seed', 'y', 'mu', 'ys', 'mu2', 'distances', 'logprobs']:\n",
    "    v = globals()[var]\n",
    "    if isinstance(v, torch.Tensor):\n",
    "        v = v.detach().cpu().numpy()\n",
    "    vars[var] = v\n",
    "\n",
    "    s = f'/tmp/asdf_{var}'\n",
    "    try:\n",
    "        old_v = np.load(f\"{s}.npy\")\n",
    "    except FileNotFoundError:\n",
    "        old_v = None\n",
    "    np.save(s, v)\n",
    "\n",
    "    if old_v is not None:\n",
    "        same = np.shape(v) == np.shape(old_v) and np.nanmax((v-old_v)**2) == 0\n",
    "        print(f'{var}: {same}')\n",
    "    else:\n",
    "        print(f'{var}: NEW')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
