{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from typing import Optional, Tuple\n",
    "\n",
    "\n",
    "class NN_sim:\n",
    "    \"\"\"\n",
    "    Neural network simulator with population structure and tanh nonlinearity.\n",
    "\n",
    "    Notes\n",
    "    -----\n",
    "    - Time indexing: we return sequences of length `T` (0..T-1). At t=0 we apply the\n",
    "      external input pulse to the first population: S_0 = tanh(x_0 + WI0), with x_0 = 0.\n",
    "      For t >= 1: x_t = J @ S_{t-1} + noise, S_t = tanh(x_t).\n",
    "    - All simulations use a NumPy Generator (`self.rng`) for reproducibility.\n",
    "    - `isFF` is stored but not used directly here (kept for compatibility).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        g: np.ndarray,\n",
    "        isFF: bool,\n",
    "        N: int,\n",
    "        f: np.ndarray,\n",
    "        w: np.ndarray,\n",
    "        theta: float,\n",
    "        std_x: float,\n",
    "        std_obs: float,\n",
    "        T_lim: int,\n",
    "        num_traj: int,\n",
    "        num_J: int,\n",
    "        seed: Optional[int] = None,\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Parameters\n",
    "        ----------\n",
    "        g : (P, P) array\n",
    "            Block-level gain/scale for connections between P populations.\n",
    "        isFF : bool\n",
    "            If True, marks network as feedforward (not used in this class).\n",
    "        N : int\n",
    "            Total number of neurons.\n",
    "        f : (P,) array\n",
    "            Fraction of neurons in each population; sum(f) == 1.\n",
    "        w : array-like\n",
    "            Weighted input to the first population (kept for compatibility).\n",
    "        theta : float\n",
    "            Initial input pulse applied to the first population at t=0.\n",
    "        std_x : float\n",
    "            Std. dev. of Gaussian process noise added to x_t.\n",
    "        std_obs : float\n",
    "            Std. dev. of observation noise (allocated but not used here).\n",
    "        T_lim : int\n",
    "            Number of time steps to simulate; output length is T_lim.\n",
    "        num_traj : int\n",
    "            Number of trajectories to simulate in batched mode.\n",
    "        num_J : int\n",
    "            Number of connectivity realizations (not used here).\n",
    "        seed : int, optional\n",
    "            RNG seed for reproducibility.\n",
    "        \"\"\"\n",
    "        # Basic assignments / validation\n",
    "        self.g = np.asarray(g, dtype=float)\n",
    "        self.isFF = bool(isFF)\n",
    "        self.N = int(N)\n",
    "        self.f = np.asarray(f, dtype=float)\n",
    "        self.w = np.asarray(w)  # retained, not used in current dynamics\n",
    "        self.theta = float(theta)\n",
    "        self.std_x = float(std_x)\n",
    "        self.std_obs = float(std_obs)\n",
    "        self.T_lim = int(T_lim)\n",
    "        self.num_traj = int(num_traj)\n",
    "        self.num_J = int(num_J)\n",
    "        self.rng = np.random.default_rng(seed)\n",
    "\n",
    "        # Shape checks\n",
    "        if self.g.ndim != 2 or self.g.shape[0] != self.g.shape[1]:\n",
    "            raise ValueError(\"g must be a square (P, P) array.\")\n",
    "        if self.f.ndim != 1 or self.f.shape[0] != self.g.shape[0]:\n",
    "            raise ValueError(\"f must have length equal to number of populations P.\")\n",
    "        if not np.isclose(self.f.sum(), 1.0):\n",
    "            raise ValueError(\"Population fractions f must sum to 1.\")\n",
    "\n",
    "        # Compute neuron counts per population with remainder to the last group\n",
    "        self.n = self._compute_population_sizes(self.N, self.f)\n",
    "\n",
    "        # Input initialization flag (kept for compatibility)\n",
    "        self.input_initialization = 1\n",
    "\n",
    "    @staticmethod\n",
    "    def _compute_population_sizes(N: int, f: np.ndarray) -> np.ndarray:\n",
    "        \"\"\"Compute integer neuron counts per population that sum exactly to N.\"\"\"\n",
    "        n = np.floor(N * f).astype(int)\n",
    "        # Allocate any leftover neurons to the last population to ensure sum == N\n",
    "        remainder = N - n.sum()\n",
    "        n[-1] += remainder\n",
    "        if n.min() <= 0:\n",
    "            raise ValueError(\"Each population must contain at least one neuron.\")\n",
    "        return n\n",
    "\n",
    "    # -------- Optional helpers (leave as-is if defined elsewhere) --------\n",
    "    def calculate_g_bar(self) -> float:\n",
    "        \"\"\"\n",
    "        Returns an effective gain scalar g_bar from `self.g`.\n",
    "\n",
    "        This method depends on external helpers `garr_to_G_arr` and `Garr_to_gbar`.\n",
    "        Provide them in your codebase or replace this function with your own logic.\n",
    "        \"\"\"\n",
    "        # These functions are assumed to exist in your environment:\n",
    "        G_arr = garr_to_G_arr(self.g, self)      # noqa: F821 - external\n",
    "        g_bar = Garr_to_gbar(G_arr, self)        # noqa: F821 - external\n",
    "        return g_bar\n",
    "\n",
    "    # ---------------- Connectivity construction ----------------\n",
    "    def J_matrix(self) -> np.ndarray:\n",
    "        \"\"\"\n",
    "        Sample a block-structured connectivity matrix J ~ N(0, g_ij / sqrt(N)).\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        J : (N, N) array\n",
    "            Dense connectivity matrix consistent with population sizes `self.n`.\n",
    "        \"\"\"\n",
    "        # Recompute `n` from `f` to reflect any external changes to N/f.\n",
    "        self.n = self._compute_population_sizes(self.N, self.f)\n",
    "\n",
    "        P = self.g.shape[0]\n",
    "        if P != self.n.size:\n",
    "            raise ValueError(\"Mismatch between g shape and number of populations.\")\n",
    "\n",
    "        J = np.zeros((self.N, self.N), dtype=float)\n",
    "\n",
    "        r0 = 0\n",
    "        for i in range(P):\n",
    "            c0 = 0\n",
    "            ni = self.n[i]\n",
    "            for j in range(P):\n",
    "                nj = self.n[j]\n",
    "                scale = self.g[i, j] / np.sqrt(self.N)\n",
    "                J[r0:r0 + ni, c0:c0 + nj] = self.rng.normal(loc=0.0, scale=scale, size=(ni, nj))\n",
    "                c0 += nj\n",
    "            r0 += ni\n",
    "\n",
    "        return J\n",
    "\n",
    "    # ---------------- Core single-trajectory simulation ----------------\n",
    "    def _build_WI0_vector(self, theta: float) -> np.ndarray:\n",
    "        \"\"\"External input vector applied at t=0 to the first population.\"\"\"\n",
    "        WI0 = np.zeros((self.N, 1), dtype=float)\n",
    "        n1 = self.n[0]\n",
    "        WI0[:n1, 0] = theta\n",
    "        return WI0\n",
    "\n",
    "    def simulation(\n",
    "        self,\n",
    "        J: Optional[np.ndarray] = None,\n",
    "        theta: Optional[float] = None,\n",
    "        T: Optional[int] = None,\n",
    "        std_x: Optional[float] = None,\n",
    "    ) -> Tuple[np.ndarray, np.ndarray]:\n",
    "        \"\"\"\n",
    "        Simulate a single trajectory.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        J : (N, N) array, optional\n",
    "            Connectivity. If None, sampled via `self.J_matrix()`.\n",
    "        theta : float, optional\n",
    "            Input pulse to first population at t=0. Defaults to `self.theta`.\n",
    "        T : int, optional\n",
    "            Number of time steps. Defaults to `self.T_lim`.\n",
    "        std_x : float, optional\n",
    "            Process noise std. Defaults to `self.std_x`.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        x_t : (N, T) array\n",
    "            States over time (including t=0).\n",
    "        S_t : (N, T) array\n",
    "            Activations over time (including t=0).\n",
    "        \"\"\"\n",
    "        if J is None:\n",
    "            J = self.J_matrix()\n",
    "        if theta is None:\n",
    "            theta = self.theta\n",
    "        if T is None:\n",
    "            T = self.T_lim\n",
    "        if std_x is None:\n",
    "            std_x = self.std_x\n",
    "\n",
    "        if J.shape != (self.N, self.N):\n",
    "            raise ValueError(\"J must have shape (N, N).\")\n",
    "\n",
    "        x_t = np.zeros((self.N, T), dtype=float)\n",
    "        S_t = np.zeros((self.N, T), dtype=float)\n",
    "\n",
    "        # t = 0\n",
    "        x0 = np.zeros((self.N, 1), dtype=float)\n",
    "        WI0 = self._build_WI0_vector(theta)\n",
    "        S0 = np.tanh(x0 + WI0)\n",
    "        S_t[:, 0] = S0[:, 0]\n",
    "\n",
    "        # t >= 1\n",
    "        for t in range(1, T):\n",
    "            noise = self.rng.normal(0.0, std_x, size=(self.N, 1))\n",
    "            x = J @ S0 + noise\n",
    "            x_t[:, t] = x[:, 0]\n",
    "            S0 = np.tanh(x)\n",
    "            S_t[:, t] = S0[:, 0]\n",
    "\n",
    "        return x_t, S_t\n",
    "\n",
    "    # ---------------- Multiple-trajectory simulation (looped) ----------------\n",
    "    def simulation_traj(self, J: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n",
    "        \"\"\"\n",
    "        Simulate `self.num_traj` independent trajectories with the same J (looped version).\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        traj_x : (num_traj, N, T) array\n",
    "        traj_S : (num_traj, N, T) array\n",
    "        \"\"\"\n",
    "        T = self.T_lim\n",
    "        traj_x = np.zeros((self.num_traj, self.N, T), dtype=float)\n",
    "        traj_S = np.zeros((self.num_traj, self.N, T), dtype=float)\n",
    "        for i in range(self.num_traj):\n",
    "            xi, Si = self.simulation(J=J)\n",
    "            traj_x[i] = xi\n",
    "            traj_S[i] = Si\n",
    "        return traj_x, traj_S\n",
    "\n",
    "    # ---------------- Multiple-trajectory simulation (batched) ----------------\n",
    "    def simulation_traj_V2(self, J: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n",
    "        \"\"\"\n",
    "        Batched simulation of `self.num_traj` trajectories with the same J.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        traj_x : (num_traj, N, T) array\n",
    "            States over time per trajectory.\n",
    "        traj_S : (num_traj, N, T) array\n",
    "            Activations over time per trajectory.\n",
    "        \"\"\"\n",
    "        if J.shape != (self.N, self.N):\n",
    "            raise ValueError(\"J must have shape (N, N).\")\n",
    "\n",
    "        T = self.T_lim\n",
    "        B = self.num_traj\n",
    "\n",
    "        traj_x = np.zeros((B, self.N, T), dtype=float)\n",
    "        traj_S = np.zeros((B, self.N, T), dtype=float)\n",
    "\n",
    "        # t = 0\n",
    "        x = np.zeros((self.N, B), dtype=float)\n",
    "        WI0 = np.zeros((self.N, B), dtype=float)\n",
    "        n1 = self.n[0]\n",
    "        WI0[:n1, :] = self.theta\n",
    "        S = np.tanh(x + WI0)\n",
    "        traj_S[:, :, 0] = S.T\n",
    "\n",
    "        # t >= 1\n",
    "        for t in range(1, T):\n",
    "            noise = self.rng.normal(0.0, self.std_x, size=(self.N, B))\n",
    "            x = J @ S + noise\n",
    "            traj_x[:, :, t] = x.T\n",
    "            S = np.tanh(x)\n",
    "            traj_S[:, :, t] = S.T\n",
    "\n",
    "        return traj_x, traj_S\n",
    "\n",
    "    # ---------------- Fisher information runs for a single J ----------------\n",
    "    def fisher_information_per_J(\n",
    "        self,\n",
    "        delta_theta: float,\n",
    "        use_batched: bool = True,\n",
    "    ) -> Tuple[\n",
    "        np.ndarray, np.ndarray,\n",
    "        np.ndarray, np.ndarray,\n",
    "        np.ndarray, np.ndarray\n",
    "    ]:\n",
    "        \"\"\"\n",
    "        Run three simulations (theta, theta + delta, theta - delta) for a single J.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        delta_theta : float\n",
    "            Small perturbation applied around the baseline `self.theta`.\n",
    "        use_batched : bool\n",
    "            If True, use `simulation_traj_V2`; else, use `simulation_traj`.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        traj_theta, S_theta,\n",
    "        traj_theta_plus, S_theta_plus,\n",
    "        traj_theta_minus, S_theta_minus\n",
    "            Each is an array of shape:\n",
    "            - Batched: (num_traj, N, T)\n",
    "            - Looped:  (num_traj, N, T)\n",
    "        \"\"\"\n",
    "        J = self.J_matrix()\n",
    "\n",
    "        # Baseline theta\n",
    "        theta0 = self.theta\n",
    "\n",
    "        if use_batched:\n",
    "            # θ\n",
    "            traj_theta, S_theta = self.simulation_traj_V2(J)\n",
    "            # θ + δ\n",
    "            theta_backup = self.theta\n",
    "            self.theta = theta0 + delta_theta\n",
    "            traj_theta_plus, S_theta_plus = self.simulation_traj_V2(J)\n",
    "            # θ - δ\n",
    "            self.theta = theta0 - delta_theta\n",
    "            traj_theta_minus, S_theta_minus = self.simulation_traj_V2(J)\n",
    "            # restore\n",
    "            self.theta = theta_backup\n",
    "        else:\n",
    "            # θ\n",
    "            traj_theta, S_theta = self.simulation_traj(J)\n",
    "            # θ + δ\n",
    "            theta_backup = self.theta\n",
    "            self.theta = theta0 + delta_theta\n",
    "            traj_theta_plus, S_theta_plus = self.simulation_traj(J)\n",
    "            # θ - δ\n",
    "            self.theta = theta0 - delta_theta\n",
    "            traj_theta_minus, S_theta_minus = self.simulation_traj(J)\n",
    "            # restore\n",
    "            self.theta = theta_backup\n",
    "\n",
    "        return (\n",
    "            traj_theta, S_theta,\n",
    "            traj_theta_plus, S_theta_plus,\n",
    "            traj_theta_minus, S_theta_minus\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN model for the copy task and sequential MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "### model\n",
    "class SimpleRNN(nn.Module):\n",
    "    def __init__(self, input_dim=10, hidden_dim=100, output_dim=10,nonlinear_f = 'relu'):\n",
    "        super().__init__()\n",
    "        self.rnn = nn.RNN(input_dim, hidden_dim, nonlinearity=nonlinear_f, batch_first=True)\n",
    "        self.readout = nn.Linear(hidden_dim, output_dim)\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        # Recurrent weights as identity (iRNN)\n",
    "        with torch.no_grad():\n",
    "            W_hh = self.rnn.weight_hh_l0\n",
    "            nn.init.eye_(W_hh)\n",
    "            # Small init for input weights helps stability\n",
    "            k = 1.0 / math.sqrt(W_hh.shape[0])\n",
    "            nn.init.uniform_(self.rnn.weight_ih_l0, -k, k)\n",
    "            nn.init.zeros_(self.rnn.bias_ih_l0)\n",
    "            nn.init.zeros_(self.rnn.bias_hh_l0)\n",
    "        nn.init.zeros_(self.readout.bias)\n",
    "\n",
    "    def forward(self, x):  # x: [B,L,10]\n",
    "        h, _ = self.rnn(x)                # [B,L,H]\n",
    "        logits = self.readout(h)          # [B,L,C]\n",
    "        return logits\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "condamac",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
