{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cff38960",
   "metadata": {},
   "source": [
    "This notebook provides the necessary ingredients for our TRF method (Algorithm 1 in the paper)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "32d02955",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import least_squares\n",
    "import math\n",
    "import scipy.special,scipy.linalg\n",
    "import numpy as np\n",
    "import time\n",
    "from matplotlib import pyplot as plt\n",
    "pi = np.pi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "94a7a234",
   "metadata": {},
   "outputs": [],
   "source": [
    "### The function below takes as input the normalized data and returns the parameter tau used in our TRF alforithm\n",
    "def estim_tau(X):\n",
    "    tau = np.mean(np.diag(X.T@X))\n",
    "    \n",
    "    return tau\n",
    "\n",
    "\n",
    "### The function below takes as input the parameter tau (computed from data) using the function estim_tau(X), and return\n",
    "### the thresholds s_minus and s_plus used to set the ternary function sigma_ter\n",
    "def compute_thresholds(tau):\n",
    "    F = lambda x: ((np.exp(-x[0] ** 2 / tau) + np.exp(-x[1] ** 2 / tau)) / np.sqrt(2 * pi * tau) - np.exp(-tau / 2),\n",
    "                   (-x[0] * np.exp(-x[0] ** 2 / tau) + x[1] * np.exp(-x[1] ** 2 / tau)) / (\n",
    "                       np.sqrt(2 * pi * tau ** 3)) - np.exp(-tau / 2) / 2)\n",
    "    ### relu\n",
    "    #F = lambda x: ((np.exp(-x[0] ** 2 / tau) + np.exp(-x[1] ** 2 / tau)) / np.sqrt(2 * pi * tau) - 1/2,\n",
    "    #               (-x[0] * np.exp(-x[0] ** 2 / tau) + x[1] * np.exp(-x[1] ** 2 / tau)) / (\n",
    "    #                   np.sqrt(2 * pi * tau ** 3)) - 1/np.sqrt(8*pi*tau))\n",
    "\n",
    "    res = least_squares(F, (1, 1), bounds=((0, 0), (1, 1)))\n",
    "    s_minus = -min(res.x)\n",
    "    s_plus = max(res.x)\n",
    "    return s_minus , s_plus \n",
    "\n",
    "\n",
    "### Given a sparsity level eps, data size n, data dimension p, the function below return an i.i.d zero mean unit variance \n",
    "### random matrix taking values -1, 0, 1\n",
    "def ternary_weight(eps, n, p):\n",
    "    elements = [-1, 0, 1]\n",
    "    probabilities = [(1-eps)/2, eps, (1-eps)/2]\n",
    "    W = np.random.choice(elements, (n,p), p=probabilities)\n",
    "    W = scipy.sparse.csr_matrix(W)\n",
    "    return W\n",
    "\n",
    "\n",
    "### Given s_minus, s_plus, and projected data Z,  the function below returns the activation \\sigma^{ter}(Z)\n",
    "def gen_sig(Z, s1=None, s2=None):\n",
    "    sig =  (Z>(np.sqrt(2)*s_plus)).astype(int) - (Z<(np.sqrt(2)*s_minus)).astype(int)\n",
    "    return sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dbeda2f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
