{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":5259,"status":"ok","timestamp":1622810577309,"user":{"displayName":"Kazusato Oko","photoUrl":"","userId":"02834413560214203088"},"user_tz":-540},"id":"YQAkPdsJrJwj","outputId":"3bc816a1-171e-4767-c1df-971072fdf3fb"},"outputs":[{"name":"stdout","output_type":"stream","text":["Collecting entropy_estimators\n","  Downloading https://files.pythonhosted.org/packages/ad/bb/c3fbb485e7ab05821c8ac940405b7b87504a86d660a500f8c09acda15ea6/entropy_estimators-0.0.1.tar.gz\n","Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from entropy_estimators) (1.4.1)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from entropy_estimators) (1.19.5)\n","Building wheels for collected packages: entropy-estimators\n","  Building wheel for entropy-estimators (setup.py) ... \u001b[?25l\u001b[?25hdone\n","  Created wheel for entropy-estimators: filename=entropy_estimators-0.0.1-cp37-none-any.whl size=8314 sha256=983b8cc4170390b4fa91ead314aafb5510569ceead6952e30ef9779fa2454429\n","  Stored in directory: /root/.cache/pip/wheels/9e/24/8c/01f7df763748d2cfe45e323cc02e9d575e81c25b94fe35d147\n","Successfully built entropy-estimators\n","Installing collected packages: entropy-estimators\n","Successfully installed entropy-estimators-0.0.1\n"]}],"source":["!pip install entropy_estimators"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"tiB9XL48hKSc"},"outputs":[],"source":["import numpy as np\n","import numpy.random as npr\n","import math\n","import scipy\n","from scipy import special\n","from scipy import linalg\n","\n","from tqdm import tqdm\n","from google.colab import files\n","\n","import matplotlib \n","from matplotlib import pyplot as plt\n","from mpl_toolkits import mplot3d\n","\n","import sklearn\n","from sklearn import datasets\n","\n","from entropy_estimators import continuous\n","\n","import torch\n","import torch.nn as nn\n","device = torch.device('cpu')\n","\n","import random\n","import time\n","import datetime"]},{"cell_type":"markdown","metadata":{"id":"iCYz7MHRFwhJ"},"source":["## **Parameters**"]},{"cell_type":"markdown","metadata":{"id":"rGJ2mW1vhQSH"},"source":["###*Regularization*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"nHJzzS-ohT1w"},"outputs":[],"source":["lambda1 = 1e-3\n","lambda2 = 1e-3\n","gamma = 1"]},{"cell_type":"markdown","metadata":{"id":"qPnEc7nlLs1l"},"source":["###*Particle-SDCA*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"VyavJtLyL1ex"},"outputs":[],"source":["T0_PSDCA = 50 #total number of outer-loop iterations\n","T1_PSDCA = 200 #sampling\n","T2_PSDCA = 30 #batch size\n","eta_PSDCA = 1e-3 #step size"]},{"cell_type":"markdown","metadata":{"id":"yYhXWhjM8DzD"},"source":["###*PDA*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"x90ijrwo8GuL"},"outputs":[],"source":["T0_PDA = 50 #total number of outer-loop iterations\n","T1_PDA = 200 #sampling\n","T2_PDA = 30 #batch size\n","eta_PDA = 1e-3 #step size"]},{"cell_type":"markdown","metadata":{"id":"qLq1pXQkK72o"},"source":["###*SGD*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ATK0h2OYK_LU"},"outputs":[],"source":["T0_SGD = 100000 #total number of outer-loop iterations\n","T1_SGD = 30 #batch size\n","eta_SGD = 1e-3 #step size"]},{"cell_type":"markdown","metadata":{"id":"UXUjDe-9S4nh"},"source":["###*Model*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ijVg6QXQFzvQ"},"outputs":[],"source":["layer_size = [10,1]\n","M = 200\n","nonlinearity = \"tanh\"\n","nonlinearity_scaler = 1\n","bias = False"]},{"cell_type":"markdown","metadata":{"id":"SrQtj07sS_2p"},"source":["###*Data*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ujWGaa3GSbSW"},"outputs":[],"source":["n = 100\n","nt = 1000\n","Data = \"Gaussian_to_NN\" \n","teacher_layer_size = [10, 1] \n","num_neuron_teacher = 1\n","nonlinearity_teacher = \"relu\" \n","nonlinearity_scaler_teacher = 1 \n","bias_teacher = False \n","sigma_noise = 1/2"]},{"cell_type":"markdown","metadata":{"id":"a2MXORIfvJQK"},"source":["###*Others*"]},{"cell_type":"markdown","metadata":{"id":"-rgaWQ6zTvin"},"source":["###Check and Preprocess"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-3XDPSKvT-Qq"},"outputs":[],"source":["if Data == \"Gaussian_to_NN\":\n","  if layer_size[0] != teacher_layer_size[0]:\n","    print(\"input data size does not much between Model and Data\")\n","  if layer_size[-1] != teacher_layer_size[-1]:\n","    print(\"output data size does not much between Model and Data\")   "]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-8wEFb1uXM4r"},"outputs":[],"source":["if Data in [\"Gaussian_to_NN\"]:\n","  classification = False\n","else:\n","  classification = True"]},{"cell_type":"markdown","metadata":{"id":"o4wpC5JTHp0e"},"source":["## **Model**"]},{"cell_type":"markdown","metadata":{"id":"D2g9T5Q6iTQK"},"source":["###*Neural Network*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-9jc4gwr-RBO"},"outputs":[],"source":["def model(X, lambda1, lambda2, params1, params2 = None, innervalues = False, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n","  if nonlinearity==\"tanh\":\n","    def activation_function(x):\n","      return torch.tanh(x) * nonlinearity_scaler\n","  num_layer = len(params1)\n","  num_data = X.shape[0]\n","  innervalues_ = []\n","  forward = X.repeat((params1[0].shape[0], 1, 1))\n","  if params2 == None:\n","    for i in range(0, len(params1)):\n","      innervalues_ = innervalues_ + [forward]\n","      forward = torch.einsum('mbi,mji->mbj', forward, params1[i])\n","      forward = activation_function(forward)\n","  else:\n","    for i in range(0, len(params1)):\n","      innervalues_ = innervalues_ + [forward]\n","      forward = torch.einsum('mbi,mji->mbj', forward, params1[i]) + params2[i].repeat((num_data,1,1)).permute(1,0,2)\n","      forward = activation_function(forward)\n","  if innervalues:\n","    return forward[:,:,0], innervalues_ \n","  else:\n","    return forward[:,:,0] \n","\n","def model_grad(X, lambda1, lambda2, params1, params2=None, innervalues=None, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n","  if innervalues == None:\n","    _, innervalues_ = model(X, lambda1, lambda2, params1, params2 = params2, innervalues = True, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n","  if nonlinearity==\"tanh\":\n","    def activation_function_grad(x):\n","      return (1-torch.tanh(x)**2) * nonlinearity_scaler\n","  num_data = X.shape[0]\n","  grad1 = []\n","  grad2 = []\n","  for i in range(len(params1)):\n","    if params2 == None:\n","      sigma_i_grad = activation_function_grad(torch.einsum('mij,mbj->mbi',params1[i],innervalues_[i]))\n","      for j in range(i):\n","        grad1[j] = torch.einsum('mbi,mbij->mbij',torch.einsum('mbk,mki->mbi',sigma_i_grad ,params1[i]),grad1[j])\n","      grad1 = grad1 + [torch.einsum('mbi,mbj->mbij',sigma_i_grad,innervalues_[i])]\n","    else:\n","      sigma_i_grad = activation_function_grad(torch.einsum('mij,mbj->mbi',params1[i],innervalues_[i])+ params2[i].repeat((num_data,1,1)).permute(1,0,2))\n","      for j in range(i):\n","        grad1[j] = torch.einsum('mbi,mbij->mbij',torch.einsum('mbk,mki->mbi',sigma_i_grad ,params1[i]),grad1[j])\n","        grad2[j] = torch.einsum('mbi,mbi->mbi',torch.einsum('mbk,mki->mbi', sigma_i_grad ,params1[i]),grad2[j])   \n","      grad1 = grad1 + [torch.einsum('mbi,mbj->mbij',sigma_i_grad,innervalues_[i])]  \n","      grad2 = grad2 + [sigma_i_grad]\n","  return grad1, grad2"]},{"cell_type":"markdown","metadata":{"id":"ei5MhH8eegpk"},"source":["##**Data**"]},{"cell_type":"markdown","metadata":{"id":"yvDGaJmliWw8"},"source":["###*Neural Network, whose input and parameters are Gaussian (or whatever you designate)*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"927MI-UgN_8k"},"outputs":[],"source":["def noise(X,sigma):\n","  if X.dim() == 1:\n","    return torch.FloatTensor(X.shape[0]).normal_(mean=0,std=sigma).to(device)\n","  elif X.dim() == 2:\n","    return torch.FloatTensor(X.shape[0], X.shape[1]).normal_(mean=0,std=sigma).to(device)\n","  elif X.dim() == 3:\n","    return torch.FloatTensor(X.shape[0], X.shape[1], X.shape[2]).normal_(mean=0,std=sigma).to(device)\n","  else:\n","    print(\"Error(noise): input X has an incorrect dimension or is not a tensor.\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"4SDQzpOsemGv"},"outputs":[],"source":["def Gaussian_to_NN(num_data, num_data_test, num_neuron_teacher, teacher_layer_size, Sigma_X = None, sigma_noise = 0.1, W_list = None, bias_list_teacher = None, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n","  if (bias_list_teacher != None) and (bias_teacher == False):\n","    print(\"Error(Gaussian_to_NN): bias_teacher=False but bias_list_teacher exists\")\n","  if Sigma_X == None:\n","    Sigma_X = torch.eye(teacher_layer_size[0]).to(device)\n","  X = torch.FloatTensor(num_data,teacher_layer_size[0]).normal_().to(device) @ Sigma_X\n","  Xt = torch.FloatTensor(num_data_test,teacher_layer_size[0]).normal_().to(device) @ Sigma_X\n","  Y = X.repeat((num_neuron_teacher, 1, 1))\n","  Yt = Xt.repeat((num_neuron_teacher, 1, 1))\n","  if nonlinearity_teacher==\"tanh\":\n","    def activation_function(x):\n","      return torch.tanh(x) * nonlinearity_scaler_teacher\n","  if nonlinearity_teacher in [\"relu\" , \"ReLU\" , \"RELU\"]:\n","    def activation_function(x):\n","      return torch.relu(x) * nonlinearity_scaler_teacher\n","  for i in range(len(teacher_layer_size)-1):\n","    if W_list == None:\n","      W = torch.FloatTensor(num_neuron_teacher, teacher_layer_size[i],teacher_layer_size[i+1]).normal_().to(device) \n","      W_weight = torch.norm(W)**2 / num_neuron_teacher\n","      if bias_teacher:\n","        B = torch.FloatTensor(num_neuron_teacher, teacher_layer_size[i+1]).normal_().to(device)\n","      W = W/ (torch.norm(W) / num_neuron_teacher**0.5)\n","    else:\n","      W = W_list[i] \n","    if bias_list_teacher != None:\n","      B = bias_list_teacher[i]\n","    if bias_teacher:\n","      Y = activation_function(torch.einsum('mni,mij->mnj', Y, W) + B.repeat((num_data,1,1)).permute(1,0,2))\n","      Yt = activation_function(torch.einsum('mni,mij->mnj', Yt, W)+ B.repeat((num_data_test,1,1)).permute(1,0,2))\n","    else:\n","      Y = activation_function(torch.einsum('mni,mij->mnj', Y, W))\n","      Yt = activation_function(torch.einsum('mni,mij->mnj', Yt, W))\n","  Y = torch.einsum('mni,m->ni', Y, torch.ones(num_neuron_teacher)/num_neuron_teacher) \n","  dY = noise(Y,sigma_noise)\n","  print(\"SNR: %f\" % (torch.norm(Y)**2 / torch.norm(dY)**2))\n","  #Y = Y + dY\n","  Yt = torch.einsum('mni,m->ni', Yt, torch.ones(num_neuron_teacher)/num_neuron_teacher)\n","  return X, Y[:,0], Xt, Yt[:,0]"]},{"cell_type":"markdown","metadata":{"id":"C76XQ1_9inrt"},"source":["###*Loss functions and corresponding functions for conjugate, initialization, update*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"VcjROqSbhiP6"},"outputs":[],"source":["def MSEloss(Y,Z, gamma=1):\n","  return (Y-Z)**2/(2*gamma)\n","\n","def MSEconjugate(G,Y, gamma=1):\n","  return G**2 *gamma/2 + G*Y\n","\n","def MSEupdate(integral, n,y_i, g_i_t,max_iter_update = 0, gamma=1, eps = 1e-18):\n","  return (integral - y_i + g_i_t/(n*lambda2) )/(gamma+1/(n*lambda2))\n","\n","def MSEgrad(Y,Z, gamma=1): #Y: variable, Z: constant\n","  return (Y-Z)/gamma\n","\n","def MSEconjugategrad(G,Y, gamma=1):\n","  return G * gamma + Y"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"guzCJlJwWMV-"},"outputs":[],"source":["loss_func = MSEloss\n","lossconjugate = MSEconjugate\n","update = MSEupdate\n","lossgrad = MSEgrad\n","lossconjugategrad = MSEconjugategrad"]},{"cell_type":"markdown","metadata":{"id":"h6l6IPRKg7bx"},"source":["##**Particle-SDCA Algorithm**"]},{"cell_type":"markdown","metadata":{"id":"dqvRwSHMLhfQ"},"source":["###*Sampling*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"XYHjo5lTUOS-"},"outputs":[],"source":["def LMC_PSDCA(g, X, lambda1, lambda2, params1, params2 = None, eta=eta_PSDCA, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n","  num_data = X.shape[0]\n","  grad1, grad2 = model_grad(X, lambda1, lambda2, params1, params2=params2, innervalues=None, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler) \n","  for i in range(len(params1)):\n","    dU1 =  torch.einsum('mbij,b->mij',grad1[i],g) / (num_data*lambda2)  +  (2 * lambda1 / lambda2) * params1[i]\n","    params1[i] = params1[i] - eta * dU1 + np.sqrt(2*eta) * torch.FloatTensor(params1[i].shape[0],params1[i].shape[1],params1[i].shape[2]).normal_().to(device)\n","    if params2 != None:\n","      dU2 =  torch.einsum('mbi,b->mi',grad2[i],g) / (num_data*lambda2)  +  (2 * lambda1 / lambda2) * params2[i]\n","      params2[i] = params2[i] - eta * dU2 + np.sqrt(2*eta) * torch.FloatTensor(params2[i].shape[0],params2[i].shape[1]).normal_().to(device)\n","  return params1, params2"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"fqRr1eL1e2__"},"outputs":[],"source":["def sample_PSDCA(g, X, M, layer_size, lambda1, lambda2, T1, method=LMC_PSDCA, params1=None, params2=None, bias = False, eta=eta_PSDCA, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n","  if (bias == False) and params2 != None:\n","    print(\"Error(sample): bias=False but params2 exists\")\n","  if params1 == None:\n","    params1 = [torch.FloatTensor(M,layer_size[i+1],layer_size[i]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n","    if bias:\n","      params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n","  for t in range(T1):\n","    params1, params2  = method(g, X, lambda1, lambda2, params1, params2 =params2, eta=eta, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n","  return params1, params2"]},{"cell_type":"markdown","metadata":{"id":"HeS6qOPbeXAm"},"source":["###*Main*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"shDCAoElhEPn"},"outputs":[],"source":["record_PSDCA = torch.zeros(6,T0_PSDCA+1) \n","###[0] : training loss\n","###[1] : L2\n","###[2] : negativeentropy\n","###[3] : training loss with regularization\n","###[4] : test loss\n","###[5] : number of gradient evaluation"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"srwWMPEDecB-"},"outputs":[],"source":["def PSDCA(X,Y,Xt,Yt,g,M,lambda1, lambda2, layer_size,record_PSDCA, T0 = T0_PSDCA, T1 = T1_PSDCA, T2 = T2_PSDCA, eta = eta_PSDCA, gamma= gamma, resampling = False, bias=False, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n","  num_data = X.shape[0]\n","  num_data_test = Xt.shape[0]\n","  num_grad = 0\n","  params1 = None\n","  params2 = None\n","  for T in range(0,T0):\n","    if resampling:\n","      params1 = None\n","      params2 = None\n","    params1, params2 = sample_PSDCA(g, X, M,layer_size,lambda1,lambda2, T1, params1=params1, params2 = params2, bias = bias, eta=eta, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n","    num_grad += T1* num_data\n","    rs = torch.ones(M) / M\n","    train_loss = sum(loss_func(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Y, gamma = gamma))/num_data\n","    test_loss = sum(loss_func(torch.einsum('mb,m->b',model(Xt,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Yt, gamma = gamma))/num_data_test\n","    if bias:\n","      params = torch.cat([params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2])), params2[0]],dim=1)  \n","    else:\n","      params = params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2]))\n","    for i in range(1,len(params1)):\n","      if bias:\n","        params =torch.cat([params, torch.cat([params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2])), params2[i]],dim=1)],dim=1)\n","      else:\n","        params =torch.cat([params, params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2]))],dim=1)\n","    L2 = (torch.norm(params)**2)/ M\n","    negativeentropy = - continuous.get_h(params, k=5) \n","    record_PSDCA[0,T] = train_loss\n","    record_PSDCA[1,T] = L2\n","    record_PSDCA[2,T] = negativeentropy\n","    record_PSDCA[3,T] = train_loss + lambda1*L2 + lambda2*negativeentropy\n","    record_PSDCA[4,T] = test_loss\n","    record_PSDCA[5,T] = num_grad\n","    if T % 10 == 0:\n","      print(\"Step: %d, Gradient evaluation: %d, Training loss: %f, L2 regularization: %f, Entropy %f, Regularized loss : %f, Test loss: %f\" % (T,num_grad,train_loss,L2,negativeentropy,train_loss + lambda1*L2 + lambda2*negativeentropy,test_loss))\n","\n","    indices = list(range(num_data))\n","    random.shuffle(indices) \n","    for i_t in indices[:T2]:\n","        integral = sum(model(X[i_t:i_t+1,:],lambda1, lambda2, params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)[:,0] * rs) /sum(rs)\n","        delg = - g[i_t]\n","        g[i_t] = update(integral, num_data, Y[i_t], g[i_t])\n","        delg += g[i_t]\n","        rs = rs * torch.exp(- model(X[i_t:i_t+1,:], lambda1, lambda2, params1,params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)[:,0]*delg/(num_data*lambda2))\n","        rs = rs/sum(rs)\n","  return params1, params2"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"o4ncjiZu5t0L"},"outputs":[],"source":["X, Y, Xt, Yt = Gaussian_to_NN(n,nt, num_neuron_teacher,teacher_layer_size, sigma_noise=sigma_noise)\n","g = torch.zeros_like(Y)\n","PSDCA(X,Y,Xt,Yt,g,M, lambda1, lambda2, layer_size, record_PSDCA, T0 = T0_PSDCA, T1=T1_PSDCA, T2= T2_PSDCA,eta=eta_PSDCA ,resampling = False, bias = bias)"]},{"cell_type":"markdown","metadata":{"id":"tn7WPap4wvHc"},"source":["##**PDA Algorithm**"]},{"cell_type":"markdown","metadata":{"id":"W_EtVPbx2g_u"},"source":["###*Sampling*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"moNVrUnAxSmH"},"outputs":[],"source":["def LMC_PDA(g, X, A, t,  lambda1, lambda2, params1, params2 = None, eta=eta_PDA, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n","  num_data = X.shape[0]\n","  grad1, grad2 = model_grad(X, lambda1, lambda2, params1, params2=params2, innervalues=None, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler) \n","  for i in range(len(params1)):\n","    dU1 = 2* torch.einsum('mbij,b->mij',grad1[i],A) / (lambda2*(t+2)*(t+3))  +  2 * lambda1 * (t+1) / (lambda2*(t+3)) * params1[i]\n","    params1[i] = params1[i] - eta * dU1 + np.sqrt(2*eta) * torch.FloatTensor(params1[i].shape[0],params1[i].shape[1],params1[i].shape[2]).normal_().to(device)\n","    if params2 != None:\n","      dU2 = 2* torch.einsum('mbi,b->mi',grad2[i],A) / (lambda2*(t+2)*(t+3))  +  2 * lambda1 * (t+1) / (lambda2*(t+3))  * params2[i]\n","      params2[i] = params2[i] - eta * dU2 + np.sqrt(2*eta) * torch.FloatTensor(params2[i].shape[0],params2[i].shape[1]).normal_().to(device)\n","  return params1, params2"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"S2gNVW8F2gRp"},"outputs":[],"source":["def sample_PDA(g, X, A, t, M, layer_size, lambda1, lambda2, T1, method=LMC_PDA, params1=None, params2=None, bias = False, eta=eta_PDA, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n","  if (bias == False) and params2 != None:\n","    print(\"Error(sample): bias=False but params2 exists\")\n","  if params1 == None:\n","    params1 = [torch.FloatTensor(M,layer_size[i+1],layer_size[i]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n","    if bias:\n","      params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n","  for i in range(T1):\n","    params1, params2  = method(g, X, A, t,lambda1, lambda2, params1, params2 =params2, eta=eta, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n","  return params1, params2"]},{"cell_type":"markdown","metadata":{"id":"M8l--1ZQk0dl"},"source":["###*Main*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Hc9B4Uvyk5Zd"},"outputs":[],"source":["record_PDA = torch.zeros(6,T0_PDA+1) "]},{"cell_type":"code","execution_count":null,"metadata":{"id":"GI2DSJvWk9DU"},"outputs":[],"source":["def PDA(X,Y,Xt,Yt,M,lambda1, lambda2, layer_size,record_PDA, T0 = T0_PDA, T1 = T1_PDA, T2 = T2_PDA, eta = eta_PDA, gamma= gamma, resampling = False, bias=False, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n","  num_data = X.shape[0]\n","  num_data_test = Xt.shape[0]\n","  num_grad = 0\n","  params1 = None\n","  params2 = None\n","  A = torch.zeros(num_data)\n","  for T in range(0,T0):\n","    if resampling:\n","      params1 = None\n","      params2 = None\n","    params1, params2 = sample_PDA(g, X, A, T ,M,layer_size,lambda1,lambda2, T1, params1=params1, params2 = params2, bias = bias, eta=eta, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n","    num_grad += T1* num_data\n","    rs = torch.ones(M) / M\n","    train_loss = sum(loss_func(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Y, gamma = gamma))/num_data\n","    test_loss = sum(loss_func(torch.einsum('mb,m->b',model(Xt,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Yt, gamma = gamma))/num_data_test\n","    if bias:\n","      params = torch.cat([params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2])), params2[0]],dim=1)  \n","    else:\n","      params = params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2]))\n","    for i in range(1,len(params1)):\n","      if bias:\n","        params =torch.cat([params, torch.cat([params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2])), params2[i]],dim=1)],dim=1)\n","      else:\n","        params =torch.cat([params, params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2]))],dim=1)\n","    L2 = (torch.norm(params)**2)/ M\n","    negativeentropy = - continuous.get_h(params, k=5) \n","    record_PDA[0,T] = train_loss\n","    record_PDA[1,T] = L2\n","    record_PDA[2,T] = negativeentropy\n","    record_PDA[3,T] = train_loss + lambda1*L2 + lambda2*negativeentropy\n","    record_PDA[4,T] = test_loss\n","    record_PDA[5,T] = num_grad\n","    if T % 10 == 0:\n","      print(\"Step: %d, Gradient evaluation: %d, Training loss: %f, L2 regularization: %f, Entropy: %f, Regularized loss : %f, Test loss: %f\" % (T,num_grad,train_loss,L2,negativeentropy,train_loss + lambda1*L2 + lambda2*negativeentropy,test_loss))\n","\n","    indices = list(range(num_data))\n","    random.shuffle(indices)\n","    target_coordinate_list =  indices[:T2]\n","    diff = lossgrad(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs),Y)\n","    A[target_coordinate_list] = A[target_coordinate_list] + (T+1) * diff[target_coordinate_list] / T2\n","  return params1, params2"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-DT-v7U7qEWr"},"outputs":[],"source":["PDA(X,Y,Xt,Yt,M, lambda1, lambda2, layer_size, record_PDA, T0 = T0_PDA, T1=T1_PDA, T2= T2_PDA,eta=eta_PSDCA ,resampling = False, bias = bias)"]},{"cell_type":"markdown","metadata":{"id":"z9Xr_44ZG8o-"},"source":["##**SGD Algorithm**"]},{"cell_type":"markdown","metadata":{"id":"2mQMkQsCHAWE"},"source":["###*Main*"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"pKfFIwODL8Iu"},"outputs":[],"source":["record_SGD = torch.zeros(6,T0_SGD+1) "]},{"cell_type":"code","execution_count":null,"metadata":{"id":"85POS1fkIWON"},"outputs":[],"source":["def SGD(X,Y,Xt,Yt,g,M,lambda1, lambda2, layer_size,record_SGD, T0 = T0_SGD, T1 = T1_SGD, eta = eta_SGD, gamma= gamma, resampling = False, bias=False, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler):\n","  num_data = X.shape[0]\n","  num_data_test = Xt.shape[0]\n","  num_grad = 0\n","  params2 = None\n","  params1 = [torch.FloatTensor(M,layer_size[i+1],layer_size[i]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n","  if bias:\n","    params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]\n","  for T in range(T0):\n","    rs = torch.ones(M) / M\n","    train_loss = sum(loss_func(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Y, gamma = gamma))/num_data\n","    test_loss = sum(loss_func(torch.einsum('mb,m->b',model(Xt,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs), Yt, gamma = gamma))/num_data_test\n","    if bias:\n","      params = torch.cat([params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2])), params2[0]],dim=1)  \n","    else:\n","      params = params1[0].reshape((-1,params1[0].shape[1]*params1[0].shape[2]))\n","    for i in range(1,len(params1)):\n","      if bias:\n","        params =torch.cat([params, torch.cat([params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2])), params2[i]],dim=1)],dim=1)\n","      else:\n","        params =torch.cat([params, params1[i].reshape((-1,params1[i].shape[1]*params1[i].shape[2]))],dim=1)\n","    L2 = (torch.norm(params)**2)/ M\n","    negativeentropy = - continuous.get_h(params, k=5) \n","    record_SGD[0,T] = train_loss\n","    record_SGD[1,T] = L2\n","    record_SGD[2,T] = negativeentropy\n","    record_SGD[3,T] = train_loss + lambda1*L2 + lambda2*negativeentropy\n","    record_SGD[4,T] = test_loss\n","    record_SGD[5,T] = num_grad\n","    if T % 100 == 0:\n","      print(\"Step: %d, Gradient evaluation: %d, Training loss: %f, L2 regularization: %f, entropy: %f, Regularized loss : %f, Test loss: %f\" % (T,num_grad,train_loss,L2,negativeentropy,train_loss + lambda1*L2 + lambda2*negativeentropy,test_loss))\n","    \n","    indices = list(range(num_data))\n","    random.shuffle(indices)\n","    target_coordinate_list =  indices[:T1]\n","    grad1, grad2 = model_grad(X, lambda1, lambda2, params1, params2=params2, innervalues=None, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler)\n","    dl = lossgrad(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity=nonlinearity, nonlinearity_scaler=nonlinearity_scaler),rs),Y)\n","    for i in range(len(params1)):\n","      dU1 =  torch.einsum('mbij,b->mij',grad1[i][:,target_coordinate_list,:,:], dl[target_coordinate_list]) +  2 * lambda1 * params1[i]\n","      params1[i] = params1[i] - eta * dU1\n","      if params2 != None:\n","        dU2 =  torch.einsum('mbi,b->mi',grad2[i][:,target_coordinate_list,:], dl[target_coordinate_list]) +  2 * lambda1 * params2[i]\n","        params2[i] = params2[i] - eta * dU2\n","    num_grad += T1\n","  return params1, params2"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9l8_h2FmKz52"},"outputs":[],"source":["params1, params2 = SGD(X,Y,Xt,Yt,g,M, lambda1, lambda2, layer_size, record_SGD, T0 = T0_SGD, T1=T1_SGD, eta=eta_SGD ,resampling = False, bias = bias)"]},{"cell_type":"markdown","metadata":{"id":"-41cNTtMwn5D"},"source":["#**Figure**"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"g0n1ZlFOAmOk"},"outputs":[],"source":["PSDCA_opt_train = min(record_PSDCA[3,:-2])\n","PDA_opt_train = min(record_PDA[3,:-2])\n","SGD_opt_train = min(record_SGD[3,:-2])\n","global_opt_train = min(torch.tensor([PSDCA_opt_train, PDA_opt_train, SGD_opt_train]))\n","\n","fig = plt.figure()\n","FONT_SIZE = 18\n","plt.rc('font',size=FONT_SIZE)\n","\n","ax1 = fig.add_subplot(1,1,1)\n","\n","ax1.set_xlabel( \"Gradient evalueations\", fontsize=16 )\n","ax1.set_ylabel( \"Excess training loss\", fontsize=16 )\n","\n","ax1.ticklabel_format(style='sci',axis='y',scilimits=(0,0))\n","ax1.ticklabel_format(style='sci',axis='x',scilimits=(0,0))\n","\n","plt.setp(ax1.get_xticklabels(), fontsize=14)\n","plt.setp(ax1.get_yticklabels(), fontsize=14)\n","\n","ax1.plot(record_PSDCA[5,:], record_PSDCA[3,] -  global_opt_train , linewidth=2, color=\"crimson\", label=\"P-SDCA\")\n","ax1.plot(record_PDA[5,:],record_PDA[3,:] -  global_opt_train , linewidth=2, color=\"steelblue\", label=\"PDA\" )\n","ax1.plot(record_SGD[5,:],record_SGD[3,:] -  global_opt_train, linewidth=2, color=\"olivedrab\", label=\"SGD\" )\n","\n","ax1.xaxis.offsetText.set_fontsize(14)\n","ax1.yaxis.offsetText.set_fontsize(14)\n","\n","ax1.set_ylim([0.0001,0.5])\n","ax1.set_xlim([0,6000000])\n","\n","ax1.set_yscale('log')\n","ax1.legend(loc=u\"upper right\", prop={'size':15} )"]}],"metadata":{"colab":{"collapsed_sections":[],"name":"supplementary1.ipynb","provenance":[{"file_id":"1bSUYtDIzIqlsDC9SbLA0P28KvQRlZa-x","timestamp":1622611527874},{"file_id":"1WszNn6oC4_z03ebmZF4x3tTrNeGGPlWh","timestamp":1622166988617},{"file_id":"1mt31dXRwIjOzXM3w5paZcX1Vta2Yq0Dr","timestamp":1622143022767},{"file_id":"1TWGfhjiSZsiCbVuXHhJHHh0ggr0wFdE0","timestamp":1622133227006},{"file_id":"1lDOYwQOWw6-V8w6OPQyfbECWvbFfK_95","timestamp":1622034903083}],"toc_visible":true},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}
