{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from learn_s_hat_toy import make_srs\n",
    "import copy\n",
    "from adaptive_latents.utils import save_to_cache\n",
    "from adaptive_latents import ArrayWithTime\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "output = None\n",
    "use_cache = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "@save_to_cache('spinning_toy')\n",
    "def f(u_function='curvy spins', n_runs=100):\n",
    "    rng = np.random.default_rng(0)\n",
    "    srs = make_srs(copy.deepcopy(rng), n_runs=n_runs, u_function=u_function, show_tqdm=True, n_rotations=100)\n",
    "    return rng, srs\n",
    "\n",
    "rng, srs = f('curvy spins', n_runs=50, _recalculate_cache_value=not use_cache)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "def square_err_array(errs):\n",
    "    lengths = [len(e.t) for e in errs]\n",
    "    common_length = min(lengths)\n",
    "    longest_idx = np.argmax(lengths)\n",
    "    # a = [np.linalg.norm(e.slice(slice(-common_length, None))**2, axis=1) for e in errs]\n",
    "    # t = errs[0].slice(slice(-common_length, None)).t\n",
    "    # return ArrayWithTime(a, t)\n",
    "\n",
    "    max_length = max(lengths)\n",
    "\n",
    "    assert np.std(np.vstack([e.t[-common_length:] for e in errs]), axis=0).max() < 1e-10\n",
    "\n",
    "    filled_errors = []\n",
    "    for e in errs:\n",
    "        early_nans = np.array([e[0] * np.nan] * (max_length - len(e))).reshape(-1,3)\n",
    "        padded_e = np.vstack([np.squeeze(early_nans), e])\n",
    "        filled_errors.append(padded_e)\n",
    "\n",
    "\n",
    "    a = np.squeeze([np.linalg.norm(e, axis=1)**2 for e in filled_errors])\n",
    "    t = errs[longest_idx].t\n",
    "    good_idx = np.nonzero(np.isnan(a).sum(axis=0) < int(a.shape[0] * .75))[0][0]\n",
    "    good_idx = 10\n",
    "    ret = ArrayWithTime(a[:,good_idx:], t[good_idx:])\n",
    "    return ret\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(figsize=np.array((56,23.6))/5, layout='constrained')\n",
    "\n",
    "ax2 = ax.twinx()\n",
    "\n",
    "kernel = np.hstack([np.zeros(50), np.ones(51)])\n",
    "# kernel = np.ones(51)\n",
    "kernel = kernel / kernel.sum()\n",
    "time_kernel = kernel * 0\n",
    "time_kernel[len(kernel)//2] = 1\n",
    "\n",
    "errs = square_err_array([sr.log['pred_error'] for sr in srs['unaware of stim']])\n",
    "mean_errors = np.nanmean(errs, axis=0)\n",
    "ax.plot(errs.t, errs[0], label='unaware (single trial)', color='#4d4d4dff', alpha=.1,)\n",
    "smoothed_mean_errs = ArrayWithTime(np.convolve(mean_errors, kernel, 'valid'), np.convolve(errs.t, time_kernel, 'valid'))\n",
    "ax2.plot(smoothed_mean_errs.t, smoothed_mean_errs, label='unaware (averaged, smoothed)', color='#4d4d4dff')\n",
    "\n",
    "errs = square_err_array([sr.log['pred_error'] for sr in srs['learning from stim']])\n",
    "mean_errors = np.nanmean(errs, axis=0)\n",
    "ax.plot(errs.t, errs[0], label='aware (single trial)', color='#ca1469ff', alpha=.1)\n",
    "smoothed_mean_errs = ArrayWithTime(np.convolve(mean_errors, kernel, 'valid'), np.convolve(errs.t, time_kernel, 'valid'))\n",
    "ax2.plot(smoothed_mean_errs.t, smoothed_mean_errs, label='aware (averaged, smoothed)', color='#ca1469ff')\n",
    "ax2.set_ylim([0,13])\n",
    "ax.set_ylim(ymin=0)\n",
    "\n",
    "ax2.axvline(25, color='gray', linestyle='--')\n",
    "ax2.axvline(45, color='gray', linestyle='--')\n",
    "ax2.axvline(75, color='gray', linestyle='--')\n",
    "\n",
    "ax.set_xlabel('time (s)')\n",
    "ax.set_ylabel('error')\n",
    "ax2.set_ylabel('averaged error')\n",
    "ax2.legend(loc='upper right')\n",
    "ax.legend(loc='upper left')\n",
    "ax.set_xlim([0,100])\n",
    "ax2.set_xlim([0,100])\n",
    "\n",
    "ax.set_ylim([0,290])\n",
    "\n",
    "# x = np.linspace(0,50,200)\n",
    "# ax2.plot(x+50,  % (2 * np.pi))\n",
    "\n",
    "if output is not None:\n",
    "    fig.savefig(output)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "for sr in srs['learning from stim']:\n",
    "    sr.stim_reg.plot_length_scales(ax)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
