{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8f88c77-2fb8-41d0-a459-247d8092f23f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This program compute the cardinality of neuromanifolds corresponding to \n",
    "# shallow neural networks with monomial activation function over prime fields"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c5b44054-e5af-4555-83fb-fc37e42a2b40",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import libraries\n",
    "from sage.all import *\n",
    "from itertools import chain, product\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from math import comb\n",
    "import itertools, math, time, csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7b1e7bb2-712e-44ed-b328-2eb19dc58bd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define all custom functions\n",
    "\n",
    "# compute the cardinality of the parameter map\n",
    "def Psi_image(d, prime, r=2, show_progress=True, do_factor=False):   \n",
    "    start   = time.time()\n",
    "    # Define the field\n",
    "    field_k = GF(prime)\n",
    "    \n",
    "    # Define the neural network weights and inputs\n",
    "    ww_vars = []\n",
    "    for k in range(len(d) - 1):\n",
    "        # avoid index shadowing; build names deterministically\n",
    "        ww_k = matrix(d[k+1], d[k], lambda i, j: var(f\"w_{k+1}{i+1}{j+1}\")).list()\n",
    "        ww_vars += ww_k\n",
    "    xx_vars = [var(f\"x_{i}\") for i in range(d[0])]\n",
    "    \n",
    "    # Polynomial rings; keep fraction field for non-polynomial activations\n",
    "    C = PolynomialRing(field_k, ww_vars)\n",
    "    P = PolynomialRing(C, xx_vars)\n",
    "    FF = P.fraction_field()\n",
    "    \n",
    "    ww = C.gens()\n",
    "    xx = P.gens()\n",
    "\n",
    "    # Feedforward\n",
    "    shift = 0\n",
    "    sigma_Wi_x = Matrix(xx).transpose()  # column vector\n",
    "\n",
    "    Wi_x_last = None\n",
    "    for ell in range(len(d) - 1):\n",
    "        count = d[ell] * d[ell + 1]\n",
    "        # use a slice (faster than a generator)\n",
    "        Wi = matrix(FF, d[ell + 1], d[ell], list(ww[shift:shift + count]))\n",
    "        shift += count\n",
    "\n",
    "        Wi_x_last = Wi * sigma_Wi_x\n",
    "        # If the activation is polynomial (integer r >= 1), we can keep it symbolic\n",
    "        sigma_Wi_x = Wi_x_last.apply_map(lambda u: u**r)\n",
    "\n",
    "    # Pre-activation output of last layer (to match your original behavior)\n",
    "    p = Wi_x_last\n",
    "\n",
    "    # Collect coefficients (numerators only unless you truly need denominators)\n",
    "    coeffs = []\n",
    "    for i in range(d[-1]):\n",
    "        # In the polynomial activation case, denominator is 1. We still use .numerator() for safety.\n",
    "        num = p[i][0].numerator()\n",
    "        coeffs.extend(num.coefficients())\n",
    "\n",
    "    # Flatten once; avoid repeated vector(chain(...))\n",
    "    Psi_coeffs = tuple(coeffs)\n",
    "\n",
    "    N = sum(d[k] * d[k + 1] for k in range(len(d) - 1))\n",
    "    total_pts = prime**N  # total number of weight assignments (points to iterate)\n",
    "    \n",
    "    iterator = product(field_k, repeat=N)\n",
    "    if show_progress:\n",
    "        from tqdm import tqdm\n",
    "        iterator = tqdm(\n",
    "            iterator,\n",
    "            total=total_pts,                 \n",
    "            desc=\"Computing Psi\",\n",
    "            ncols=80,\n",
    "            bar_format=\"{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\"\n",
    "        )\n",
    "    \n",
    "    image_Psi = set()\n",
    "    for point in iterator:\n",
    "        image_Psi.add(tuple(c(*point) for c in Psi_coeffs))\n",
    "    \n",
    "    #print(f\"Total points (|F_{prime}|^{N}): {total_pts}\")\n",
    "    #print(f\"Distinct images (cardinality): {len(image_Psi)}\")\n",
    "    #print(f\"Ambient codomain size bound (p^{len(Psi_coeffs)}): {prime**len(Psi_coeffs)}\")\n",
    "    \n",
    "    return len(image_Psi), time.time() - start\n",
    "\n",
    "# run through all the architectures \n",
    "def main(d_list, p, TIMEOUT, r):\n",
    "    csv_path = f\"result_p{p}_r{r}.csv\"\n",
    "    print(f\"Writing to: {csv_path}\")\n",
    "    \n",
    "    with open(f\"result_p{p}_r{r}.csv\", \"w\", newline=\"\") as fh:\n",
    "        writer = csv.writer(fh)\n",
    "        writer.writerow([\"n\",\"m\",\"k\",\"prime\",\"r\",\"cardinality\",\n",
    "                         \"amb_space_card\",\"ratio\",\"dim_amb\",\"time_sec\"])\n",
    "        for i, d in enumerate(d_list):\n",
    "            n, m, k = d\n",
    "            print(f\"Running for architecture d = {d} over F_{p}, # parameters = {p**(m*n+k*m)}\")\n",
    "            card, elapsed = Psi_image(d, p, r)\n",
    "            dim_amb      = k * comb(n + r - 1, r)\n",
    "            amb_space    = p ** dim_amb\n",
    "            ratio        = \"n/a\" if card is None else round(card / amb_space, 9)\n",
    "\n",
    "            print(f\"{i}/{len(d_list)}:: cardinality={card if card is not None else 'TIMEOUT'}, amb_space={amb_space}, time={elapsed:.2f}s\\n\")\n",
    "            writer.writerow([n,m,k,p,r,\n",
    "                             card if card is not None else \"TIMEOUT\",\n",
    "                             amb_space, ratio, dim_amb, f\"{elapsed:.1f}\"])\n",
    "\n",
    "# filter archtiectures not to exceed given number of max_points to check\n",
    "def valid_d_list(p, max_points, n_max=10, m_max=10, k_max=10):\n",
    "    result = []\n",
    "    for n in range(2, n_max):       # n > 1, upper bound conservative\n",
    "        for m in range(2, m_max):\n",
    "            for k in range(1, k_max):   # search within reasonable bounds\n",
    "                if check_cardinality([n,m,k], p, max_points):\n",
    "                    result.append([n, m, k])\n",
    "                else:\n",
    "                    break  # stop increasing k if it already exceeds\n",
    "    return result\n",
    "\n",
    "def check_cardinality(d, prime, max_points):\n",
    "    N = d[2]*d[1] + d[1]*d[0]\n",
    "    total_pts = prime**N\n",
    "    if total_pts < max_points:\n",
    "        return True\n",
    "    else:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5bcb5a9-2839-47c6-91bb-cc83aff359e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# In particular, the code below computes the architectures\n",
    "# whose the expressive capacity of approches 1/2 when a prime approaches infinity\n",
    "# (this is done in data_analysis file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0da08632-360b-4c49-9b8e-9905fe75a6bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Running for r = 2, prime p = 2 ===\n",
      "Writing to: result_p2_r2.csv\n",
      "Running for architecture d = [2, 2, 2] over F_2, # parameters = 256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing Psi: 100%|███████████████████████| 256/256 [00:00<00:00, 29495.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0/1:: cardinality=16, amb_space=64, time=0.03s\n",
      "\n",
      "\n",
      "=== Running for r = 2, prime p = 3 ===\n",
      "Writing to: result_p3_r2.csv\n",
      "Running for architecture d = [2, 2, 2] over F_3, # parameters = 6561\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing Psi: 100%|█████████████████████| 6561/6561 [00:00<00:00, 32687.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0/1:: cardinality=393, amb_space=729, time=0.21s\n",
      "\n",
      "\n",
      "=== Running for r = 2, prime p = 5 ===\n",
      "Writing to: result_p5_r2.csv\n",
      "Running for architecture d = [2, 2, 2] over F_5, # parameters = 390625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing Psi: 100%|█████████████████| 390625/390625 [00:11<00:00, 32968.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0/1:: cardinality=7945, amb_space=15625, time=11.86s\n",
      "\n",
      "\n",
      "=== Running for r = 2, prime p = 7 ===\n",
      "Writing to: result_p7_r2.csv\n",
      "Running for architecture d = [2, 2, 2] over F_7, # parameters = 5764801\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing Psi: 100%|███████████████| 5764801/5764801 [02:54<00:00, 32968.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0/1:: cardinality=59185, amb_space=117649, time=174.86s\n",
      "\n",
      "\n",
      "=== Running for r = 2, prime p = 11 ===\n",
      "Writing to: result_p11_r2.csv\n",
      "Running for architecture d = [2, 2, 2] over F_11, # parameters = 214358881\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing Psi:   0%|            | 479726/214358881 [00:14<1:46:25, 33494.01it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 19\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m primes:    \n\u001b[1;32m     11\u001b[0m     \u001b[38;5;66;03m#d_list = valid_d_list(p, max_points)            \u001b[39;00m\n\u001b[1;32m     12\u001b[0m     \u001b[38;5;66;03m# check for the custom weights\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     15\u001b[0m     \u001b[38;5;66;03m#    d_list_new.append(check_cardinality(d, p, max_points))\u001b[39;00m\n\u001b[1;32m     16\u001b[0m     \u001b[38;5;66;03m#d_list = d_list_new\u001b[39;00m\n\u001b[1;32m     18\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m=== Running for r = \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mr\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, prime p = \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mp\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m ===\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 19\u001b[0m     \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43md_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mTIMEOUT\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mr\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[4], line 85\u001b[0m, in \u001b[0;36mmain\u001b[0;34m(d_list, p, TIMEOUT, r)\u001b[0m\n\u001b[1;32m     83\u001b[0m n, m, k \u001b[38;5;241m=\u001b[39m d\n\u001b[1;32m     84\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRunning for architecture d = \u001b[39m\u001b[38;5;132;01m{\u001b[39;00md\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m over F_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mp\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, # parameters = \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mp\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m(m\u001b[38;5;241m*\u001b[39mn\u001b[38;5;241m+\u001b[39mk\u001b[38;5;241m*\u001b[39mm)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 85\u001b[0m card, elapsed \u001b[38;5;241m=\u001b[39m \u001b[43mPsi_image\u001b[49m\u001b[43m(\u001b[49m\u001b[43md\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     86\u001b[0m dim_amb      \u001b[38;5;241m=\u001b[39m k \u001b[38;5;241m*\u001b[39m comb(n \u001b[38;5;241m+\u001b[39m r \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m, r)\n\u001b[1;32m     87\u001b[0m amb_space    \u001b[38;5;241m=\u001b[39m p \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m dim_amb\n",
      "Cell \u001b[0;32mIn[4], line 66\u001b[0m, in \u001b[0;36mPsi_image\u001b[0;34m(d, prime, r, show_progress, do_factor)\u001b[0m\n\u001b[1;32m     64\u001b[0m image_Psi \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n\u001b[1;32m     65\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m point \u001b[38;5;129;01min\u001b[39;00m iterator:\n\u001b[0;32m---> 66\u001b[0m     image_Psi\u001b[38;5;241m.\u001b[39madd(\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpoint\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mc\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mPsi_coeffs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m     68\u001b[0m \u001b[38;5;66;03m#print(f\"Total points (|F_{prime}|^{N}): {total_pts}\")\u001b[39;00m\n\u001b[1;32m     69\u001b[0m \u001b[38;5;66;03m#print(f\"Distinct images (cardinality): {len(image_Psi)}\")\u001b[39;00m\n\u001b[1;32m     70\u001b[0m \u001b[38;5;66;03m#print(f\"Ambient codomain size bound (p^{len(Psi_coeffs)}): {prime**len(Psi_coeffs)}\")\u001b[39;00m\n\u001b[1;32m     72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(image_Psi), time\u001b[38;5;241m.\u001b[39mtime() \u001b[38;5;241m-\u001b[39m start\n",
      "Cell \u001b[0;32mIn[4], line 66\u001b[0m, in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     64\u001b[0m image_Psi \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n\u001b[1;32m     65\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m point \u001b[38;5;129;01min\u001b[39;00m iterator:\n\u001b[0;32m---> 66\u001b[0m     image_Psi\u001b[38;5;241m.\u001b[39madd(\u001b[38;5;28mtuple\u001b[39m(\u001b[43mc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpoint\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m c \u001b[38;5;129;01min\u001b[39;00m Psi_coeffs))\n\u001b[1;32m     68\u001b[0m \u001b[38;5;66;03m#print(f\"Total points (|F_{prime}|^{N}): {total_pts}\")\u001b[39;00m\n\u001b[1;32m     69\u001b[0m \u001b[38;5;66;03m#print(f\"Distinct images (cardinality): {len(image_Psi)}\")\u001b[39;00m\n\u001b[1;32m     70\u001b[0m \u001b[38;5;66;03m#print(f\"Ambient codomain size bound (p^{len(Psi_coeffs)}): {prime**len(Psi_coeffs)}\")\u001b[39;00m\n\u001b[1;32m     72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(image_Psi), time\u001b[38;5;241m.\u001b[39mtime() \u001b[38;5;241m-\u001b[39m start\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    TIMEOUT = 20_000\n",
    "    max_points = 10**7 - 1 # max points allowed in the parameter space\n",
    "    primes = [2,3,5,7,11] # primes to check\n",
    "    activations = [2]\n",
    "    d_list = [[2,2,2]] # custom architectures\n",
    "\n",
    "\n",
    "    for r in activations:\n",
    "        for p in primes:    \n",
    "            #d_list = valid_d_list(p, max_points) # generate architectures         \n",
    "            # check for the custom weights\n",
    "            #d_list_new = []\n",
    "            #for d in d_list:\n",
    "            #    d_list_new.append(check_cardinality(d, p, max_points))\n",
    "            #d_list = d_list_new\n",
    "            \n",
    "            print(f\"\\n=== Running for r = {r}, prime p = {p} ===\")\n",
    "            main(d_list, p, TIMEOUT, r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdb8c285-61e6-4a9e-9068-5fc35fa4181f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "SageMath 10.4",
   "language": "python",
   "name": "sage10.4"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
