{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "from jax.nn import relu\n",
    "from tqdm.autonotebook import tqdm\n",
    "\n",
    "import adaptive_latents\n",
    "from adaptive_latents import StreamingKalmanFilter, ArrayWithTime, Pipeline, proSVD, CenteringTransformer\n",
    "from adaptive_latents.input_sources.autoregressor import AdamOptimizer\n",
    "from adaptive_latents.regressions import BaseKNearestNeighborRegressor\n",
    "from adaptive_latents.stim_regressor import StimRegressor\n",
    "\n",
    "rng = np.random.default_rng(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = adaptive_latents.datasets.Odoherty21Dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def high_d_S(low_d_point, high_d_stim):\n",
    "    return high_d_stim\n",
    "\n",
    "def S(low_d_point, high_d_stim, pro):\n",
    "    return pro.transform(high_d_S(low_d_point, high_d_stim)[None,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss(s, v, lam_1=1e-3):\n",
    "    u = s  # this assumes for now that the dynamics S function is an identity\n",
    "    return (\n",
    "            - jnp.sqrt(jnp.linalg.norm(v.T @ s))**2  # maximize dot product with the target vector\n",
    "            + jnp.linalg.norm(s - v @ v.T @ s)**2  # minimize orthogonal component\n",
    "            + jnp.linalg.norm(u, ord=1) * lam_1  # L1 penalty\n",
    "    )\n",
    "\n",
    "grad_loss = jax.jit(jax.value_and_grad(loss))\n",
    "\n",
    "\n",
    "def design_stim(v, N=30, convergence_threshold=1e-2, max_outer_iters=20, max_inner_iters=250):\n",
    "    v = np.atleast_2d(v.T).T\n",
    "\n",
    "    s_history = []\n",
    "    loss_history = []\n",
    "\n",
    "    lam_1 = 1e-3\n",
    "\n",
    "    for _ in range(max_outer_iters):\n",
    "        s = rng.uniform(size=(130,)) * .1\n",
    "        s_optimizer = AdamOptimizer(lr=0.005)\n",
    "\n",
    "        for i in range(max_inner_iters):\n",
    "            val, grad = grad_loss(s, v, lam_1=lam_1)\n",
    "            s = s_optimizer.update(s,grad)\n",
    "            s = relu(s)\n",
    "\n",
    "            s_history.append(s)\n",
    "            loss_history.append(val)\n",
    "\n",
    "            if len(s_history) > 10 and jnp.linalg.norm(s_history[-2] - s_history[-1]) < convergence_threshold:\n",
    "                break\n",
    "\n",
    "        l0 = jnp.linalg.norm(s,ord=0)\n",
    "        if 0 < l0 <= N:\n",
    "            break\n",
    "        if l0 == 0 or jnp.isnan(s).any():\n",
    "            lam_1 /= 1.2\n",
    "        else:\n",
    "            lam_1 *= 2\n",
    "    return np.array(s / s.max())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_experiment(input_arrays, decay_rate=.9, stim_scale=1, stim_p=0.05, rng=None, s_inputs_to_evaluate_on=None):\n",
    "    if rng is None:\n",
    "        rng = np.random.default_rng(0)\n",
    "\n",
    "    centerer = CenteringTransformer()\n",
    "\n",
    "    pro = proSVD(k=10)\n",
    "\n",
    "    sr = StimRegressor(\n",
    "        autoreg=StreamingKalmanFilter(),\n",
    "        stim_reg=BaseKNearestNeighborRegressor(k=1, maxlen=100),\n",
    "        attempt_correction=True\n",
    "    )\n",
    "\n",
    "    stims = []\n",
    "    predictions = []\n",
    "    latents = []\n",
    "    dt_X = []\n",
    "    s_hat_evals = []\n",
    "    s_eval = []\n",
    "\n",
    "    s_inputs_evaluated_on = []\n",
    "\n",
    "    to_add = np.zeros(input_arrays[0].shape[1])\n",
    "\n",
    "    for input_array in input_arrays:\n",
    "        pbar = tqdm(total=round(input_array.t[-1],2))\n",
    "        for data in Pipeline().streaming_run_on(input_array):\n",
    "\n",
    "            latent_location = pro.transform(centerer.transform(data))\n",
    "\n",
    "            stim = np.zeros(data.shape[1])\n",
    "            if rng.random() < stim_p and pro.is_initialized:\n",
    "                stim = design_stim(pro.Q[:,0]) * stim_scale\n",
    "                to_add += high_d_S(latent_location, stim)\n",
    "\n",
    "            if s_inputs_to_evaluate_on == 'record' and np.any(stim):\n",
    "                s_inputs_evaluated_on.append(copy.deepcopy((latent_location, stim, pro)))\n",
    "                s_eval.append(S(latent_location, stim, pro))\n",
    "\n",
    "            data = data + to_add\n",
    "            to_add = to_add * decay_rate\n",
    "\n",
    "            data = centerer.partial_fit_transform(data)\n",
    "            data = pro.partial_fit_transform(data)\n",
    "            latents.append(data)\n",
    "\n",
    "            qX = ArrayWithTime([[1]], data.t)\n",
    "            sr.partial_fit_transform(np.array([[stim]]), stream='stim')\n",
    "            prediction = sr.partial_fit_transform(qX, stream='dt_X')\n",
    "            sr.partial_fit_transform(data, stream='X')\n",
    "\n",
    "\n",
    "\n",
    "            stims.append(ArrayWithTime(stim, data.t))\n",
    "            predictions.append(prediction)\n",
    "            dt_X.append(qX)\n",
    "\n",
    "            if not isinstance(s_inputs_to_evaluate_on, str) and s_inputs_to_evaluate_on is not None and np.any(stim):\n",
    "                point_evals = []\n",
    "                for latent_location, stim, _ in s_inputs_to_evaluate_on:\n",
    "                    stim_reg_input = np.hstack([latent_location.flatten(), stim.flatten()])\n",
    "                    diff = sr.stim_reg.predict(stim_reg_input)\n",
    "                    point_evals.append(diff.flatten())\n",
    "                s_hat_evals.append(ArrayWithTime(point_evals,data.t))\n",
    "\n",
    "                if not s_eval:\n",
    "                    for inputs in s_inputs_to_evaluate_on:\n",
    "                        s_eval.append(S(*inputs))\n",
    "\n",
    "\n",
    "            if not sr.autoreg.parameter_fitting:\n",
    "                sr.autoreg.toggle_parameter_fitting(True)\n",
    "\n",
    "            pbar.update(round(data.t,2) - pbar.n)\n",
    "        sr.autoreg.toggle_parameter_fitting(False)\n",
    "\n",
    "\n",
    "    stims = ArrayWithTime.from_list(stims)\n",
    "    predictions = ArrayWithTime.from_list(predictions, drop_early_nans=True, squeeze_type='to_2d')\n",
    "    latents = ArrayWithTime.from_list(latents, squeeze_type='to_2d')\n",
    "    s_hat_evals = ArrayWithTime.from_list(s_hat_evals, drop_early_nans=True)\n",
    "    s_eval = np.squeeze(s_eval)\n",
    "    return (predictions, latents, stims), (s_hat_evals, s_eval, s_inputs_evaluated_on)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "stim_scale = 20\n",
    "(predictions, latents, stims), _ = do_experiment([d.neural_data], stim_p=0.005, stim_scale=stim_scale)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions.shape, latents.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
