{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset\n",
    "import sys, os\n",
    "\n",
    "sys.path.append('../')\n",
    "from rnn.vae import VAE\n",
    "from rnn.train import train_VAE\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from pyrnn.model import RNN, predict\n",
    "from pyrnn.train import train_rnn\n",
    "from rnn.saving import save_model, load_model\n",
    "from pyrnn.train import save_rnn, load_rnn\n",
    "import matplotlib.pyplot as plt\n",
    "from itertools import combinations, chain\n",
    "\n",
    "import matplotlib as mpl\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "class Reaching(Dataset):\n",
    "    def __init__(self, task_params):\n",
    "        self.task_params = task_params\n",
    "        \n",
    "    def __len__(self):\n",
    "        \"\"\"Arbitrary number of trials, as they are randomly generated anyway\"\"\"\n",
    "        return 200#self.task_params['n_stim']\n",
    "    def __getitem__(self, idx):\n",
    "        return idx\n",
    "\n",
    "rnn_osc,model_params,task_params,training_params = load_rnn(\"reacb\")\n",
    "alpha = rnn_osc.rnn.dt/rnn_osc.rnn.tau\n",
    "z=np.ones(2)\n",
    "W2 = torch.clone(rnn_osc.rnn.m.detach()).numpy()\n",
    "W1= torch.clone(rnn_osc.rnn.n.detach()/model_params['n_rec']).numpy()*alpha\n",
    "tau = 1-alpha\n",
    "A =np.diag(np.ones(2)*tau)\n",
    "h2 = rnn_osc.rnn.b_rec.detach().numpy()\n",
    "h1 = np.zeros(2)\n",
    "print(A.dot(z) + W1.dot(np.maximum(W2.dot(z) + h2, 0)) + h1)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def powerset(iterable):\n",
    "    s = list(iterable)\n",
    "    return chain.from_iterable(combinations(s,r) for r in range(len(s)+1))\n",
    "\n",
    "def find_fixed_points(a,V,U,hz,h,d=1):\n",
    "    \"\"\"\n",
    "    Find fixed points of the model\n",
    "    Args:\n",
    "        a: numpy array of shape (R,)\n",
    "        V: numpy array of shape (N,R)\n",
    "        U: numpy array of shape (N,R)\n",
    "        hz: numpy array of shape (R,)\n",
    "        h: numpy array of shape (N,)\n",
    "    Returns:\n",
    "        D_list: numpy array of shape (n_Ds,N) containing all subspaces\n",
    "        D_inds: list of indices of subspaces in D_list that are fixed points\n",
    "        z_list: list of fixed points\n",
    "    \"\"\"\n",
    "    \n",
    "    n_inverses=0\n",
    "    N=U.shape[0]\n",
    "    R=U.shape[1]\n",
    "\n",
    "    # First solve for all intersection of hyperplanes\n",
    "    intersect_inds = np.array(list(combinations(np.arange(N),R)))\n",
    "    print(len(intersect_inds))\n",
    "\n",
    "    par_inds = []\n",
    "    if d == 2:\n",
    "        ni = N//2\n",
    "        for i, el in enumerate(intersect_inds):\n",
    "            if el[0]==el[1]+ni or el[1]==el[0]+ni:\n",
    "                par_inds.append(i)\n",
    "        intersect_inds=np.delete(intersect_inds,par_inds,axis=0)\n",
    "        print(\"removed parallel lines\")\n",
    "        print(len(intersect_inds))\n",
    "    \n",
    "    n_Ds_initial = len(list(powerset(range(R))))*len(intersect_inds)\n",
    "    print(len(list(powerset(range(R)))))\n",
    "    D_list = np.zeros((n_Ds_initial,N),dtype='uint8')\n",
    "    it = 0\n",
    "    n_singular = 0\n",
    "    for inds in intersect_inds:\n",
    "        b_hat = h[inds]\n",
    "        U_hat = U[inds]\n",
    "        if np.linalg.matrix_rank(U_hat)>0:#==R:\n",
    "            n_inverses+=1\n",
    "            z = np.linalg.solve(U_hat,b_hat)\n",
    "            # Find all subspaces bordering to this intersection\n",
    "            x = U@z-h\n",
    "            D_init = np.array(x > 0).astype('uint8')\n",
    "            D_init[inds]=0\n",
    "            D_list[it]=D_init\n",
    "            it+=1\n",
    "            D_inds = list(powerset(inds))[1:]\n",
    "            for D_ind in D_inds:\n",
    "                D=np.copy(D_init)\n",
    "                D[np.array(D_ind)]=1\n",
    "                D_list[it]=D\n",
    "                it+=1\n",
    "        else:\n",
    "            n_singular+=1\n",
    "    print(\"n singular\")\n",
    "    print(n_singular)\n",
    "    # Throw away duplicate subspaces\n",
    "    print(D_list.shape)\n",
    "    D_list = np.unique(D_list,axis=0)\n",
    "    print(D_list.shape)\n",
    "\n",
    "    # Finally solve for fixed points\n",
    "    z_list = []\n",
    "    D_inds = []\n",
    "    for D_ind,D_init in enumerate(D_list):\n",
    "\n",
    "        A = -np.eye(R)+np.diag(a)+V.T@np.diag(D_init)@U\n",
    "        b = V.T@np.diag(D_init)@h+hz\n",
    "        z_hat = np.linalg.solve(A,b)\n",
    "        n_inverses+=1\n",
    "\n",
    "        x_hat = U@z_hat-h\n",
    "        if np.allclose(D_init,np.array(x_hat > 0).astype('uint8')):\n",
    "            print(\"Found a fixed point\")\n",
    "            print(z_hat)\n",
    "            z_list.append(z_hat)\n",
    "            D_inds.append(D_ind)\n",
    "    print(\"Done, found \" + str(len(z_list)) + \" fixed points\")\n",
    "    return D_list,D_inds,z_list, n_singular, n_inverses\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "#using code from https://github.com/DurstewitzLab/CNS-2023/blob/main/CNS2023_tutorial.ipynb\n",
    "\n",
    "def construct_relu_matrix(number_quadrant: int, dim: int):\n",
    "    \"\"\"\n",
    "    Matrix describing the Relu function for different quadrants(subcompartments)\n",
    "    \"\"\"\n",
    "    quadrant_index = format(number_quadrant, f'0{dim}b')[::-1]\n",
    "    return np.diag(np.array([bool(int(bit)) for bit in quadrant_index]))\n",
    "\n",
    "def construct_relu_matrix_list(dim: int, order: int):\n",
    "    \"\"\"\n",
    "    Construct a list of relu matrices for a random sequence of quadrants\n",
    "    \"\"\"\n",
    "    relu_matrix_list = np.empty((dim, dim, order))\n",
    "    for i in range(order):\n",
    "        n = int(np.floor(np.random.rand(1)[0] * (2 ** dim)))\n",
    "        relu_matrix_list[:, :, i] = construct_relu_matrix(n, dim)\n",
    "    return relu_matrix_list\n",
    "\n",
    "def get_cycle_point_candidate(A, W1, W2, h1, h2, D_list, order):\n",
    "    \"\"\"\n",
    "    get the candidate for a cycle point by solving the cycle equation\n",
    "    \"\"\"\n",
    "    z_factor, h1_factor, h2_factor = get_factors(A, W1, W2, D_list, order)\n",
    "    try:\n",
    "        inverse_matrix = np.linalg.inv(np.eye(A.shape[0]) - z_factor)\n",
    "        z_candidate = inverse_matrix.dot(h1_factor.dot(h1) + h2_factor.dot(h2))\n",
    "        return z_candidate\n",
    "    except np.linalg.LinAlgError:\n",
    "        # Not invertible\n",
    "        return None\n",
    "\n",
    "def get_factors(A, W1, W2, D_list, order):\n",
    "    \"\"\"\n",
    "    recursively applying map gives us the factors of the cycle equation\n",
    "    \"\"\"\n",
    "    hidden_dim = W2.shape[0]\n",
    "    latent_dim = W1.shape[0]\n",
    "    factor_z = np.eye(A.shape[0])\n",
    "    factor_h1 = np.eye(A.shape[0])\n",
    "    factor_h2 = W1.dot(D_list[:, :, 0]).dot(np.eye(hidden_dim))\n",
    "    for i in range(order - 1):\n",
    "        factor_z = (A + W1.dot(D_list[:, :, i]).dot(W2)).dot(factor_z)\n",
    "        factor_h1 = (A + W1.dot(D_list[:, :, i + 1]).dot(W2)).dot(factor_h1) + np.eye(A.shape[0])\n",
    "        factor_h2 = (A + W1.dot(D_list[:, :, i + 1]).dot(W2)).dot(factor_h2) + W1.dot(D_list[:, :, i + 1])\n",
    "    factor_z = (A + W1.dot(D_list[:, :, order-1]).dot(W2)).dot(factor_z)\n",
    "    return factor_z, factor_h1, factor_h2\n",
    "\n",
    "def get_latent_time_series(time_steps, A, W1, W2, h1, h2, dz, z_0=None):\n",
    "    \"\"\"\n",
    "    Generate the time series by iteravely applying the PLRNN\n",
    "    \"\"\"\n",
    "    if z_0 is None:\n",
    "        z = np.random.randn(dz)\n",
    "    else:\n",
    "        z = z_0\n",
    "    trajectory = [z]\n",
    "\n",
    "    for t in range(1, time_steps):\n",
    "        z = latent_step(z, A, W1, W2, h1, h2)\n",
    "        trajectory.append(z)\n",
    "    return trajectory\n",
    "\n",
    "def latent_step(z, A, W1, W2, h1, h2):\n",
    "    \"\"\"\n",
    "    PLRNN step\n",
    "    \"\"\"\n",
    "    return A.dot(z) + W1.dot(np.maximum(W2.dot(z) + h2, 0)) + h1\n",
    "\n",
    "def get_eigvals(A, W1, W2, D_list, order):\n",
    "    \"\"\"\n",
    "    Get the eigenvalues for all the points along the trajectory to learn about the stability\n",
    "    \"\"\"\n",
    "    e = np.eye(A.shape[0])\n",
    "    for i in range(order):\n",
    "        e = (np.diag(A) + W1.dot(D_list[:, :, i]).dot(W2)).dot(e)\n",
    "    return np.linalg.eigvals(e)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "#using code from https://github.com/DurstewitzLab/CNS-2023/blob/main/CNS2023_tutorial.ipynb\n",
    "\n",
    "def scy_fi2(A, W1, W2, h1, h2, order, found_lower_orders, outer_loop_iterations=300, inner_loop_iterations=100,constrain=False,n_inverses_max =0):\n",
    "    \"\"\"\n",
    "    heuristic algorithm for calculating FP/k-cycle\n",
    "    adapted from https://github.com/DurstewitzLab/CNS-2023/blob/main/CNS2023_tutorial.ipynb\n",
    "    \"\"\"\n",
    "    hidden_dim = h2.shape[0]\n",
    "    latent_dim = h1.shape[0]\n",
    "    cycles_found = []\n",
    "    eigvals = []\n",
    "    n_inverses = 0\n",
    "    i = -1\n",
    "    if constrain:\n",
    "        N=W2.shape[0]\n",
    "        R=W2.shape[1]\n",
    "        # First solve for all intersection of hyperplanes\n",
    "        intersect_inds = np.array(list(combinations(np.arange(N),R)))\n",
    "\n",
    "        n_Ds_initial = len(list(powerset(range(R))))*len(intersect_inds)\n",
    "        D_list = np.zeros((n_Ds_initial,N),dtype='uint8')\n",
    "        it = 0\n",
    "        for inds in intersect_inds:\n",
    "            b_hat = -h2[inds]\n",
    "            U_hat = W2[inds]\n",
    "            n_inverses+=1\n",
    "            z = np.linalg.solve(U_hat,b_hat)\n",
    "            # Find all subspaces bordering to this intersection\n",
    "            x = W2@z+h2\n",
    "            D_init = np.array(x > 0).astype('uint8')\n",
    "            D_init[inds]=0\n",
    "            D_list[it]=D_init\n",
    "            it+=1\n",
    "            D_inds = list(powerset(inds))[1:]\n",
    "            for D_ind in D_inds:\n",
    "                D=np.copy(D_init)\n",
    "                D[np.array(D_ind)]=1\n",
    "                D_list[it]=D\n",
    "                it+=1\n",
    "        \n",
    "    while i < outer_loop_iterations and n_inverses<n_inverses_max:\n",
    "\n",
    "        i += 1\n",
    "        if constrain:\n",
    "            ind = np.random.randint(low=0,high=len(D_list))\n",
    "            relu_matrix_list=np.diag(D_list[ind])\n",
    "        else:\n",
    "            relu_matrix_list = construct_relu_matrix_list(hidden_dim, order)[:,:,0]\n",
    "        c = 0\n",
    "        while c < inner_loop_iterations and n_inverses<n_inverses_max:\n",
    "            c += 1\n",
    "            #z_candidate = get_cycle_point_candidate(A, W1, W2, h1, h2, relu_matrix_list, order)\n",
    "            As = -np.eye(latent_dim)+A+W1@relu_matrix_list@W2\n",
    "            bs = -W1@relu_matrix_list@h2-h1\n",
    "            z_candidate = np.linalg.solve(As,bs)\n",
    "            n_inverses +=1\n",
    "            trajectory =z_candidate\n",
    "            \n",
    "            trajectory_relu_matrix_list= np.diag((W2.dot(trajectory) + h2) > 0)\n",
    "            difference_relu_matrices = np.sum(np.abs(trajectory_relu_matrix_list.astype(int) - relu_matrix_list))\n",
    "      \n",
    "            if difference_relu_matrices == 0:\n",
    "                if not np.any(np.isin(np.round(trajectory[0], 4), np.round(cycles_found, 4))):\n",
    "                    e = 0#get_eigvals(A, W1, W2, relu_matrix_list, order)\n",
    "                    cycles_found.append(trajectory)\n",
    "                    eigvals.append(e)\n",
    "                    i = 0\n",
    "                    c = 0\n",
    "                    #print(\"found fixed point\")\n",
    "                if constrain:\n",
    "                    ind = np.random.randint(low=0,high=len(D_list))\n",
    "                    relu_matrix_list=np.diag(D_list[ind])\n",
    "                else:\n",
    "                    relu_matrix_list = construct_relu_matrix_list(hidden_dim, order)[:,:,0]            \n",
    "            else:\n",
    "                relu_matrix_list = trajectory_relu_matrix_list\n",
    "\n",
    "        #print(n_inverses,n_inverses_max)\n",
    "    return cycles_found, eigvals,n_inverses\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "D_list,D_inds,z_list, n_singular,n_inverses = find_fixed_points(A[range(2),range(2)],W1.T,W2,0,-h2)\n",
    "true_n_fps = len(z_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(n_inverses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def main(A, W1, W2, h1, h2, order, outer_loop_iterations=None, inner_loop_iterations=None,constrain=True,n_inverses_max=1000):\n",
    "    found_lower_orders = []\n",
    "    found_eigvals = []\n",
    "\n",
    "    for i in range(1, order + 1):\n",
    "        cycles_found, eigvals,n_inverses = scy_fi2(A, W1, W2, h1, h2, i, found_lower_orders, \n",
    "                                                  outer_loop_iterations=outer_loop_iterations, inner_loop_iterations=inner_loop_iterations,\n",
    "                                                  constrain=constrain,n_inverses_max=n_inverses_max)\n",
    "\n",
    "        found_lower_orders.append(cycles_found)\n",
    "        found_eigvals.append(eigvals)\n",
    "\n",
    "    return [found_lower_orders, found_eigvals,n_inverses]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "n_iterations = 20\n",
    "all_results = []\n",
    "all_results_constrain = []\n",
    "all_inverses = []\n",
    "all_inverses_constrain = []\n",
    "n_inverses_maxs=np.arange(5000,52000,5000)\n",
    "\n",
    "for _ in range(n_iterations):\n",
    "\n",
    "    results = []\n",
    "    results_constrain = []\n",
    "    inverses = []\n",
    "    inverses_constrain = []\n",
    "    for n_inverses_max in n_inverses_maxs:\n",
    "        for outer_loop_iterations in [10000000]:\n",
    "            for inner_loop_iterations in [100]:\n",
    "                dyn_objects, eigenvals,n_inverses = main(A, W1, W2, h1, h2, 1, outer_loop_iterations=outer_loop_iterations, \n",
    "                                                        inner_loop_iterations=inner_loop_iterations,constrain=False, n_inverses_max=n_inverses_max)\n",
    "                results.append(len(dyn_objects[0]))\n",
    "                print(len(dyn_objects[0]))\n",
    "                inverses.append(n_inverses)\n",
    "                dyn_objects, eigenvals,n_inverses = main(A, W1, W2, h1, h2, 1, outer_loop_iterations=outer_loop_iterations, \n",
    "                                                        inner_loop_iterations=inner_loop_iterations, constrain=True, n_inverses_max=n_inverses_max)\n",
    "                results_constrain.append(len(dyn_objects[0]))\n",
    "                inverses_constrain.append(n_inverses)\n",
    "                print(len(dyn_objects[0]))\n",
    "    all_results.append(results)\n",
    "    all_results_constrain.append(results_constrain)\n",
    "    all_inverses.append(inverses)\n",
    "    all_inverses_constrain.append(inverses_constrain)\n",
    "\n",
    "all_results = np.array(all_results)\n",
    "all_results_constrain = np.array(all_results_constrain)\n",
    "all_inverses = np.array(all_inverses)\n",
    "all_inverses_constrain = np.array(all_inverses_constrain)\n",
    "#65537"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "print(len(list(combinations(range(rnn_osc.rnn.N),2))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "n_inverses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "true_n_fps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "plt.plot(n_inverses_maxs,all_results.T, color = 'blue',alpha =.6)\n",
    "plt.plot(n_inverses_maxs,all_results_constrain.T, color = 'orange',alpha =.6)\n",
    "plt.scatter(n_inverses,true_n_fps,zorder = 1000, color='red')\n",
    "plt.axvline(n_inverses,ls='--',color='red',ymax=true_n_fps-1,ymin =0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "mean = np.mean(all_results,axis=0)\n",
    "var = np.var(all_results,axis=0)\n",
    "mean_constrain = np.mean(all_results_constrain,axis=0)\n",
    "var_constrain = np.var(all_results_constrain,axis=0)\n",
    "max_constrain = np.max(all_results_constrain,axis=0)\n",
    "min_constrain = np.min(all_results_constrain,axis=0)\n",
    "max = np.max(all_results,axis=0)\n",
    "min = np.min(all_results,axis=0)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "D_list,D_inds,z_list, n_singular,n_inverses = find_fixed_points(A[range(2),range(2)],W1.T,W2,0,-h2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_inverses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt_fill_max =  np.maximum(mean + 2 * np.sqrt(var), true_n_fps)\n",
    "plt_fill_max_constrain =  np.maximum(mean_constrain + 2 * np.sqrt(var), true_n_fps)\n",
    "n_start = 1\n",
    "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
    "\n",
    "    plt.figure(figsize=(1,1))\n",
    "    plt.plot(n_inverses_maxs,mean, label=\"approximate\", marker=\"o\", color=\"C0\")\n",
    "    plt.plot(n_inverses_maxs[n_start:],mean_constrain[n_start:], label=\"combined\", marker=\"o\", color=\"C1\")\n",
    "    plt.fill_between(n_inverses_maxs, min, max, alpha=0.2, color=\"C0\")\n",
    "    plt.fill_between(n_inverses_maxs[n_start:], min_constrain[n_start:] ,max_constrain[n_start:], alpha=0.2, color=\"C1\")\n",
    "    #plt.plot(smoothed_particles[:,:], color=\"C1\", alpha=0.5, label=\"Smoothed Particles\")\n",
    "    #plt.plot(smoothed_state_argmax, label=\"Smoothed States Argmax\", color=\"purple\",ls=\"--\")\n",
    "    plt.scatter(n_inverses,true_n_fps,zorder = 1000, color='purple',marker = '*',s=100,label=\"analytic\")\n",
    "    plt.gca().set_box_aspect(1)\n",
    "    plt.legend(loc='upper right',bbox_to_anchor=(2.1,1))\n",
    "    plt.xlabel(\"# inverses\")\n",
    "    plt.ylabel(\"# fixed points found\")\n",
    "    plt.yticks([1,true_n_fps])\n",
    "    plt.ylim(0,true_n_fps+2)\n",
    "    plt.xticks([5000,25000,50000])\n",
    "    plt.xlim(5000,51000)\n",
    "    plt.savefig(\"../figures/FigFP1282.pdf\")#, bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "plt_fill_max =  np.maximum(mean + 2 * np.sqrt(var), true_n_fps)\n",
    "plt_fill_max_constrain =  np.maximum(mean_constrain + 2 * np.sqrt(var), true_n_fps)\n",
    "n_start = 2\n",
    "with mpl.rc_context(fname=\"matplotlibrc\"):\n",
    "\n",
    "    plt.figure(figsize=(1,1))\n",
    "    plt.plot(n_inverses_maxs,mean, label=\"approximate\", marker=\"o\", color=\"C0\")\n",
    "    plt.plot(n_inverses_maxs[n_start:],mean_constrain[n_start:], label=\"combined\", marker=\"o\", color=\"C1\")\n",
    "    plt.fill_between(n_inverses_maxs, min, max, alpha=0.2, color=\"C0\")\n",
    "    plt.fill_between(n_inverses_maxs[n_start:], min_constrain[n_start:] ,max_constrain[n_start:], alpha=0.2, color=\"C1\")\n",
    "    #plt.plot(smoothed_particles[:,:], color=\"C1\", alpha=0.5, label=\"Smoothed Particles\")\n",
    "    #plt.plot(smoothed_state_argmax, label=\"Smoothed States Argmax\", color=\"purple\",ls=\"--\")\n",
    "    plt.scatter(n_inverses,true_n_fps,zorder = 1000, color='purple',marker = '*',s=100,label=\"analytic\")\n",
    "    plt.gca().set_box_aspect(1)\n",
    "    plt.legend(loc='upper right',bbox_to_anchor=(2.1,1))\n",
    "    plt.xlabel(\"# inverses\")\n",
    "    plt.ylabel(\"# fixed points found\")\n",
    "    plt.yticks([0,10,true_n_fps])\n",
    "    plt.ylim(10,true_n_fps+.75)\n",
    "    plt.xticks([1000,4000,7000])\n",
    "    plt.xlim(1000,7000)\n",
    "    plt.savefig(\"../figures/FigFP.pdf\")#, bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "mean_constrain[1:] - 2 * np.sqrt(var_constrain[1:]),"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rnns",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
