{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax.nn import sigmoid, relu\n",
    "from jax import random\n",
    "from adaptive_latents.input_sources.autoregressor import AdamOptimizer\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "key = random.key(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.extend.backend import get_backend\n",
    "print(get_backend().platform)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "## proof sanity check/numerical check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(100):\n",
    "    key, old_key = jax.random.split(key)\n",
    "    s = random.uniform(old_key, shape=(130,)) * .1\n",
    "\n",
    "    key, old_key = jax.random.split(key)\n",
    "    v = random.normal(old_key, shape=(130,2))\n",
    "\n",
    "    key, old_key = jax.random.split(key)\n",
    "    a1, a2 = random.uniform(old_key, shape=(2,))\n",
    "\n",
    "    steps = []\n",
    "    steps.append(a1*jnp.linalg.norm(v @ v.T @ s)**2 - a2*jnp.linalg.norm(s - v @ v.T @ s)**2)\n",
    "    steps.append(a1*s.T @ v @ (v.T @ v) @ v.T @ s - a2*jnp.linalg.norm(s - v @ v.T @ s)**2)\n",
    "    steps.append(a1*s.T @ v @ (v.T @ v) @ v.T @ s - a2*(s.T @ s - 2 * s.T @ v @ v.T @ s + s.T @ v @ (v.T @ v) @ v.T @ s))\n",
    "    steps.append(a1*s.T @ v @ (v.T @ v) @ v.T @ s + -a2*s.T @ s - -a2*2 * s.T @ v @ v.T @ s + -a2*s.T @ v @ (v.T @ v) @ v.T @ s)\n",
    "    steps.append((a1-a2)*s.T @ v @ (v.T @ v) @ v.T @ s + a2*(2 * s.T @ v @ v.T @ s -s.T @ s)  )\n",
    "    steps.append(s.T @ ((a1-a2)* v @ (v.T @ v) @ v.T + a2*(2 * v @ v.T - jnp.eye(v.shape[0]))) @ s  )\n",
    "\n",
    "\n",
    "    for i in range(len(steps)-1):\n",
    "        assert jnp.allclose(steps[i], steps[i+1]), i\n",
    "    assert jnp.allclose(steps[0], steps[-1])\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "## using a real Q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "Q = jnp.load('Q.npy')\n",
    "\n",
    "v = Q[:,:2]\n",
    "v = jnp.atleast_2d(v.T).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss(s, v, lam_1=1e-3):\n",
    "    u = s  # this assumes for now that the dynamics S function is an identity\n",
    "    return (\n",
    "        - jnp.sqrt(jnp.linalg.norm(v.T @ s))**2  # maximize dot product with the target vector\n",
    "        + jnp.linalg.norm(s - v @ v.T @ s)**2  # minimize orthogonal component\n",
    "        + jnp.linalg.norm(u, ord=1) * lam_1  # L1 penalty\n",
    "    )\n",
    "\n",
    "# def loss(s, v, lam_1=1e-3):\n",
    "#     u = s  # this assumes for now that the dynamics S function is an identity\n",
    "#     return (\n",
    "#             - 10 * jnp.linalg.norm(v @ v.T @ s)**2  # maximize dot product with the target vector\n",
    "#             + jnp.linalg.norm(s - v @ v.T @ s)**2  # minimize orthogonal component\n",
    "#             + jnp.linalg.norm(u, ord=1) * lam_1  # L1 penalty\n",
    "#     )\n",
    "grad_loss = jax.jit(jax.value_and_grad(loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "s_history = []\n",
    "loss_history = []\n",
    "N = 30\n",
    "\n",
    "lam_1 = 1e-3\n",
    "convergence_threshold = 1e-2\n",
    "while True:\n",
    "    # make a random s to optimize\n",
    "    key, old_key = jax.random.split(key)\n",
    "    s = random.uniform(old_key, shape=(130,)) * .1\n",
    "    s_optimizer = AdamOptimizer(lr=0.005)\n",
    "\n",
    "    for i in range(250):\n",
    "        # Adam update\n",
    "        val, grad = grad_loss(s, v, lam_1=lam_1)\n",
    "        s = s_optimizer.update(s,grad)\n",
    "\n",
    "        # set negative values to 0\n",
    "        s = relu(s)\n",
    "\n",
    "        # logging\n",
    "        s_history.append(s)\n",
    "        loss_history.append(val)\n",
    "\n",
    "        # if converged, break\n",
    "        if len(s_history) > 10 and jnp.linalg.norm(s_history[-2] - s_history[-1]) < convergence_threshold:\n",
    "            break\n",
    "\n",
    "    # if the L0 norm is too big, try again with a larger L1 penalty\n",
    "    l0 = jnp.linalg.norm(s,ord=0)\n",
    "    if 0 < l0 <= N:\n",
    "        break\n",
    "    if l0 == 0 or jnp.isnan(s).any():\n",
    "        print(f\"{jnp.linalg.norm(s, ord=0)}, {lam_1:.2f}, /\")\n",
    "        lam_1 /= 1.2\n",
    "    else:\n",
    "        print(f\"{jnp.linalg.norm(s, ord=0)}, {lam_1:.2f}, *\")\n",
    "        lam_1 *= 2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(ncols=2, layout='tight')\n",
    "axs[0].plot(jnp.linalg.norm(jnp.diff(jnp.array(s_history), axis=0), axis=1))\n",
    "axs[0].axhline(convergence_threshold, color='r')\n",
    "axs[0].semilogy()\n",
    "axs[0].set_xlabel(\"iteration\")\n",
    "axs[0].set_ylabel(\"difference between $\\Delta s$ norms\")\n",
    "axs[1].plot(loss_history)\n",
    "axs[1].set_xlabel(\"iteration\")\n",
    "axs[1].set_ylabel(\"loss history\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "ax.plot(v, color='k', alpha=.1, label='v (the target space)')\n",
    "ax.plot(s/s.max(), color='C0', label='s (optimized stimulation vector)')\n",
    "ax.legend()\n",
    "# ax.plot(s/s.max(), color='C0')\n",
    "print(jnp.linalg.norm(s, ord=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "energies = (Q.T @ s)**2\n",
    "energies = energies / energies.sum()\n",
    "fig, ax = plt.subplots()\n",
    "ax.plot(energies, '.-')\n",
    "ax.set_xlabel(\"latent direction\")\n",
    "ax.set_ylabel(\"sim energy received\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "key, old_key = jax.random.split(key)\n",
    "X = random.normal(old_key, shape=(10000,20))\n",
    "\n",
    "key, old_key = jax.random.split(key)\n",
    "x = random.normal(old_key, shape=(10,))\n",
    "\n",
    "key, old_key = jax.random.split(key)\n",
    "y = random.normal(old_key, shape=(10,)) * 0\n",
    "\n",
    "# X = X.at[100:].set(jnp.nan)\n",
    "\n",
    "X = X.at[9000:, :10].set(1e100)\n",
    "X = X.at[9000:, 10:].set(0)\n",
    "\n",
    "\n",
    "def f(x, y, X):\n",
    "    w = jnp.exp(-(jnp.linalg.norm(X[:,:10] - x, axis=1))**2)\n",
    "    w = w / w.sum()\n",
    "\n",
    "    y_hat = w @ X[:,10:]\n",
    "    return jnp.linalg.norm(y_hat - y)**2\n",
    "\n",
    "value_and_grad_f = jax.value_and_grad(f)\n",
    "value, grad = value_and_grad_f(x,y,X)\n",
    "grad"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
