{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9a1861a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import neural_tangents as nt\n",
    "from neural_tangents import stax\n",
    "\n",
    "from templates import multiplication\n",
    "from templates import addition\n",
    "from templates import permutation\n",
    "from templates import utils\n",
    "\n",
    "jax.config.update(\"jax_enable_x64\", True)\n",
    "\n",
    "\n",
    "def get_ntk_predictor(X_train, Y_train):\n",
    "\n",
    "    n0 = X_train.shape[0]\n",
    "\n",
    "    kernel_fn = stax.serial(\n",
    "        stax.Dense(1024),   # First dense layer with 128 units\n",
    "        stax.Relu(),       # ReLU activation\n",
    "        stax.Dense(n0)      # Output dense layer with 1 unit\n",
    "    )[-1]\n",
    "    predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, X_train, Y_train)\n",
    "\n",
    "    return lambda x: predict_fn(x_test=x, get='ntk', compute_cov=True).mean[0]\n",
    "\n",
    "eps = 1e-10 # small tolerance to avoid numerical issues"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39b258de",
   "metadata": {},
   "source": [
    "### Permutation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b6728d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "m = 10 # number of bits\n",
    "\n",
    "# sample permutation at random\n",
    "P = np.eye(m)[np.random.permutation(m)]\n",
    "\n",
    "# define dataset\n",
    "X, Y_train = permutation.get_dataset(m)\n",
    "X_train = jnp.array(P.T@np.eye(len(X)), dtype=jnp.float64)\n",
    "Y_train = jnp.array(Y_train, dtype=jnp.float64)\n",
    "predict_fn = get_ntk_predictor(X_train, Y_train)\n",
    "\n",
    "\n",
    "number = np.random.randint(0, 2**m)\n",
    "number_bin = np.array([int(x) for x in np.binary_repr(number, width=m)])\n",
    "out_bin = (P.T@number_bin).tolist()\n",
    "out = int(''.join([str(int(x)) for x in out_bin]), 2)\n",
    "print(f\"Input: {number}\")      \n",
    "\n",
    "X_test = permutation.get_sample(m)[1]\n",
    "X_test['p'] = number_bin.tolist()\n",
    "\n",
    "print(\"Iteration 0:\")\n",
    "print(X_test)\n",
    "print()\n",
    "\n",
    "X_test = utils.encode_data(X_test, X)\n",
    "X_test = jnp.array(X_test, dtype=jnp.float64).reshape(1, -1)\n",
    "\n",
    "y_pred = predict_fn(X_test)\n",
    "y_pred_round = np.where(y_pred > eps, 1, 0).tolist()\n",
    "X_test = permutation.unflatten_sample(y_pred_round, m)\n",
    "\n",
    "print(\"Iteration 1:\")\n",
    "print(X_test)\n",
    "print()\n",
    "\n",
    "\n",
    "assert out_bin == X_test['p'] \n",
    "\n",
    "# decode output from binary to int\n",
    "prediction = int(''.join([str(int(x)) for x in X_test['p']]), 2)\n",
    "print(f\"Expected Output: {out}\")\n",
    "print(f\"Predicted Output: {prediction}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2426dbe1",
   "metadata": {},
   "source": [
    "### Addition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29945de3",
   "metadata": {},
   "outputs": [],
   "source": [
    "m = 4\n",
    "\n",
    "X, Y_train = addition.get_dataset(m)\n",
    "X_train = jnp.eye(len(X))\n",
    "Y_train = jnp.array(Y_train, dtype=jnp.float64)\n",
    "predict_fn = get_ntk_predictor(X_train, Y_train)\n",
    "\n",
    "p = np.random.randint(0, 2**m)\n",
    "q = np.random.randint(0, 2**m)\n",
    "out = p + q\n",
    "\n",
    "X_test = addition.get_sample(m)[1]\n",
    "X_test['sum_p'] = np.array([int(x) for x in np.binary_repr(p, width=m)])[::-1].tolist()\n",
    "X_test['sum_q'] = np.array([int(x) for x in np.binary_repr(q, width=m)])[::-1].tolist()\n",
    "\n",
    "print(f\"Input: {p} + {q}\")\n",
    "print(\"Iteration 0:\")\n",
    "print(X_test)\n",
    "print()\n",
    "\n",
    "for i in range(2*m):\n",
    "    X_test_old = X_test.copy()\n",
    "    X_test = utils.encode_data(X_test, X)\n",
    "\n",
    "    X_test = jnp.array(X_test, dtype=jnp.float64).reshape(1, -1)\n",
    "    y_pred = predict_fn(X_test)\n",
    "    y_pred_round = np.where(y_pred > eps, 1, 0).tolist()\n",
    "    X_test = addition.unflatten_sample(y_pred_round, m)\n",
    "\n",
    "    print(\"Iteration:\", i+1)\n",
    "    print(\"Updated variables:\")\n",
    "    for key in X_test.keys():\n",
    "        if X_test[key] != X_test_old[key]:\n",
    "            print(f\"{key}: {X_test[key]}\")\n",
    "    print()\n",
    "\n",
    "    # add breaking condition to avoid unnecessary printing\n",
    "    # you can remove this condition if you want to see all iterations\n",
    "    if X_test['sum_c'][:-1] == [0]*(m-1) and X_test['sum_q'] == [0]*m:\n",
    "        break\n",
    "\n",
    "out_bin = np.array([int(x) for x in np.binary_repr(out, width=m+1)])[::-1].tolist()\n",
    "# check if the output is correct\n",
    "assert out_bin[-1] == X_test['sum_c'][-1]\n",
    "assert out_bin[:-1] == X_test['sum_p']\n",
    "\n",
    "# decode output from binary to int\n",
    "prediction = int(''.join([str(int(x)) for x in X_test['sum_p'] + X_test['sum_c'][-1:]])[::-1], 2)\n",
    "print(f\"Expected Output: {out}\")\n",
    "print(f\"Predicted Output: {prediction}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c92b6d2",
   "metadata": {},
   "source": [
    "### Multiplication"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f7df5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "m = 3 # number of bits\n",
    "\n",
    "# define dataset\n",
    "X, Y_train = multiplication.get_dataset(m)\n",
    "pad = m\n",
    "X_train = jnp.eye(len(X) + pad)\n",
    "Y_pad = np.zeros((pad, Y_train.shape[1]), dtype=np.float64)\n",
    "Y_train = jnp.array(np.vstack((Y_train, Y_pad)), dtype=jnp.float64)\n",
    "predict_fn = get_ntk_predictor(X_train, Y_train)\n",
    "\n",
    "# define test sample\n",
    "# randomly sample from 2**m range\n",
    "multiplier = np.random.randint(0, 2**m)\n",
    "multiplicand = np.random.randint(0, 2**m)\n",
    "out = multiplier*multiplicand\n",
    "print(f\"Multiplier: {multiplier}, Multiplicand: {multiplicand}\")\n",
    "\n",
    "# initialize test sample\n",
    "X_test = multiplication.get_sample(m)[1]\n",
    "X_test['multiplier'] = np.array([int(x) for x in np.binary_repr(multiplier, width=m)])[::-1].tolist()\n",
    "X_test['multiplicand'] = np.array([int(x) for x in np.binary_repr(multiplicand, width=2*m)])[::-1].tolist()\n",
    "X_test['to_check_lsb'][0] = 1\n",
    "\n",
    "print(\"Iteration: 0\")\n",
    "print(X_test)\n",
    "print()\n",
    "\n",
    "for i in range(4*(m**2)+ 3*m):\n",
    "    X_test_old = X_test.copy()\n",
    "    X_test = utils.encode_data(X_test, X)\n",
    "    X_test = np.concatenate([X_test, np.zeros(pad)])\n",
    "    X_test = jnp.array(X_test, dtype=jnp.float64).reshape(1, -1)\n",
    "    y_pred = predict_fn(X_test)\n",
    "\n",
    "    y_pred = np.where(y_pred > eps, 1, 0).tolist()\n",
    "    X_test = multiplication.unflatten_sample(y_pred, m)\n",
    "    print(\"Iteration:\", i+1)\n",
    "    print(\"Updated variables:\")\n",
    "    for key in X_test.keys():\n",
    "        if X_test[key] != X_test_old[key]:\n",
    "            print(f\"{key}: {X_test[key]}\")\n",
    "    print()\n",
    "\n",
    "    # add breaking condition to avoid unnecessary printing\n",
    "    # you can remove this condition if you want to see all iterations\n",
    "    if X_test['multiplier'] == [0]*m:\n",
    "        break\n",
    "\n",
    "out_bin = np.array([int(x) for x in np.binary_repr(out, width=2*m)])[::-1].tolist()\n",
    "assert out_bin == X_test['sum_p']\n",
    "\n",
    "# decode output from binary to int\n",
    "prediction = int(''.join([str(int(x)) for x in X_test['sum_p']])[::-1], 2)\n",
    "print(f\"Expected Output: {out}\")\n",
    "print(f\"Predicted Output: {prediction}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neurips",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
