{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import numpy as np\n",
    "import adaptive_latents as al\n",
    "from adaptive_latents.prediction_regression_run import pred_reg_run\n",
    "import matplotlib.pyplot as plt\n",
    "import functools\n",
    "from collections import namedtuple\n",
    "from scipy.linalg import null_space\n",
    "from adaptive_latents import proSVD\n",
    "from tqdm.autonotebook import tqdm\n",
    "import warnings\n",
    "import matplotlib as mpl\n",
    "rng = np.random.default_rng()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = al.datasets.Odoherty21Dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sort([1,2,3])[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "stim_times = [100, 110, 200]\n",
    "\n",
    "def normalize(x):\n",
    "    return x/np.linalg.norm(x)\n",
    "\n",
    "def get_pipeline_proj_matrix(p):\n",
    "    for step in reversed(p.steps):\n",
    "        if isinstance(step, al.proSVD):\n",
    "            return step.Q.T\n",
    "        if isinstance(step, al.sjPCA):\n",
    "            warnings.warn('check that this order of ops is correct')\n",
    "        if isinstance(step, al.mmICA):\n",
    "            warnings.warn('check that this order of ops is correct')\n",
    "\n",
    "    warnings.warn('check that this shape works')\n",
    "    return 1\n",
    "\n",
    "def give_decay(x, decay_time=6, divisor=1.5):\n",
    "    response_time = np.arange(decay_time)\n",
    "    response_decay = np.exp(-response_time/divisor)\n",
    "    return response_decay[:,None] @ x.flatten()[None,:]\n",
    "\n",
    "def in_out_ratio(pipeline, ratio=.5, magnitude=10):\n",
    "    proj_mat = get_pipeline_proj_matrix(p=pipeline)\n",
    "    x = np.ones(shape=proj_mat.shape[1])\n",
    "    x = x - proj_mat.T @ proj_mat @ x\n",
    "    x_orth = x.flatten() / np.linalg.norm(x)\n",
    "    x_in = proj_mat[0].flatten()\n",
    "    response_direction = (ratio*x_in + (1-ratio)*x_orth) * magnitude\n",
    "    return give_decay(response_direction)\n",
    "    \n",
    "def single_neuron_null(pipeline, magnitude=10):\n",
    "    proj_mat = get_pipeline_proj_matrix(p=pipeline)\n",
    "    idx = np.argmin(np.abs(proj_mat).sum(axis=0))\n",
    "    response_direction = np.zeros(proj_mat.shape[1])\n",
    "    response_direction[idx] = magnitude\n",
    "    return give_decay(response_direction)\n",
    "\n",
    "def towards_null_direction(pipeline, magnitude=10):\n",
    "    proj_mat = get_pipeline_proj_matrix(p=pipeline)\n",
    "    x = np.ones(shape=proj_mat.shape[1])\n",
    "    x = x - proj_mat.T @ proj_mat @ x\n",
    "    x = x / np.linalg.norm(x)\n",
    "    response_direction = x * magnitude\n",
    "    \n",
    "    return give_decay(response_direction)\n",
    "\n",
    "def towards_prosvd_direction(pipeline, magnitude=10, component=0):\n",
    "    mat = get_pipeline_proj_matrix(p=pipeline)\n",
    "    response_direction = normalize(mat[component]) * magnitude\n",
    "    \n",
    "    return give_decay(response_direction)\n",
    "\n",
    "def towards_prosvd_direction_keep_top(pipeline, magnitude=10, component=0, zero_negative=False, keep_top=30):\n",
    "    mat = get_pipeline_proj_matrix(p=pipeline)\n",
    "    response_direction = mat[component]\n",
    "    threshold = np.sort(response_direction**2)[-keep_top]\n",
    "    response_direction[response_direction**2 < threshold] = 0\n",
    "    if zero_negative:\n",
    "        response_direction[response_direction < 0] = 0\n",
    "    response_direction = normalize(response_direction) * magnitude\n",
    "\n",
    "    return give_decay(response_direction)\n",
    "\n",
    "def exp_resp(pipeline, magnitude=10):\n",
    "    response_direction = np.mean(dataset.neural_data, axis=0) * magnitude\n",
    "    response_direction = normalize(response_direction)\n",
    "    \n",
    "    return give_decay(response_direction)\n",
    "\n",
    "def zero_resp(pipeline):\n",
    "    response_direction = np.zeros(dataset.neural_data.shape[1])\n",
    "    \n",
    "    return give_decay(response_direction)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "p = al.proSVD(k=2)\n",
    "\n",
    "points = rng.normal(size=(300,3)) * np.array([.1,1,1])\n",
    "for x in points.reshape(3,-1,3):\n",
    "    p.partial_fit_transform(x)\n",
    "\n",
    "x = lambda: None\n",
    "x.steps = [p]\n",
    "\n",
    "response = towards_null_direction(x, magnitude=5)\n",
    "\n",
    "fig, ax = plt.subplots(subplot_kw={'projection': '3d'})\n",
    "ax.scatter(*points.T)\n",
    "\n",
    "\n",
    "ax.scatter(*response.T, color='r')\n",
    "ax.set_xlabel('null axis')\n",
    "ax.set_ylabel('manifold axis 1')\n",
    "ax.set_zlabel('manifold axis 2')\n",
    "\n",
    "ax.set_title('null space stimulation control')\n",
    "\n",
    "ax.legend(['observed points', 'stimulation'])\n",
    "\n",
    "ax.axis('equal');\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_in_series(stim_times, make_response):\n",
    "\n",
    "    stim_passed = [False for _ in stim_times]\n",
    "\n",
    "    evaluation = pred_reg_run(\n",
    "        neural_data=dataset.neural_data, \n",
    "        behavioral_data=dataset.behavioral_data[:,:0], \n",
    "        target_data=dataset.behavioral_data[:,:0],\n",
    "        dim_red_method='pro',\n",
    "        predict=False, evaluate=False)\n",
    "\n",
    "    pipeline = evaluation.pipeline\n",
    "\n",
    "    pending_responses = []\n",
    "\n",
    "    outputs = []\n",
    "    for d, s in al.Pipeline().streaming_run_on(evaluation.sources, return_output_stream=True):\n",
    "        if s == 0:\n",
    "            for i in range(len(stim_times)):\n",
    "                if not stim_passed[i] and d.t > stim_times[i]:\n",
    "                    response = make_response(pipeline)\n",
    "                    pending_responses.append(list(response))\n",
    "                    stim_passed[i] = True\n",
    "                    \n",
    "            for r in pending_responses:\n",
    "                if r:\n",
    "                    d = d + r.pop(0)\n",
    "        \n",
    "        d, s = evaluation.pipeline.partial_fit_transform(d, s, return_output_stream=True)\n",
    "        if s == 0:\n",
    "            outputs.append(d)\n",
    "        \n",
    "        if d.t > stim_times[-1] + 10:\n",
    "            break\n",
    "            \n",
    "    latents = al.ArrayWithTime.from_list(outputs, squeeze_type=\"to_2d\")\n",
    "    \n",
    "    return latents, evaluation\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def subtract_aligned_indices(a,b):\n",
    "    return al.ArrayWithTime.subtract_aligned_indices(a,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "Sim = namedtuple('Sim', ['pipeline', 'response', 'outputs', 'expire_time'])\n",
    "\n",
    "def evaluate_in_parallel(stim_times, make_response, end_t=200, response_cutoff=5, freeze_prosvd=False):\n",
    "\n",
    "    stim_times = list(sorted(stim_times))\n",
    "    assert len(stim_times) == len(set(stim_times))\n",
    "\n",
    "    evaluation = pred_reg_run(\n",
    "        neural_data=dataset.neural_data,\n",
    "        behavioral_data=dataset.behavioral_data[:,:0],\n",
    "        target_data=dataset.behavioral_data[:,:0],\n",
    "        dim_red_method='pro',\n",
    "        predict=False, evaluate=False)\n",
    "\n",
    "    sims = [Sim(evaluation.pipeline, [], [], end_t)]\n",
    "\n",
    "\n",
    "    for d, s in al.Pipeline().streaming_run_on(evaluation.sources, return_output_stream=True):\n",
    "        if s == 0 and len(stim_times) and d.t > stim_times[0]:\n",
    "            stim_times.pop(0)\n",
    "            response = make_response(sims[0].pipeline)\n",
    "            sims.append(Sim(copy.deepcopy(sims[0].pipeline), list(response), [], response_cutoff))\n",
    "            if freeze_prosvd:\n",
    "                sims[-1].pipeline.steps[-3].freeze()\n",
    "            \n",
    "\n",
    "        for sim in sims:\n",
    "            if d.t - (sim.outputs[0].t if sim.outputs else d.t) > sim.expire_time:\n",
    "                continue\n",
    "            inner_d = copy.deepcopy(d)\n",
    "            inner_s = copy.deepcopy(s)\n",
    "            if inner_s == 0 and len(sim.response):\n",
    "                inner_d += sim.response.pop(0)\n",
    "            \n",
    "            inner_d, inner_s = sim.pipeline.partial_fit_transform(inner_d, inner_s, return_output_stream=True)\n",
    "            \n",
    "            if inner_s == 0:\n",
    "                sim.outputs.append(inner_d)\n",
    "        \n",
    "\n",
    "        if not len(sims):\n",
    "            break\n",
    "\n",
    "    latents = [al.ArrayWithTime.from_list(sim.outputs, squeeze_type=\"to_2d\") for sim in sims]\n",
    "\n",
    "    return latents, evaluation\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "ys = np.linspace(.1,50,10)\n",
    "xs = np.linspace(0,1,7)\n",
    "rrs = []\n",
    "for freeze in [False, True]:\n",
    "    rr = []\n",
    "    for y in tqdm(ys):\n",
    "        responses = []\n",
    "        for x in xs:\n",
    "            latents, _ = evaluate_in_parallel([100], functools.partial(in_out_ratio, ratio=x, magnitude=y), end_t=120, freeze_prosvd=freeze)\n",
    "            response = np.sqrt((subtract_aligned_indices(latents[1].slice(None,60), latents[0])**2).mean())\n",
    "            responses.append(response)\n",
    "        responses = np.array(responses)/responses[-1]\n",
    "        rr.append(responses)\n",
    "    rr = np.array(rr)\n",
    "    rrs.append(rr)\n",
    "rrs = np.array(rrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "fig, axs = plt.subplots(ncols=2, figsize=(10,4), sharex=True, sharey=True)\n",
    "\n",
    "colors = plt.cm.jet(np.linspace(0,1,rr.shape[0]))\n",
    "for ax, inner_rr in zip(axs, rrs):\n",
    "    for idx, rrr in enumerate(inner_rr):\n",
    "        ax.plot(xs, rrr, color=colors[idx])\n",
    "        \n",
    "    ax.set_xlabel('proportion of stimulation in-space')\n",
    "    ax.set_ylabel('magnitude of total response (compared to in-space stim)')\n",
    "ax.legend([f'{y:.1f}'for y in ys], title='stim. magnitude')\n",
    "ax.set_title('large stimulations scale nonlinearly')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "difference = subtract_aligned_indices(latents[1].slice(None,60), latents[0])\n",
    "\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.plot(difference.t, difference)\n",
    "ax.set_xlabel('experiment time')\n",
    "ax.set_ylabel('difference magnitude')\n",
    "ax.set_title('Difference between stimulated and unstimulated latents')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "difference = subtract_aligned_indices(latents[1].slice(None,60), latents[0])\n",
    "\n",
    "fig, axs = plt.subplots(ncols=2, figsize=(10,5), sharex=True)\n",
    "axs[1].plot(latents[1].t, latents[1])\n",
    "axs[0].plot(latents[0].t, latents[0])\n",
    "axs[0].set_xlim([99.9, 105.1])\n",
    "axs[0].set_xlabel('experiment time')\n",
    "axs[1].set_xlabel('experiment time')\n",
    "\n",
    "axs[0].set_title('unstimulated latents')\n",
    "axs[1].set_title('stimulated latents')\n",
    "\n",
    "\n",
    "axs[0].set_ylabel('latent magnitude (a.u.)')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "n_components = [1,5,10,20, 50, 130]\n",
    "latents = []\n",
    "for n in tqdm(n_components):\n",
    "    l, _ = evaluate_in_parallel([100], functools.partial(towards_prosvd_direction_keep_top, magnitude=20, keep_top=n, zero_negative=True and (not n==130)), end_t=120, freeze_prosvd=True)\n",
    "    latents.append(l[1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=3, ncols=len(n_components), sharex=True, sharey='row', figsize=(15,5))\n",
    "\n",
    "for i in range(len(latents)):\n",
    "    axs[0, i].plot(latents[i])\n",
    "    axs[1, i].plot(latents[i] - latents[-1])\n",
    "    axs[2, i].plot(subtract_aligned_indices(latents[i],  l[0]))\n",
    "    axs[0, i].set_title(f'kept top {n_components[i]}')\n",
    "axs[0, -1].set_title(f'full vector')\n",
    "axs[0,0].set_ylabel('resp.')\n",
    "axs[1,0].set_ylabel('resp. - resp.[-1]')\n",
    "axs[2,0].set_ylabel('resp. - no stim')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "fig, ax = plt.subplots()\n",
    "l1, _ = evaluate_in_series(stim_times, functools.partial(towards_null_direction, magnitude=10))\n",
    "\n",
    "ax.plot(l1.t, l1, '.-')\n",
    "ax.set_xlim([99, 104])\n",
    "ax.set_title(\"Stimulated latents - unstimulated latents\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "l1, _ = evaluate_in_series(stim_times, functools.partial(towards_null_direction, magnitude=100))\n",
    "l2, _ = evaluate_in_series(stim_times, zero_resp)\n",
    "\n",
    "ax.plot(l1.t, l1 - l2, '.-');\n",
    "ax.set_xlim([99, 104])\n",
    "ax.set_title(\"Stimulated latents - unstimulated latents\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.nanstd(l2, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "differences = []\n",
    "magnitudes = [.5,1,2,4,8, 16, 32, 64, 128]\n",
    "for m in tqdm(magnitudes):\n",
    "    l1, _ = evaluate_in_series(stim_times, functools.partial(towards_null_direction, magnitude=m))\n",
    "    l2, _ = evaluate_in_series(stim_times, zero_resp)\n",
    "    differences.append(np.nanmean(np.abs(l1 - l2)))\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.plot(magnitudes, differences, '.')\n",
    "ax.set_xlabel('stimulation magnitude')\n",
    "ax.set_ylabel('perturbation magnitude')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib qt\n",
    "plt.plot(l1.t, l1);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
