{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29c3d42b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import sys,os\n",
    "from pathlib import Path\n",
    "\n",
    "p = Path(os.getcwd()).parents[1]\n",
    "\n",
    "phi_kds = np.load(str(p) + '/trained_dicts/kds_dict_100_atoms_16_sz_patch_bs_256_lam_1e-1.npy')\n",
    "\n",
    "def encode_accelerated(y, W, penalty, layers=15):\n",
    "    step = W.svd().S[0] ** -2\n",
    "    x_tmp = torch.zeros(y.shape[0], W.shape[0], device=y.device)\n",
    "    x_old = torch.zeros(y.shape[0], W.shape[0], device=y.device)\n",
    "    weight = (y.square().sum(dim=1, keepdims=True)+ W.T.square().sum(dim=0, keepdims=True)- 2 * y @ W.T)\n",
    "    for layer in range(layers):\n",
    "        grad = (x_tmp @ W - y) @ W.T\n",
    "        grad = grad + weight * penalty\n",
    "        x_new = activate(x_tmp - grad * step)\n",
    "        x_old, x_tmp = x_new, x_new + layer / (layer + 3) * (x_new - x_old)\n",
    "    return x_new\n",
    "\n",
    "def activate(x):\n",
    "        m, n = x.shape\n",
    "        cnt_m = torch.arange(m, device=x.device)\n",
    "        cnt_n = torch.arange(n, device=x.device)\n",
    "        u = x.sort(dim=1, descending=True).values\n",
    "        v = (u.cumsum(dim=1) - 1) / (cnt_n + 1)\n",
    "        w = v[cnt_m, (u > v).sum(dim=1) - 1]\n",
    "        return (x - w.view(m, 1)).relu()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00bc5141",
   "metadata": {},
   "source": [
    "## STA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adce5ee5",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "idx = 0\n",
    "time_steps = 10000\n",
    "num_atoms = phi_kds.shape[0]\n",
    "patch_size = 16\n",
    "gwn = torch.randn((time_steps, patch_size**2))\n",
    "\n",
    "z = encode_accelerated(gwn, W = torch.tensor(phi_kds), penalty=0.01, layers=15)\n",
    "num_spikes = z.T[idx].sum()\n",
    "z_0 = z.T[idx].unsqueeze(dim=1)\n",
    "\n",
    "STA_raw =  1/num_spikes *(gwn.T @ z_0)\n",
    "STA = STA_raw.reshape(patch_size, patch_size)\n",
    "\n",
    "beta = 0.2\n",
    "ringach_rule_of_thumb = np.linalg.inv(phi_kds @ phi_kds.T + beta * np.eye(num_atoms)) @ phi_kds\n",
    "\n",
    "f, axarr = plt.subplots(1,4, figsize=(26, 15))\n",
    "axarr[0].imshow(STA / np.linalg.norm(STA,axis=0))\n",
    "axarr[0].set_title('Normalized STA', fontsize=20)\n",
    "axarr[1].imshow(STA) #unnormalized\n",
    "axarr[1].set_title('Unnormalized STA', fontsize=20)\n",
    "axarr[2].imshow(ringach_rule_of_thumb[idx].reshape(patch_size, patch_size))\n",
    "axarr[2].set_title('Ringach Approx in Appendix', fontsize=20)\n",
    "axarr[3].imshow(phi_kds[idx].reshape(patch_size, patch_size))\n",
    "axarr[3].set_title('KDS Basis Function', fontsize=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9987754",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
