{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4b2507d1-c494-4b2a-9369-93db7c698617",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import normaltest\n",
    "from scipy.fftpack import dct, idct\n",
    "rng = np.random.default_rng(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "85eb5ef1-1cbf-4cd8-b62a-c3ff21e98b0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample a normal N conditional on |N| < kappa\n",
    "def condnormal(kappa):\n",
    "    while True:\n",
    "        x = rng.uniform(-kappa, kappa)\n",
    "        y = rng.uniform(0.0, 1.0)\n",
    "        if y < np.exp(- .5 * (x ** 2)):\n",
    "            return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "67b7f161-8890-4ade-932b-cddd7d1baa91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample a n-dimensional normal conditioned on the sum of entries being y\n",
    "def condsum(n, y):\n",
    "    invrootn = n ** (-.5)\n",
    "    z = rng.normal(size = n, scale = invrootn)\n",
    "    z[0] = invrootn * y\n",
    "    return idct(z, norm='ortho') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e108cb19-1c3b-41cb-a35d-7770a91e2b54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample an instance backdoored by {+1}^m\n",
    "def samplematrix(n, m, kappa):\n",
    "    rhs = np.array([condnormal(kappa) for _ in range(m)])\n",
    "    lhs = np.array([condsum(n, rhs[j]) for j in range(m)])\n",
    "    return lhs, rhs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9eb81f50-a3fa-4f8f-9a3b-33f45bca4c4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# shortest column vector\n",
    "def mincol(A):\n",
    "    n = A.shape[1]\n",
    "    return np.min([np.linalg.norm(A[:,i]) for i in range(n)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "49803e4f-2506-40e3-abc6-89cabb231d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# online norm minimization\n",
    "def online(A, permute = False):\n",
    "    m, n = A.shape[0], A.shape[1]\n",
    "    s = np.zeros(m)\n",
    "    order = range(n)\n",
    "    if permute:\n",
    "        order = rng.permutation(order)\n",
    "    for i in order:\n",
    "        sp, sm = s + A[:,i], s - A[:,i]\n",
    "        if np.linalg.norm(sp) < np.linalg.norm(sm):\n",
    "            s = sp\n",
    "        else:\n",
    "            s = sm\n",
    "    return np.linalg.norm(s) * (n ** (-.5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "97c1795e-b2aa-400d-868a-82e8c501bea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sign of random orthogonal vector\n",
    "def orthsign(A):\n",
    "    m, n = A.shape[0], A.shape[1]\n",
    "    invrootn = n ** (-.5)\n",
    "    B = np.vstack((A, rng.normal(size = n, scale = invrootn)))\n",
    "    sol = np.sign(np.linalg.qr(B.T)[0][:, m])\n",
    "    return invrootn * np.linalg.norm(A.dot(np.sign(sol)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f79f0ff2-9450-4c93-8bf7-0ddb0175f580",
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment(n, m, kappa, trials = 1):\n",
    "    A, y = samplematrix(n, m, kappa)\n",
    "    truth = np.linalg.norm(y) * (n ** (-.5))\n",
    "    alg1 = mincol(A)\n",
    "    alg2 = np.min([online(A, True) for _ in range(trials)])\n",
    "    alg3 = np.min([orthsign(A) for _ in range(trials)])\n",
    "    return truth, alg1, alg2, alg3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a11feb4c-ea57-46e3-922e-9f6f81a5c4cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(np.float64(1.5822449770203932e-11),\n",
       " np.float64(0.13806155369377351),\n",
       " np.float64(0.027835785537033603),\n",
       " np.float64(0.09418950718885981))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "experiment(100, 10, 1e-10, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4d065370-257c-491c-98b6-6e7d650b905a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(np.float64(2.5626732342841875e-11),\n",
       " np.float64(0.3112373045580075),\n",
       " np.float64(0.0909608666344407),\n",
       " np.float64(0.1642954768157144))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "experiment(100, 20, 1e-10, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6dc28dd6-048b-4563-b9d6-b49dc84625b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(np.float64(3.31723040821967e-11),\n",
       " np.float64(0.36262827813599036),\n",
       " np.float64(0.13248269084977216),\n",
       " np.float64(0.2208142658469229))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "experiment(100, 30, 1e-10, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5e3730e3-7e48-49a6-81e8-9558a4d3886d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sympy import ZZ\n",
    "from sympy.polys.matrices import DomainMatrix, DM\n",
    "\n",
    "# LLL algortihm for short combinations\n",
    "def lllattack(A, kappa):\n",
    "    m, n = A.shape[0], A.shape[1]\n",
    "    scaling = 1.0 / (kappa * (n ** .5))\n",
    "    B = np.hstack([np.identity(n), scaling * A.T])\n",
    "    M = DM(np.round(B), ZZ)\n",
    "    N = M.lll()\n",
    "    x = np.array([float(N[0,i].to_sympy().as_expr()) for i in range(n)])\n",
    "    val = np.linalg.norm(A.dot(x)) / np.linalg.norm(x)\n",
    "    return x, val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "adc8140b-7517-4549-b07d-5df74dd03bf4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10 3 0.01868876924163716 (array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]), np.float64(0.20048890758086796))\n",
      "20 6 0.0026098591847644264 (array([ 0., -1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
      "        0.,  0.,  1.,  1.,  0.,  0.,  0.]), np.float64(0.1213577451953749))\n",
      "30 9 0.00035405311338038105 (array([-1.,  1.,  2., -1.,  1., -1., -1., -3.,  5.,  1.,  0.,  0., -1.,\n",
      "        1.,  2., -1., -5.,  0., -1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,\n",
      "        0.,  0.,  0.,  0.]), np.float64(0.010524542974345205))\n",
      "40 12 3.627002247433539e-05 (array([-10.,   0.,  -5.,   1.,  11.,   2.,  -2.,  -2.,  -2.,   0.,  -9.,\n",
      "         4.,  -5.,   0.,   6.,  -3.,  -6.,  -7.,   1.,  -6.,  -5.,   1.,\n",
      "         4.,  -1.,   3.,   5.,  -7.,  -3.,   1.,   0.,   0.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.,   0.]), np.float64(0.0006797968454243592))\n",
      "50 15 3.2133376309578724e-06 (array([ 19.,  22.,  -8.,  -5.,  -4.,   8.,   8.,  14.,   5., -21.,   5.,\n",
      "        -3.,   5.,   3.,  12.,  31.,   5.,  12.,  11., -17.,  -5.,   2.,\n",
      "         5.,  17.,   3.,   3.,   5.,   1.,  -9.,   9.,  17.,   2., -11.,\n",
      "        -4.,   5.,   3.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
      "         0.,   0.,   0.,   0.,   0.,   0.]), np.float64(7.863389308262932e-05))\n"
     ]
    }
   ],
   "source": [
    "# LLL test with parameters m = 3t, n = 10t and kappa = 10**(-t)\n",
    "for t in range(1, 6):\n",
    "    kappa = 10 ** (-t)\n",
    "    n = 10 * t\n",
    "    m = 3 * t\n",
    "    A, y = samplematrix(n, m, kappa)\n",
    "    print(n, m, np.linalg.norm(y) * (n ** (-.5)), lllattack(A, kappa))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b8656c2e-b312-4c05-8b2b-6be65437a1e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NormaltestResult(statistic=array([2.06930434, 2.88019138, 0.2660114 , 3.53931714, 1.43174038,\n",
       "       0.41588142, 1.48205014, 1.88375761, 0.52597037, 2.987965  ,\n",
       "       1.06540063, 1.92564548, 0.3787936 , 0.2448641 , 1.16444334,\n",
       "       0.25556481, 1.77561349, 0.28081116, 1.05725784, 0.61365912,\n",
       "       2.2332363 , 0.15406719, 4.56878044, 0.27017596, 1.89292101,\n",
       "       1.01579755, 0.04871762, 3.02898827, 2.18737385, 3.47945545]), pvalue=array([0.35534996, 0.23690509, 0.8754601 , 0.17039116, 0.48876661,\n",
       "       0.81225519, 0.47662509, 0.38989461, 0.76875329, 0.22447689,\n",
       "       0.58701769, 0.3818136 , 0.82745811, 0.88476602, 0.55865584,\n",
       "       0.88004485, 0.41155741, 0.86900571, 0.58941255, 0.735776  ,\n",
       "       0.32738509, 0.92585875, 0.10183614, 0.87363904, 0.38811232,\n",
       "       0.60175868, 0.97593547, 0.21991941, 0.33497917, 0.1755682 ]))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# normality test\n",
    "A, b = samplematrix(100, 30, 1e-10) \n",
    "rootn = A.shape[1] ** .5\n",
    "normaltest(rootn * A.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7973d4c6-c155-483d-999d-009e75ec986e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
