{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "80938cd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import scipy.special,scipy.linalg\n",
    "import numpy as np\n",
    "import time\n",
    "from matplotlib import pyplot as plt\n",
    "#from sklearn.datasets import fetch_mldata\n",
    "from tensorflow.keras.datasets import mnist,fashion_mnist"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93193f88",
   "metadata": {},
   "source": [
    "## GENERATE DATA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e56d39dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_data(testcase,Tr,Te,prop,means=None,covs=None):\n",
    "    rng = np.random\n",
    "    \n",
    "    if testcase is 'MNIST':\n",
    "        #mnist=fetch_mldata('MNIST original')\n",
    "        (X_train_full, y_train_full), (X_test_full, y_test_full) = mnist.load_data()\n",
    "        X_train_full = X_train_full.reshape(-1,784)\n",
    "        X_test_full = X_test_full.reshape(-1,784)\n",
    "        #X,y = mnist.data,mnist.target\n",
    "        #X_train_full, X_test_full = X[:60000], X[60000:]\n",
    "        #y_train_full, y_test_full = y[:60000], y[60000:]\n",
    "\n",
    "        selected_target = [7,9]\n",
    "        K=len(selected_target)\n",
    "        X_train = np.array([]).reshape(p,0)\n",
    "        X_test = np.array([]).reshape(p,0)        \n",
    "        \n",
    "        y_train = []\n",
    "        y_test = []\n",
    "        ind=0\n",
    "        for i in selected_target:\n",
    "            locate_target_train = np.where(y_train_full==i)[0][range(np.int(prop[ind]*Tr))]\n",
    "            locate_target_test  = np.where(y_test_full==i)[0][range(np.int(prop[ind]*Te))]\n",
    "            X_train = np.concatenate( (X_train,X_train_full[locate_target_train].T),axis=1)\n",
    "            y_train = np.concatenate( (y_train,2*(ind-K/2+.5)*np.ones(np.int(Tr*prop[ind]))) )\n",
    "            X_test  = np.concatenate( (X_test,X_test_full[locate_target_test].T),axis=1)\n",
    "            y_test = np.concatenate( (y_test,2*(ind-K/2+.5)*np.ones(np.int(Te*prop[ind]))) )\n",
    "            ind+=1                       \n",
    "        \n",
    "        X_train = X_train - np.mean(X_train,axis=1).reshape(p,1)\n",
    "        X_train = X_train*np.sqrt(784)/np.sqrt(np.sum(X_train**2,(0,1))/Tr)\n",
    "        \n",
    "        X_test = X_test - np.mean(X_test,axis=1).reshape(p,1)\n",
    "        X_test = X_test*np.sqrt(784)/np.sqrt(np.sum(X_test**2,(0,1))/Te)\n",
    "        \n",
    "    else:\n",
    "        X_train = np.array([]).reshape(p,0)\n",
    "        Omega = np.array([]).reshape(p,0)\n",
    "        X_test = np.array([]).reshape(p,0)       \n",
    "        y_train = []\n",
    "        y_test = []\n",
    "        K = len(prop)\n",
    "        for i in range(K):    \n",
    "            tmp = rng.multivariate_normal(means[i], covs[i], size=np.int(Tr * prop[i])).T\n",
    "            X_train = np.concatenate((X_train,rng.multivariate_normal(means[i],covs[i],size=np.int(Tr*prop[i])).T),axis=1)\n",
    "            Omega = np.concatenate((Omega, tmp - np.outer(means[i], np.ones((1, np.int(Tr * prop[i]))))), axis=1)\n",
    "            X_test  = np.concatenate((X_test, rng.multivariate_normal(means[i],covs[i],size=np.int(Te*prop[i])).T),axis=1)\n",
    "            y_train = np.concatenate( (y_train,2*(i-K/2+.5)*np.ones(np.int(Tr*prop[i]))) )\n",
    "            y_test = np.concatenate( (y_test,2*(i-K/2+.5)*np.ones(np.int(Te*prop[i]))) )            \n",
    "            \n",
    "    X_train = X_train/math.sqrt(p)\n",
    "    X_test  = X_test/math.sqrt(p)\n",
    "            \n",
    "    return X_train, X_test, y_train, y_test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c03573aa",
   "metadata": {},
   "source": [
    "## Generate σ(⋅) activation functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5cb3eb49",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "from scipy.special import erfc\n",
    "def gen_sig(fun,Z,tau, polynom=None, s1=None, s2=None):\n",
    "    \n",
    "    if fun is 'poly2':\n",
    "        sig = polynom[0]*Z**2+polynom[1]*Z+polynom[2]\n",
    "    elif fun is 'ReLu':\n",
    "        sig = np.maximum(Z,0)\n",
    "        d = [(1/2-1/(2*np.pi))*tau,1/4,1/(8*np.pi*tau)]\n",
    "    elif fun is 'sign':\n",
    "        sig = np.sign(Z)\n",
    "        d = []\n",
    "    elif fun is 'posit':\n",
    "        sig = (Z>0).astype(int)\n",
    "    elif fun is 'erf':\n",
    "        sig = scipy.special.erf(Z)\n",
    "    elif fun is 'cos':\n",
    "        sig = np.cos(Z)\n",
    "        d = []\n",
    "    elif fun is 'rff':\n",
    "        sig = np.concatenate([np.cos(Z), np.sin(Z)])\n",
    "        d = []\n",
    "    elif fun is 'abs':\n",
    "        sig = np.abs(Z)\n",
    "    elif fun is 'ternary':\n",
    "        sig =  (Z>(np.sqrt(2)*s2)).astype(int) - (Z<(np.sqrt(2)*s1)).astype(int)\n",
    "        d = []\n",
    "        #d = [1-(erfc(-s1/np.sqrt(tau))-erfc(s2/np.sqrt(tau)))/2-((erfc(-s1/np.sqrt(tau))+erfc(s2/np.sqrt(tau)))/2-1)**2, 1,1]\n",
    "        \n",
    "    return sig, d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c1911331",
   "metadata": {},
   "outputs": [],
   "source": [
    "def estim_tau(X):\n",
    "    tau = np.mean(np.diag(X.T@X))\n",
    "    \n",
    "    return tau"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "814be295",
   "metadata": {},
   "source": [
    "## Generate original kernels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "aa59dda4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_Phi(fun,A,B,polynom=None,distrib=None,nu=None):\n",
    "    normA = np.sqrt(np.sum(A**2,axis=0))\n",
    "    normB = np.sqrt(np.sum(B**2,axis=0))\n",
    "    \n",
    "    AB = A.T @ B\n",
    "    angle_AB = np.minimum( (1/normA).reshape((len(normA),1)) * AB * (1/normB).reshape( (1,len(normB)) ) ,1.)\n",
    "      \n",
    "    if fun is 'poly2':\n",
    "        mom = {'gauss': [1,0,3],'bern': [1,0,1],'bern_skewed': [1,-2/math.sqrt(3),7/3],'student':[1,0,6/(nu-4)+3]}\n",
    "        A2 = A**2\n",
    "        B2 = B**2\n",
    "        Phi = polynom[0]**2*(mom[distrib][0]**2*(2*AB**2+(normA**2).reshape((len(normA),1))*(normB**2).reshape((1,len(normB))) )+(mom[distrib][2]-3*mom[distrib][0]**2)*(A2.T@B2))+polynom[1]**2*mom[distrib][0]*AB+polynom[1]*polynom[0]*mom[distrib][1]*(A2.T@B+A.T@B2)+polynom[2]*polynom[0]*mom[distrib][0]*( (normA**2).reshape( (len(normA),1) )+(normB**2).reshape( (1,len(normB)) ) )+polynom[2]**2\n",
    "        \n",
    "    elif fun is 'ReLu':\n",
    "        Phi = 1/(2*math.pi)* normA.reshape((len(normA),1)) * (angle_AB*np.arccos(-angle_AB)+np.sqrt(1-angle_AB**2)) * normB.reshape( (1,len(normB)) )\n",
    "        \n",
    "    elif fun is 'abs':\n",
    "        Phi = 2/math.pi* normA.reshape((len(normA),1)) * (angle_AB*np.arcsin(angle_AB)+np.sqrt(1-angle_AB**2)) * normB.reshape( (1,len(normB)) )\n",
    "        \n",
    "    elif fun is 'posit':\n",
    "        Phi = 1/2-1/(2*math.pi)*np.arccos(angle_AB)\n",
    "        \n",
    "    elif fun is 'sign':\n",
    "        Phi = 1-2/math.pi*np.arccos(angle_AB)\n",
    "        \n",
    "    elif fun is 'cos':\n",
    "        Phi = np.exp(-.5*( (normA**2).reshape((len(normA),1))+(normB**2).reshape((1,len(normB))) ))*np.cosh(AB)\n",
    "        \n",
    "    elif fun is 'erf':\n",
    "        Phi = 2/math.pi*np.arcsin(2*AB/np.sqrt((1+2*(normA**2).reshape((len(normA),1)))*(1+2*(normB**2).reshape((1,len(normB))))))\n",
    "\n",
    "    return Phi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d831f2e1",
   "metadata": {},
   "source": [
    "## MAIN CODE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "05ee532a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1\n",
      "0.1\n",
      "0.1\n",
      "0.1\n",
      "0.1\n",
      "Time for Simulations Computation 23min 17s\n"
     ]
    }
   ],
   "source": [
    "from sklearn.utils.extmath import safe_sparse_dot\n",
    "## Parameter setting\n",
    "n = 50000 #512\n",
    "p= 512\n",
    "Tr= 1024      # Training length\n",
    "Te=Tr             # Testing length\n",
    "\n",
    "prop=[.5,.5]       # proportions of each class\n",
    "K=len(prop)        # number of data classes\n",
    "\n",
    "#gammas = [10**x for x in np.arange(-4,2.25,.25)] # Range of gamma for simulations\n",
    "\n",
    "gammas = [10**x for x in np.arange(-2,4.25,.25)] # Range of gamma for simulations\n",
    "\n",
    "#gamma = 10\n",
    "\n",
    "testcase='MNIST'   # testcase for simulation, among 'iid','means','var','orth','mixed',MNIST'\n",
    "sigma='rff'       # activation function, among 'ReLu', 'sign', 'posit', 'erf', 'poly2', 'cos', 'abs'\n",
    "\n",
    "\n",
    "# Only used for sigma='poly2'\n",
    "polynom=[-.5,0,1]  # sigma(t)=polynom[0].t²+polynom[1].t+polynom[2]\n",
    "distrib='gauss'  # distribution of Wij, among 'gauss','bern','bern_skewed','student'\n",
    "\n",
    "# Only used for sigma='poly2' and distrib='student'\n",
    "nu=7             # degrees of freedom of Student-t distribution\n",
    "        \n",
    "\n",
    "## Generate X_train,X_test,y_train,y_test\n",
    "if testcase is 'MNIST':\n",
    "    p=784\n",
    "    X_train, X_test,y_train,y_test = gen_data(testcase,Tr,Te,prop)\n",
    "else:    \n",
    "    means=[]\n",
    "    covs=[]\n",
    "    if testcase is 'iid':\n",
    "        for i in range(K):\n",
    "            #means.append(np.zeros(p))\n",
    "            means.append( np.concatenate( (np.zeros(i),4*np.ones(1),np.zeros(p-i-1)) ) )\n",
    "            #covs.append(np.eye(p))     \n",
    "            covs.append(np.eye(p)*(1+8*i/np.sqrt(p)))\n",
    "    elif testcase is 'means':\n",
    "        for i in range(K):\n",
    "            means.append( np.concatenate( (np.zeros(i),4*np.ones(1),np.zeros(p-i-1)) ) )\n",
    "            covs.append(np.eye(p))\n",
    "    elif testcase is 'var':\n",
    "        for i in range(K):\n",
    "            means.append(np.zeros(p))\n",
    "            covs.append(np.eye(p)*(1+8*i/np.sqrt(p)))\n",
    "    elif testcase is 'orth':\n",
    "        for i in range(K):\n",
    "            means.append(np.zeros(p))\n",
    "            covs.append( np.diag(np.concatenate( (np.ones(np.int(np.sum(prop[0:i]*p))),40*np.ones(np.int(prop[i]*p)),np.ones(np.int(np.sum(prop[i+1:]*p))) ) ) ))\n",
    "    elif testcase is 'mixed':\n",
    "        for i in range(K):\n",
    "            means.append( np.concatenate( (np.zeros(i),4*np.ones(1),np.zeros(p-i-1)) ) )\n",
    "            covs.append((1+4*i/np.sqrt(p))*scipy.linalg.toeplitz( [(.4*i)**x for x in range(p)] ))            \n",
    "\n",
    "    X_train, X_test,y_train,y_test = gen_data(testcase,Tr,Te,prop,means,covs)\n",
    "    \n",
    "from scipy.optimize import least_squares\n",
    "pi = np.pi\n",
    "tau = estim_tau(X_train)\n",
    "def compute_thresholds(tau):\n",
    "    F = lambda x: ((np.exp(-x[0] ** 2 / tau) + np.exp(-x[1] ** 2 / tau)) / np.sqrt(2 * pi * tau) - np.exp(-tau / 2),\n",
    "                   (-x[0] * np.exp(-x[0] ** 2 / tau) + x[1] * np.exp(-x[1] ** 2 / tau)) / (\n",
    "                       np.sqrt(2 * pi * tau ** 3)) - np.exp(-tau / 2) / 2)\n",
    "    ### relu\n",
    "    #F = lambda x: ((np.exp(-x[0] ** 2 / tau) + np.exp(-x[1] ** 2 / tau)) / np.sqrt(2 * pi * tau) - 1/2,\n",
    "    #               (-x[0] * np.exp(-x[0] ** 2 / tau) + x[1] * np.exp(-x[1] ** 2 / tau)) / (\n",
    "    #                   np.sqrt(2 * pi * tau ** 3)) - 1/np.sqrt(8*pi*tau))\n",
    "\n",
    "    res = least_squares(F, (1, 1), bounds=((0, 0), (1, 1)))\n",
    "    return res.x\n",
    "\n",
    "s1 = -min(compute_thresholds(tau))\n",
    "s2 = max(compute_thresholds(tau))\n",
    "#print(s1, s2)\n",
    "## Simulations\n",
    "\n",
    "start_sim_calculus = time.time()\n",
    "\n",
    "loops = 5        # Number of generations of W to be averaged over\n",
    "epsilons = np.linspace(0,1,10)\n",
    "\n",
    "E_train_ter=np.zeros(len(gammas))\n",
    "E_test_ter =np.zeros(len(gammas))\n",
    "eps = 0.1\n",
    "\n",
    "\n",
    "rng = np.random\n",
    "\n",
    "ind = 0\n",
    "for loop in range(loops):    \n",
    "    print(eps)\n",
    "    if sigma is 'poly2':\n",
    "        if distrib is 'student':\n",
    "            W = rng.standard_t(nu,n*p).reshape(n,p)/np.sqrt(nu/(nu-2))\n",
    "        elif distrib is 'bern':\n",
    "            W = np.sign(rng.randn(n,p))\n",
    "        elif distrib is 'bern_skewed':\n",
    "            Z = rng.rand(n,p)\n",
    "            W = (Z<.75)/np.sqrt(3)+(Z>.75)*(-np.sqrt(3))\n",
    "        elif distrib is 'gauss':\n",
    "            W = rng.randn(n,p)\n",
    "    else:\n",
    "        if distrib is 'gauss':\n",
    "            W = rng.randn(n,p)\n",
    "        elif distrib is 'ternary':\n",
    "            elements = [-1, 0, 1]\n",
    "            probabilities = [(1-eps)/2, eps, (1-eps)/2]\n",
    "            W = np.random.choice(elements, (n,p), p=probabilities)\n",
    "            W = scipy.sparse.csr_matrix(W)\n",
    "            \n",
    "\n",
    "    S_train, _ = gen_sig(sigma,safe_sparse_dot(W, X_train), tau, polynom, s1 = s1, s2=s2)\n",
    "    tmp = scipy.sparse.csr_matrix(S_train)\n",
    "    SS = safe_sparse_dot(tmp.T, tmp)    \n",
    "\n",
    "    S_test, _ = gen_sig(sigma, W @ X_test, tau, polynom, s1 = s1, s2=s2)\n",
    "    tmp_test = scipy.sparse.csr_matrix(S_test)\n",
    "\n",
    "    ind = 0\n",
    "    for gamma in gammas:\n",
    "        gamma = gamma\n",
    "\n",
    "        inv_resolv = np.linalg.solve( SS/Tr+gamma*np.eye(Tr),y_train)\n",
    "        beta = safe_sparse_dot(S_train, inv_resolv/Tr)\n",
    "        z_train = safe_sparse_dot(S_train.T, beta)\n",
    "\n",
    "        z_test = safe_sparse_dot(tmp_test.T, beta)\n",
    "\n",
    "\n",
    "        E_train_ter[ind] += gamma**2*np.linalg.norm(inv_resolv)**2/Tr/loops\n",
    "        #E_train_ter[ind] += np.linalg.norm(y_train-z_train)**2/Tr/loops\n",
    "        E_test_ter[ind]  += np.linalg.norm(y_test-z_test)**2/Te/loops \n",
    "\n",
    "        ind+=1\n",
    "    \n",
    "end_sim_calculus = time.time() \n",
    "\n",
    "m,s = divmod(end_sim_calculus-start_sim_calculus,60)\n",
    "print('Time for Simulations Computation {:d}min {:d}s'.format( int(m),math.ceil(s) ))    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c6c66b8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.01,0.1506593464998466)(0.01778279410038923,0.1506438066505016)(0.03162277660168379,0.15061701289924687)(0.05623413251903491,0.15057191780774934)(0.1,0.1504992447209039)(0.1778279410038923,0.1503910859087268)(0.31622776601683794,0.1502532846177693)(0.5623413251903491,0.15013339616322652)(1.0,0.15016353330257168)(1.7782794100389228,0.1506011636887789)(3.1622776601683795,0.15185169776321186)(5.623413251903491,0.15447574893567262)(10.0,0.15916498176678529)(17.78279410038923,0.1666579152371331)(31.622776601683793,0.17763982696407377)(56.23413251903491,0.19273643836496027)(100.0,0.21269549979493924)(177.82794100389228,0.23882564681174168)(316.22776601683796,0.2739667463812602)(562.341325190349,0.3238795768206397)(1000.0,0.39660177951029996)(1778.2794100389228,0.4960227120137166)(3162.2776601683795,0.6131445757411381)(5623.413251903491,0.7278578996330463)(10000.0,0.822374445629078)\n"
     ]
    }
   ],
   "source": [
    "res = \"\".join(\"({},{})\".format(x,y) for x, y in zip(gammas, E_test_ter))\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ffe4770",
   "metadata": {},
   "outputs": [],
   "source": [
    "E1_gauss = E_test_ter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43da346f",
   "metadata": {},
   "outputs": [],
   "source": [
    "E1_ternary = E_test_ter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ef860b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "E3_ternary = E_test_ter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "738a0b45",
   "metadata": {},
   "outputs": [],
   "source": [
    "E5_ternary = E_test_ter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e3959f",
   "metadata": {},
   "outputs": [],
   "source": [
    "E7_ternary = E_test_ter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d72499b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "E9_ternary = E_test_ter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8d33c23",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Plots     \n",
    "#p11,=plt.plot(epsilons,E_train,'bo')\n",
    "#p21,=plt.plot(epsilons,E_test,'ro')\n",
    "p12,=plt.plot(gammas,E1_ternary,'bo')\n",
    "p22,=plt.plot(gammas,E3_ternary,'ko')\n",
    "p32,=plt.plot(gammas,E5_ternary,'go')\n",
    "p42,=plt.plot(gammas,E7_ternary,'co')\n",
    "p52,=plt.plot(gammas,E9_ternary,'ro')\n",
    "p53,=plt.plot(gammas,E1_gauss,'k')\n",
    "plt.xscale('log')\n",
    "#plt.yscale('log')\n",
    "plt.xlim( gammas[0],gammas[-1] )\n",
    "#plt.ylim(np.amin(E_train_ter),np.amax(E_test_ter))\n",
    "plt.legend([p12, p22, p32, p42, p52, p53], [\"E_train\", \"E_test\"],bbox_to_anchor=(1, 1), loc='upper left')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3a61175",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b067b86b",
   "metadata": {},
   "outputs": [],
   "source": [
    "F = lambda x: ((np.exp(-x[0] ** 2 / tau) + np.exp(-x[1] ** 2 / tau)) / np.sqrt(2 * pi * tau),\n",
    "                   (-x[0] * np.exp(-x[0] ** 2 / tau) + x[1] * np.exp(-x[1] ** 2 / tau)) / (\n",
    "                       np.sqrt(2 * pi * tau ** 3)))\n",
    "x = np.array([2.84094572e-07, 9.99999948e-01])\n",
    "print(F(x))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
