{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YQAkPdsJrJwj",
        "outputId": "da2b716f-c85e-4574-c2e5-42ada5a35419"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting entropy_estimators\n",
            "  Downloading entropy_estimators-0.0.1.tar.gz (7.6 kB)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from entropy_estimators) (1.4.1)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from entropy_estimators) (1.19.5)\n",
            "Building wheels for collected packages: entropy-estimators\n",
            "  Building wheel for entropy-estimators (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for entropy-estimators: filename=entropy_estimators-0.0.1-py3-none-any.whl size=8314 sha256=c7fd0f2cfff3b77825b87f4e98f0390adf9eb286b8772cadb44b627b5990bf57\n",
            "  Stored in directory: /root/.cache/pip/wheels/c8/75/ce/83a79892fd974fe4b87cde76bc4e84a2d0a88bb9bd639a1cd5\n",
            "Successfully built entropy-estimators\n",
            "Installing collected packages: entropy-estimators\n",
            "Successfully installed entropy-estimators-0.0.1\n"
          ]
        }
      ],
      "source": [
        "!pip install entropy_estimators"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tiB9XL48hKSc"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import numpy.random as npr\n",
        "import math\n",
        "import scipy\n",
        "from scipy import special\n",
        "from scipy import linalg\n",
        "\n",
        "from tqdm import tqdm\n",
        "from google.colab import files\n",
        "\n",
        "import matplotlib \n",
        "from matplotlib import pyplot as plt\n",
        "from mpl_toolkits import mplot3d\n",
        "\n",
        "import sklearn\n",
        "from sklearn import datasets\n",
        "\n",
        "from entropy_estimators import continuous\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "device = torch.device('cpu')\n",
        "\n",
        "import random\n",
        "import time\n",
        "import datetime"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iCYz7MHRFwhJ"
      },
      "source": [
        "## **Parameters**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rGJ2mW1vhQSH"
      },
      "source": [
        "###*Regularization*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nHJzzS-ohT1w"
      },
      "outputs": [],
      "source": [
        "lambda1 = 1e-3\n",
        "lambda2 = 1e-5\n",
        "gamma = 0.5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qPnEc7nlLs1l"
      },
      "source": [
        "###*Particle-SDCA*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VyavJtLyL1ex"
      },
      "outputs": [],
      "source": [
        "T0_PSDCA = 32\n",
        "T1_PSDCA = 10\n",
        "T2_PSDCA = 1000\n",
        "eta_PSDCA = 1e-4"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yYhXWhjM8DzD"
      },
      "source": [
        "###*PDA*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x90ijrwo8GuL"
      },
      "outputs": [],
      "source": [
        "T0_PDA = 320\n",
        "T1_PDA = 1\n",
        "T2_PDA = 50\n",
        "eta_PDA = 1e-5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qLq1pXQkK72o"
      },
      "source": [
        "###*SGD*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ATK0h2OYK_LU"
      },
      "outputs": [],
      "source": [
        "T0_SGD = 6100\n",
        "T1_SGD = 50\n",
        "eta_SGD = 5*1e-3"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UXUjDe-9S4nh"
      },
      "source": [
        "###*Model*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ijVg6QXQFzvQ"
      },
      "outputs": [],
      "source": [
        "layer_size = [50,1] \n",
        "M = 200\n",
        "nonlinearity = \"tanh\" \n",
        "nonlinearity_scaler = 1.0\n",
        "bias = True"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1RRCpjd9zTpT"
      },
      "source": [
        "Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sKemYmq8zUvr"
      },
      "outputs": [],
      "source": [
        "n = 1000\n",
        "nt = 1000\n",
        "d = 50\n",
        "Mt = 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o4wpC5JTHp0e"
      },
      "source": [
        "## **Model**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D2g9T5Q6iTQK"
      },
      "source": [
        "###*Neural Network*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-9jc4gwr-RBO"
      },
      "outputs": [],
      "source": [
        "def model(X, lambda1, lambda2, params1, params2 = None, innervalues = False, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n",
        "  if nonlinearity==\"tanh\":\n",
        "    def activation_function(x):\n",
        "      return torch.tanh(x) * nonlinearity_scaler\n",
        "  if nonlinearity in [\"relu\" , \"ReLU\" , \"RELU\"]:\n",
        "    def activation_function(x):\n",
        "      return torch.relu(x) * nonlinearity_scaler\n",
        "  num_layer = len(params1)\n",
        "  num_data = X.shape[0]\n",
        "  innervalues_ = []\n",
        "  forward = X.repeat((params1[0].shape[0], 1, 1))\n",
        "  if params2 == None:\n",
        "    for i in range(0, len(params1)):\n",
        "      innervalues_ = innervalues_ + [forward]\n",
        "      forward = torch.einsum('mbi,mji->mbj', forward, params1[i])\n",
        "      forward = activation_function(forward)\n",
        "  else:\n",
        "    for i in range(0, len(params1)):\n",
        "      innervalues_ = innervalues_ + [forward]\n",
        "      forward = torch.einsum('mbi,mji->mbj', forward, params1[i]) + params2[i].repeat((num_data,1,1)).permute(1,0,2)\n",
        "      forward = activation_function(forward)\n",
        "  if innervalues:\n",
        "    return forward[:,:,0], innervalues_ \n",
        "  else:\n",
        "    return forward[:,:,0] \n",
        "\n",
        "def model_grad(X, lambda1, lambda2, params1, params2=None, innervalues=None, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n",
        "  if innervalues == None:\n",
        "    _, innervalues_ = model(X, lambda1, lambda2, params1, params2 = params2, innervalues = True, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n",
        "  if nonlinearity==\"tanh\":\n",
        "    def activation_function_grad(x):\n",
        "      return (1-torch.tanh(x)**2) * nonlinearity_scaler\n",
        "  if nonlinearity in [\"relu\" , \"ReLU\" , \"RELU\"]:\n",
        "    def activation_function_grad(x):\n",
        "      return x * (x > 0.0) * nonlinearity_scaler\n",
        "  num_data = X.shape[0]\n",
        "  grad1 = []\n",
        "  grad2 = []\n",
        "  for i in range(len(params1)):\n",
        "    if params2 == None:\n",
        "      sigma_i_grad = activation_function_grad(torch.einsum('mij,mbj->mbi',params1[i],innervalues_[i]))\n",
        "      for j in range(i):\n",
        "        grad1[j] = torch.einsum('mbi,mbij->mbij',torch.einsum('mbk,mki->mbi',sigma_i_grad ,params1[i]),grad1[j])\n",
        "      grad1 = grad1 + [torch.einsum('mbi,mbj->mbij',sigma_i_grad,innervalues_[i])]\n",
        "    else:\n",
        "      sigma_i_grad = activation_function_grad(torch.einsum('mij,mbj->mbi',params1[i],innervalues_[i])+ params2[i].repeat((num_data,1,1)).permute(1,0,2))\n",
        "      for j in range(i):\n",
        "        grad1[j] = torch.einsum('mbi,mbij->mbij',torch.einsum('mbk,mki->mbi',sigma_i_grad ,params1[i]),grad1[j])\n",
        "        grad2[j] = torch.einsum('mbi,mbi->mbi',torch.einsum('mbk,mki->mbi', sigma_i_grad ,params1[i]),grad2[j])   \n",
        "      grad1 = grad1 + [torch.einsum('mbi,mbj->mbij',sigma_i_grad,innervalues_[i])]  \n",
        "      grad2 = grad2 + [sigma_i_grad]\n",
        "  return grad1, grad2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C76XQ1_9inrt"
      },
      "source": [
        "###*Loss functions and corresponding functions for conjugate, initialization, update*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VcjROqSbhiP6"
      },
      "outputs": [],
      "source": [
        "def MSEloss(Y,Z, gamma=0.5):\n",
        "  return (Y-Z)**2/(2*gamma)\n",
        "\n",
        "def MSEconjugate(G,Y, gamma=0.5):\n",
        "  return G**2 *gamma/2 + G*Y\n",
        "\n",
        "def MSEupdate(integral, n,y_i, g_i_t,max_iter_update = 0, gamma=0.5, eps = 1e-18):\n",
        "  return (integral - y_i + g_i_t/(n*lambda2) )/(gamma+1/(n*lambda2))\n",
        "\n",
        "def MSEgrad(Y,Z, gamma=0.5): #Y: variable, Z: constant\n",
        "  return (Y-Z)/gamma\n",
        "\n",
        "def MSEconjugategrad(G,Y, gamma=0.5):\n",
        "  return G * gamma + Y"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "guzCJlJwWMV-"
      },
      "outputs": [],
      "source": [
        "loss_func = MSEloss\n",
        "lossconjugate = MSEconjugate\n",
        "update = MSEupdate\n",
        "lossgrad = MSEgrad\n",
        "lossconjugategrad = MSEconjugategrad"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h6l6IPRKg7bx"
      },
      "source": [
        "##**Particle-SDCA Algorithm**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dqvRwSHMLhfQ"
      },
      "source": [
        "###*Sampling*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XYHjo5lTUOS-"
      },
      "outputs": [],
      "source": [
        "def LMC_PSDCA(g, X, lambda1, lambda2, params1, params2 = None, eta=eta_PSDCA, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n",
        "  num_data = X.shape[0]\n",
        "  grad1, grad2 = model_grad(X, lambda1, lambda2, params1, params2=params2, innervalues=None, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler) \n",
        "  for i in range(len(params1)):\n",
        "    dU1 =  torch.einsum('mbij,b->mij',grad1[i],g) / (num_data*lambda2)  +  (2 * lambda1 / lambda2) * params1[i]\n",
        "    params1[i] = params1[i] - eta * dU1 + np.sqrt(2*eta) * torch.FloatTensor(params1[i].shape[0],params1[i].shape[1],params1[i].shape[2]).normal_().to(device)\n",
        "    if params2 != None:\n",
        "      dU2 =  torch.einsum('mbi,b->mi',grad2[i],g) / (num_data*lambda2)  +  (2 * lambda1 / lambda2) * params2[i]\n",
        "      params2[i] = params2[i] - eta * dU2 + np.sqrt(2*eta) * torch.FloatTensor(params2[i].shape[0],params2[i].shape[1]).normal_().to(device)\n",
        "  return params1, params2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fqRr1eL1e2__"
      },
      "outputs": [],
      "source": [
        "def sample_PSDCA(g, X, M, layer_size, lambda1, lambda2, T1, method=LMC_PSDCA, params1=None, params2=None, bias = False, eta=eta_PSDCA, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n",
        "  if (bias == False) and params2 != None:\n",
        "    print(\"Error(sample): bias=False but params2 exists\")\n",
        "  if params1 == None:\n",
        "    params1 = [torch.FloatTensor(M,layer_size[i+1],layer_size[i]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n",
        "    if bias:\n",
        "      params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n",
        "  for t in range(T1):\n",
        "    params1, params2  = method(g, X, lambda1, lambda2, params1, params2 =params2, eta=eta, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n",
        "  return params1, params2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HeS6qOPbeXAm"
      },
      "source": [
        "###*Main*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "srwWMPEDecB-"
      },
      "outputs": [],
      "source": [
        "def PSDCA(X,Y,Xt,Yt,g,M,lambda1, lambda2, layer_size,record_PSDCA, T0 = T0_PSDCA, T1 = T1_PSDCA, T2 = T2_PSDCA, eta = eta_PSDCA, gamma= gamma, resampling = False, bias=False, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n",
        "  num_data = X.shape[0]\n",
        "  num_data_test = Xt.shape[0]\n",
        "  num_grad = -T1*num_data\n",
        "  params1 = None\n",
        "  params2 = None\n",
        "\n",
        "  for T in range(0,T0):\n",
        "    if resampling:\n",
        "      params1 = None\n",
        "      params2 = None\n",
        "    if T==0:\n",
        "      params1, params2 = sample_PSDCA(g, X, M,layer_size,lambda1,lambda2, 0, params1=params1, params2 = params2, bias = bias, eta=eta, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n",
        "    else:\n",
        "      params1, params2 = sample_PSDCA(g, X, M,layer_size,lambda1,lambda2, T1, params1=params1, params2 = params2, bias = bias, eta=eta, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n",
        "    num_grad += T1* num_data\n",
        "    rs = torch.ones(M) / M\n",
        "    train_loss = sum(loss_func(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Y, gamma = gamma))/num_data\n",
        "    test_loss = sum(loss_func(torch.einsum('mb,m->b',model(Xt,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Yt, gamma = gamma))/num_data_test\n",
        "    if bias:\n",
        "      params = torch.cat([params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2])), params2[0]],dim=1)  \n",
        "    else:\n",
        "      params = params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2]))\n",
        "    for i in range(1,len(params1)):\n",
        "      if bias:\n",
        "        params =torch.cat([params, torch.cat([params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2])), params2[i]],dim=1)],dim=1)\n",
        "      else:\n",
        "        params =torch.cat([params, params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2]))],dim=1)\n",
        "    L2 = (torch.norm(params)**2)/ M\n",
        "    negativeentropy = - continuous.get_h(params, k=5) \n",
        "    record_PSDCA[0,T] = train_loss\n",
        "    record_PSDCA[1,T] = L2\n",
        "    record_PSDCA[2,T] = negativeentropy\n",
        "    record_PSDCA[3,T] = train_loss + lambda1*L2 + lambda2*negativeentropy\n",
        "    record_PSDCA[4,T] = test_loss\n",
        "    record_PSDCA[5,T] = num_grad\n",
        "    if T % 3 == 0:\n",
        "      print(\"Step: %d, Gradient evaluation: %d, Training loss: %f, L2 regularization: %f, Entropy %f, Regularized loss : %f, Test loss: %f\" % (T,num_grad,train_loss,L2,negativeentropy,train_loss + lambda1*L2 + lambda2*negativeentropy,test_loss))\n",
        "\n",
        "    indices = list(range(num_data))\n",
        "    random.shuffle(indices) \n",
        "    for i_t in indices[:T2]:\n",
        "        integral = sum(model(X[i_t:i_t+1,:],lambda1, lambda2, params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)[:,0] * rs) /sum(rs)\n",
        "        delg = - g[i_t]\n",
        "        g[i_t] = update(integral, num_data, Y[i_t], g[i_t])\n",
        "        delg += g[i_t]\n",
        "        rs = rs * torch.exp(- model(X[i_t:i_t+1,:], lambda1, lambda2, params1,params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)[:,0]*delg/(num_data*lambda2))\n",
        "        rs = rs/sum(rs)\n",
        "  return params1, params2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tn7WPap4wvHc"
      },
      "source": [
        "##**PDA Algorithm**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W_EtVPbx2g_u"
      },
      "source": [
        "###*Sampling*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "moNVrUnAxSmH"
      },
      "outputs": [],
      "source": [
        "def LMC_PDA(g, X, A, t,  lambda1, lambda2, params1, params2 = None, eta=eta_PDA, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n",
        "  num_data = X.shape[0]\n",
        "  grad1, grad2 = model_grad(X, lambda1, lambda2, params1, params2=params2, innervalues=None, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler) \n",
        "  for i in range(len(params1)):\n",
        "    dU1 = 2* torch.einsum('mbij,b->mij',grad1[i],A) / (lambda2*(t+2)*(t+3))  +  2 * lambda1 * (t+1) / (lambda2*(t+3)) * params1[i]\n",
        "    params1[i] = params1[i] - eta * dU1 + np.sqrt(2*eta) * torch.FloatTensor(params1[i].shape[0],params1[i].shape[1],params1[i].shape[2]).normal_().to(device)\n",
        "    if params2 != None:\n",
        "      dU2 = 2* torch.einsum('mbi,b->mi',grad2[i],A) / (lambda2*(t+2)*(t+3))  +  2 * lambda1 * (t+1) / (lambda2*(t+3))  * params2[i]\n",
        "      params2[i] = params2[i] - eta * dU2 + np.sqrt(2*eta) * torch.FloatTensor(params2[i].shape[0],params2[i].shape[1]).normal_().to(device)\n",
        "  return params1, params2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S2gNVW8F2gRp"
      },
      "outputs": [],
      "source": [
        "def sample_PDA(g, X, A, t, M, layer_size, lambda1, lambda2, T1, method=LMC_PDA, params1=None, params2=None, bias = False, eta=eta_PDA, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n",
        "  if (bias == False) and params2 != None:\n",
        "    print(\"Error(sample): bias=False but params2 exists\")\n",
        "  if params1 == None:\n",
        "    params1 = [torch.FloatTensor(M,layer_size[i+1],layer_size[i]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n",
        "    if bias:\n",
        "      params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n",
        "  for i in range(T1):\n",
        "    params1, params2  = method(g, X, A, t,lambda1, lambda2, params1, params2 =params2, eta=eta, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n",
        "  return params1, params2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M8l--1ZQk0dl"
      },
      "source": [
        "###*Main*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GI2DSJvWk9DU"
      },
      "outputs": [],
      "source": [
        "def PDA(X,Y,Xt,Yt,M,lambda1, lambda2, layer_size,record_PDA, T0 = T0_PDA, T1 = T1_PDA, T2 = T2_PDA, eta = eta_PDA, gamma= gamma, resampling = False, bias=False, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n",
        "  num_data = X.shape[0]\n",
        "  num_data_test = Xt.shape[0]\n",
        "  num_grad = 0\n",
        "  params1 = None\n",
        "  params2 = None\n",
        "  A = torch.zeros(num_data)\n",
        "  for T in range(0,T0):\n",
        "    if resampling:\n",
        "      params1 = None\n",
        "      params2 = None\n",
        "    params1, params2 = sample_PDA(g, X, A, T ,M,layer_size,lambda1,lambda2, T1, params1=params1, params2 = params2, bias = bias, eta=eta, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n",
        "    num_grad += T1* num_data\n",
        "    rs = torch.ones(M) / M\n",
        "    train_loss = sum(loss_func(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Y, gamma = gamma))/num_data\n",
        "    test_loss = sum(loss_func(torch.einsum('mb,m->b',model(Xt,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Yt, gamma = gamma))/num_data_test\n",
        "    if bias:\n",
        "      params = torch.cat([params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2])), params2[0]],dim=1)  \n",
        "    else:\n",
        "      params = params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2]))\n",
        "    for i in range(1,len(params1)):\n",
        "      if bias:\n",
        "        params =torch.cat([params, torch.cat([params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2])), params2[i]],dim=1)],dim=1)\n",
        "      else:\n",
        "        params =torch.cat([params, params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2]))],dim=1)\n",
        "    L2 = (torch.norm(params)**2)/ M\n",
        "    negativeentropy = - continuous.get_h(params, k=5) \n",
        "    record_PDA[0,T] = train_loss\n",
        "    record_PDA[1,T] = L2\n",
        "    record_PDA[2,T] = negativeentropy\n",
        "    record_PDA[3,T] = train_loss + lambda1*L2 + lambda2*negativeentropy\n",
        "    record_PDA[4,T] = test_loss\n",
        "    record_PDA[5,T] = num_grad\n",
        "    if T % 30 == 0:\n",
        "      print(\"Step: %d, Gradient evaluation: %d, Training loss: %f, L2 regularization: %f, Entropy: %f, Regularized loss : %f, Test loss: %f\" % (T,num_grad,train_loss,L2,negativeentropy,train_loss + lambda1*L2 + lambda2*negativeentropy,test_loss))\n",
        "\n",
        "    indices = list(range(num_data))\n",
        "    random.shuffle(indices)\n",
        "    target_coordinate_list =  indices[:T2]\n",
        "    diff = lossgrad(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs),Y)\n",
        "    A[target_coordinate_list] = A[target_coordinate_list] + (T+1) * diff[target_coordinate_list] / T2\n",
        "  return params1, params2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z9Xr_44ZG8o-"
      },
      "source": [
        "##**SGD Algorithm**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2mQMkQsCHAWE"
      },
      "source": [
        "###*Main*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "85POS1fkIWON"
      },
      "outputs": [],
      "source": [
        "def SGD(X,Y,Xt,Yt,g,M,lambda1, lambda2, layer_size,record_SGD, T0 = T0_SGD, T1 = T1_SGD, eta = eta_SGD, gamma= gamma, resampling = False, bias=False, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n",
        "  num_data = X.shape[0]\n",
        "  num_data_test = Xt.shape[0]\n",
        "  num_grad = 0\n",
        "  params2 = None\n",
        "  params1 = [torch.FloatTensor(M,layer_size[i+1],layer_size[i]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n",
        "  if bias:\n",
        "    params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n",
        "  for T in range(T0):\n",
        "    rs = torch.ones(M) / M\n",
        "    train_loss = sum(loss_func(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Y, gamma = gamma))/num_data\n",
        "    test_loss = sum(loss_func(torch.einsum('mb,m->b',model(Xt,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Yt, gamma = gamma))/num_data_test\n",
        "    if bias:\n",
        "      params = torch.cat([params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2])), params2[0]],dim=1)  \n",
        "    else:\n",
        "      params = params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2]))\n",
        "    for i in range(1,len(params1)):\n",
        "      if bias:\n",
        "        params =torch.cat([params, torch.cat([params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2])), params2[i]],dim=1)],dim=1)\n",
        "      else:\n",
        "        params =torch.cat([params, params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2]))],dim=1)\n",
        "    L2 = (torch.norm(params)**2)/ M\n",
        "    negativeentropy = - continuous.get_h(params, k=5) \n",
        "    record_SGD[0,T] = train_loss\n",
        "    record_SGD[1,T] = L2\n",
        "    record_SGD[2,T] = negativeentropy\n",
        "    record_SGD[3,T] = train_loss + lambda1*L2 + lambda2*negativeentropy\n",
        "    record_SGD[4,T] = test_loss\n",
        "    record_SGD[5,T] = num_grad\n",
        "    if T % 600 == 0:\n",
        "      print(\"Step: %d, Gradient evaluation: %d, Training loss: %f, L2 regularization: %f, entropy: %f, Regularized loss : %f, Test loss: %f\" % (T,num_grad,train_loss,L2,negativeentropy,train_loss + lambda1*L2 + lambda2*negativeentropy,test_loss))\n",
        "    \n",
        "    indices = list(range(num_data))\n",
        "    random.shuffle(indices)\n",
        "    target_coordinate_list =  indices[:T1]\n",
        "    grad1, grad2 = model_grad(X, lambda1, lambda2, params1, params2=params2, innervalues=None, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n",
        "    dl = lossgrad(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs),Y)\n",
        "    for i in range(len(params1)):\n",
        "      dU1 =  torch.einsum('mbij,b->mij',grad1[i][:,target_coordinate_list,:,:], dl[target_coordinate_list]) +  2 * lambda1 * params1[i]\n",
        "      params1[i] = params1[i] - eta * dU1\n",
        "      if params2 != None:\n",
        "        dU2 =  torch.einsum('mbi,b->mi',grad2[i][:,target_coordinate_list,:], dl[target_coordinate_list]) +  2 * lambda1 * params2[i]\n",
        "        params2[i] = params2[i] - eta * dU2\n",
        "    num_grad += T1\n",
        "  return params1, params2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nt1o4i6OOD3w"
      },
      "outputs": [],
      "source": [
        "def SGD_NTK(X,Y,Xt,Yt,g,M,lambda1, lambda2, layer_size,record_SGD_NTK, T0 = T0_SGD, T1 = T1_SGD, eta = eta_SGD, gamma= gamma, resampling = False, bias=False, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n",
        "  num_data = X.shape[0]\n",
        "  num_data_test = Xt.shape[0]\n",
        "  num_grad = 0\n",
        "  params2 = None\n",
        "  params1 = [torch.FloatTensor(M,layer_size[i+1],layer_size[i]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n",
        "  if bias:\n",
        "    params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n",
        "  for T in range(T0):\n",
        "    rs = torch.ones(M) / M**0.5\n",
        "    train_loss = sum(loss_func(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Y, gamma = gamma))/num_data\n",
        "    test_loss = sum(loss_func(torch.einsum('mb,m->b',model(Xt,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Yt, gamma = gamma))/num_data_test\n",
        "    if bias:\n",
        "      params = torch.cat([params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2])), params2[0]],dim=1)  \n",
        "    else:\n",
        "      params = params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2]))\n",
        "    for i in range(1,len(params1)):\n",
        "      if bias:\n",
        "        params =torch.cat([params, torch.cat([params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2])), params2[i]],dim=1)],dim=1)\n",
        "      else:\n",
        "        params =torch.cat([params, params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2]))],dim=1)\n",
        "    L2 = (torch.norm(params)**2)/ M\n",
        "    negativeentropy = - continuous.get_h(params, k=5) \n",
        "    record_SGD_NTK[0,T] = train_loss\n",
        "    record_SGD_NTK[1,T] = L2\n",
        "    record_SGD_NTK[2,T] = negativeentropy\n",
        "    record_SGD_NTK[3,T] = train_loss + lambda1*L2 + lambda2*negativeentropy\n",
        "    record_SGD_NTK[4,T] = test_loss\n",
        "    record_SGD_NTK[5,T] = num_grad\n",
        "    if T % 600 == 0:\n",
        "      print(\"Step: %d, Gradient evaluation: %d, Training loss: %f, L2 regularization: %f, entropy: %f, Regularized loss : %f, Test loss: %f\" % (T,num_grad,train_loss,L2,negativeentropy,train_loss + lambda1*L2 + lambda2*negativeentropy,test_loss))\n",
        "    \n",
        "    indices = list(range(num_data))\n",
        "    random.shuffle(indices)\n",
        "    target_coordinate_list =  indices[:T1]\n",
        "    grad1, grad2 = model_grad(X, lambda1, lambda2, params1, params2=params2, innervalues=None, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n",
        "    dl = lossgrad(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs),Y)\n",
        "    for i in range(len(params1)):\n",
        "      dU1 =  torch.einsum('mbij,b->mij',grad1[i][:,target_coordinate_list,:,:], dl[target_coordinate_list]) +  2 * lambda1 * params1[i]\n",
        "      params1[i] = params1[i] - eta * dU1\n",
        "      if params2 != None:\n",
        "        dU2 =  torch.einsum('mbi,b->mi',grad2[i][:,target_coordinate_list,:], dl[target_coordinate_list]) +  2 * lambda1 * params2[i]\n",
        "        params2[i] = params2[i] - eta * dU2\n",
        "    num_grad += T1\n",
        "  return params1, params2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gXA28NmpQqk_",
        "outputId": "d290c974-e5ec-40ab-b116-dfa340712291"
      },
      "outputs": [],
      "source": [
        "def Gaussian(Sigma_X,n):\n",
        "    d,_ = Sigma_X.shape\n",
        "    X = torch.FloatTensor(n,d).normal_().to(device)\n",
        "    return X@Sigma_X #torch.cat([X@Sigma_X,torch.ones(n,1).to(device)],dim=1)\n",
        "\n",
        "psdca_res = []\n",
        "pda_res = []\n",
        "sgd_res = []\n",
        "sgdntk_res = []\n",
        "\n",
        "for i in range(60,90):\n",
        "    print(i)\n",
        "    Wt = Gaussian(torch.eye(d).to(device),Mt).T  \n",
        "    #Wt = torch.ones(d,1).to(device)\n",
        "    Wt = Wt / torch.norm(Wt)\n",
        "    bt = torch.ones(Mt,1).to(device) / Mt\n",
        "    def target(X):\n",
        "        return torch.tanh(X@Wt)@bt\n",
        "\n",
        "    Sigma_X = torch.eye(d).to(device)\n",
        "\n",
        "    X = Gaussian(Sigma_X,n)\n",
        "    Y = target(X)\n",
        "    Xt = Gaussian(Sigma_X,nt)\n",
        "    Yt = target(Xt)\n",
        "\n",
        "\n",
        "    Yt = Yt * (n**0.5 )/ torch.norm(Y)\n",
        "    Y = Y * (n**0.5 )/ torch.norm(Y)\n",
        "    noise = torch.FloatTensor(n,1).normal_(mean=0,std=1.0).to(device)\n",
        "    Y = torch.sign(Y)\n",
        "    Yt = torch.sign(Yt)\n",
        "\n",
        "    print(\"SNR: %f\" % ((torch.norm(Y)/ torch.norm(noise * torch.norm(Y) / torch.norm(noise) / np.sqrt(0.98)))**2))\n",
        "    Y = Y + noise * torch.norm(Y) / torch.norm(noise) / np.sqrt(0.98)\n",
        "    Y = Y[:,0]\n",
        "    Yt = Yt[:,0]\n",
        "\n",
        "    g = torch.zeros_like(Y)\n",
        "    record_PSDCA = torch.zeros(6,T0_PSDCA+1) \n",
        "    params1, params2 = PSDCA(X,Y,Xt,Yt,g,M, lambda1,lambda2, layer_size, record_PSDCA, T0 = T0_PSDCA, T1=T1_PSDCA, T2= T2_PSDCA , eta=eta_PSDCA,resampling = False, bias = True, nonlinearity_scaler=1.0, gamma=0.5)\n",
        "\n",
        "    record_PDA = torch.zeros(6,500+1)\n",
        "    params1, params2 = PDA(X,Y,Xt,Yt,M, lambda1, lambda2, layer_size, record_PDA, T0 = T0_PDA, T1=T1_PDA, T2=T2_PDA,eta=eta_PDA, resampling = False, bias = True, nonlinearity_scaler=1.0, gamma=0.5)\n",
        "\n",
        "    record_SGD = torch.zeros(6,T0_SGD+1)\n",
        "    params1, params2 = SGD(X,Y,Xt,Yt,g,M, lambda1, lambda2, layer_size, record_SGD, T0 = T0_SGD, T1=T1_SGD, eta=eta_SGD, resampling = False, bias = True, nonlinearity_scaler=1.0, gamma=0.5)\n",
        "\n",
        "    record_SGD_NTK = torch.zeros(6,T0_SGD+1)\n",
        "    params1, params2 = SGD_NTK(X,Y,Xt,Yt,g,M, lambda1, lambda2, layer_size, record_SGD_NTK, T0 = T0_SGD, T1=T1_SGD, eta=eta_SGD/10, resampling = False, bias = True, nonlinearity_scaler=1.0, gamma=0.5)\n",
        "\n",
        "    torch.save(record_PSDCA ,f\"./PSDCA_{i}.pth\")\n",
        "    torch.save(record_PDA ,f\"./PDA_{i}.pth\")\n",
        "    torch.save(record_SGD ,f\"./SGD_{i}.pth\")\n",
        "    torch.save(record_SGD_NTK ,f\"./SDG_NTK_{i}.pth\")\n",
        "\n",
        "    psdca_res = psdca_res +[record_PSDCA]\n",
        "    pda_res =pda_res + [record_PDA]\n",
        "    sgd_res = sgd_res + [record_SGD]\n",
        "    sgdntk_res = sgdntk_res + [record_SGD_NTK]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "anGHyBC2SGcT"
      },
      "outputs": [],
      "source": [
        "num_test = 10\n",
        "\n",
        "psdca_ = torch.tensor([psdca_res[i].tolist() for i in range(0,len(psdca_res))])\n",
        "pda_ = torch.tensor([pda_res[i].tolist() for i in range(0,len(psdca_res))])\n",
        "sgd_ = torch.tensor([sgd_res[i].tolist() for i in range(0,len(psdca_res))])\n",
        "sgdntk_ = torch.tensor([sgdntk_res[i].tolist() for i in range(0,len(psdca_res))])\n",
        "\n",
        "psdca = psdca_[:,4,:]\n",
        "mean = sum(psdca)/num_test\n",
        "reg = psdca - mean\n",
        "var = (sum(reg**2)/(num_test**2))**0.5\n",
        "psdca_mean = mean\n",
        "psdca_var = var\n",
        "\n",
        "pda = pda_[:,4,:]\n",
        "mean = sum(pda)/num_test\n",
        "reg = pda - mean\n",
        "var = (sum(reg**2)/(num_test**2))**0.5\n",
        "pda_mean = mean\n",
        "pda_var = var\n",
        "\n",
        "sgd = sgd_[:,4,:]\n",
        "mean = sum(sgd)/num_test\n",
        "reg = sgd - mean\n",
        "var = (sum(reg**2)/(num_test**2))**0.5\n",
        "sgd_mean = mean\n",
        "sgd_var = var\n",
        "\n",
        "sgdntk = sgdntk_[:,4,:]\n",
        "mean = sum(sgdntk)/num_test\n",
        "reg = sgdntk - mean\n",
        "var = (sum(reg**2)/(num_test**2))**0.5\n",
        "sgdntk_mean = mean\n",
        "sgdntk_var = var\n",
        "\n",
        "psdca_iters = psdca_res[0][5]\n",
        "pda_iters = pda_res[0][5]\n",
        "sgd_iters = sgd_res[0][5]\n",
        "sgdntk_iters = sgdntk_res[0][5]\n",
        "\n",
        "fig = plt.figure()\n",
        "FONT_SIZE = 18\n",
        "plt.rc('font',size=FONT_SIZE)\n",
        "\n",
        "ax1 = fig.add_subplot(1,1,1)\n",
        "\n",
        "ax1.set_xlabel( \"Gradient evalueations\", fontsize=16 )\n",
        "ax1.set_ylabel( \"Test loss\", fontsize=16 )\n",
        "\n",
        "ax1.ticklabel_format(style='sci',axis='y',scilimits=(0,0))\n",
        "ax1.ticklabel_format(style='sci',axis='x',scilimits=(0,0))\n",
        "\n",
        "plt.setp(ax1.get_xticklabels(), fontsize=14)\n",
        "plt.setp(ax1.get_yticklabels(), fontsize=14)\n",
        "\n",
        "global_opt = 0\n",
        "\n",
        "ax1.errorbar(sgd_iters[:-2:200], sgd_mean[:-2:200] - global_opt, sgd_var[:-2:200],label = \"SGD (mean field)\",ecolor= (0.4196078431372549, 0.5568627450980392, 0.13725490196078433,0.3), color = (0.4196078431372549, 0.5568627450980392, 0.13725490196078433,1.0))#, capsize=4, ecolor=\"olivedrab\")\n",
        "ax1.errorbar(sgdntk_iters[:-2:200], sgdntk_mean[:-2:200] - global_opt, sgdntk_var[:-2:200],label = \"SGD (NTK)\",ecolor= (0.4196078431372549, 0.5568627450980392, 0.13725490196078433,0.3), color = (0.4196078431372549, 0.5568627450980392, 0.13725490196078433,1.0))#, capsize=4, ecolor=\"olivedrab\")\n",
        "ax1.errorbar(pda_iters[:-2:10], pda_mean[:-2:10] - global_opt, pda_var[:-2],label = \"PDA\",ecolor= (0.27450980392156865, 0.5098039215686274, 0.7058823529411765,0.3), color = (0.27450980392156865, 0.5098039215686274, 0.7058823529411765,1.0))#,color=\"steelblue\")#, capsize=4, ecolor=\"steelblue\")\n",
        "ax1.errorbar(psdca_iters[:-2], psdca_mean[:-2] - global_opt, psdca_var[:-2],label = \"P-SDCA\",ecolor= (0.8627450980392157, 0.0784313725490196, 0.23529411764705882,0.3), color = (0.8627450980392157, 0.0784313725490196, 0.23529411764705882,1.0))#,coler=\"crimson\")#, capsize=4, ecolor=\"crimson\")\n",
        "\n",
        "ax1.xaxis.offsetText.set_fontsize(14)\n",
        "ax1.yaxis.offsetText.set_fontsize(14)\n",
        "\n",
        "ax1.set_ylim([0.1,1.2])\n",
        "ax1.set_xlim([0,300000])\n",
        "\n",
        "ax1.set_yscale('log')\n",
        "ax1.legend(loc=u\"upper right\", prop={'size':15} )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WKrIOrwYSG1l"
      },
      "outputs": [],
      "source": [
        "fig.savefig(\"./exp.pdf\", bbox_inches=\"tight\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Fig1-(b).ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
