{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"YQAkPdsJrJwj"},"outputs":[],"source":["!pip install entropy_estimators"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"8SdngKcCi5h2"},"outputs":[],"source":["import numpy as np\n","import numpy.random as npr\n","import math\n","import scipy\n","from scipy import special\n","from scipy import linalg\n","\n","from tqdm import tqdm\n","from google.colab import files\n","\n","import matplotlib \n","from matplotlib import pyplot as plt\n","from mpl_toolkits import mplot3d\n","\n","import sklearn\n","from sklearn import datasets\n","\n","from entropy_estimators import continuous\n","\n","import torch\n","import torch.nn as nn\n","device = torch.device('cpu')\n","\n","import random\n","import time\n","import datetime"]},{"cell_type":"markdown","metadata":{"id":"FRIh6-PIldEg"},"source":["#Basic functions and hyperparameters"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"rK9lmXf9rA-V"},"outputs":[],"source":["# Gaussian input\n","def Gaussian(Sigma_X,n=1):\n","    d,_ = Sigma_X.shape\n","    X = torch.FloatTensor(n,d).normal_().to(device)\n","    return torch.cat([X@Sigma_X,torch.ones(n,1).to(device)],dim=1) # one extra dimension for the bias unit\n","\n","\n","# activation functions\n","def tanh(X):\n","    return torch.tanh(X)\n","\n","def tanh_grad(X):\n","    return (1 - torch.tanh(X)**2)\n","\n","\n","# Training loss\n","# Y1 - true labels\n","# Y2 - model prediction\n","def mse(Y1,Y2):\n","    return torch.mean((Y1-Y2)**2) / 2.\n","\n","def mse_grad(Y1,Y2):\n","    return (Y1 - Y2)\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-6qFoFC3OhV_"},"outputs":[],"source":["lambda_1 = 1e-4 # L2 regularization\n","lambda_2 = 1e-4 # Entropy\n","\n","#particles\n","M = 200\n","\n","# training set size\n","n = 500\n","# test set size\n","n_test = 1000\n","#input dimension\n","d = 50\n","\n","#noise\n","sigma = 1/4"]},{"cell_type":"markdown","metadata":{"id":"Us-WJBBAzijJ"},"source":["# Data Distribution"]},{"cell_type":"markdown","metadata":{"id":"7vPKxXFO1Nm5"},"source":["**Regression**"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Rq9NIpXc1NHG"},"outputs":[],"source":["classification = False\n","\n","loss_fn = mse\n","loss_grad = mse_grad\n","\n","mt = 1\n","Wt = Gaussian(torch.eye(d).to(device),mt).T \n","Wt = Wt / torch.norm(Wt)\n","bt = torch.ones(mt,1).to(device) / mt \n","def target(X):\n","    return torch.relu(X@Wt)@bt\n","\n","Sigma_X = torch.eye(d).to(device)\n","X = Gaussian(Sigma_X,n)\n","Y = target(X)\n","Xt = Gaussian(Sigma_X,n_test)\n","Yt = target(Xt)\n","noise = torch.FloatTensor(n,1).normal_(mean=0,std=sigma).to(device)\n","print(\"SNR: %f\" % (torch.norm(Y)**2 / torch.norm(noise)**2))\n","Y = Y + noise"]},{"cell_type":"markdown","metadata":{"id":"T57QVlk7INrw"},"source":["# Student Model"]},{"cell_type":"markdown","metadata":{"id":"tnZLa0aKA4np"},"source":["**Sampling**"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ysKNWhUZGSPm"},"outputs":[],"source":["def MSEloss(Y,Z,gamma = 1):\n","  return (Y-Z)*(Y-Z)/(2*gamma)\n","\n","def MSEconjugate(G,Y,gamma = 1):\n","  return (G*G)*gamma/2 + G*Y\n","\n","def MSEinit(Y):\n","  return - Y\n","\n","def MSEupdate(integral, n, lambda_2, y_i, g_i_t):\n","  return ( integral - y_i + g_i_t/(n*lambda_2) )/( 1+1/(n*lambda_2) )"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Q5tsPVwt9QNn"},"outputs":[],"source":["def model1(X, params):\n","  A_1 = params \n","  return torch.t(torch.tanh(X@A_1)) \n","\n","def model1_grad(X):\n","  return (1 - torch.tanh(X)**2)\n","\n","def nabla_U(g, X, p):\n","  n = g.shape[0]\n","  g = g.reshape(n,1)\n","  dg = model1_grad(X@p)  \n","  part1 = torch.t(X)@(g*dg) / (n * lambda_2)\n","  part2 = (2 * lambda_1 / lambda_2) * p \n","\n","  return part1 + part2"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"pBXmFLk2wj0z"},"outputs":[],"source":["def LMC(g, p, X, dU, eta=0.001):\n","  d,M = p.shape\n","  Bt = torch.FloatTensor(d,M).normal_().to(device)\n","  p_ = p - eta * dU(g, X, p) + np.sqrt(2*eta) * Bt\n","  return p_\n","\n","def MALA(g, p, X, dU, eta=0.001):\n","  d,M = p.shape\n","  Bt = torch.FloatTensor(d,M).normal_().to(device)\n","  p_ = p - eta * dU(g, X, p) + np.sqrt(2*eta) * Bt\n","  alpha = torch.exp(-sum((torch.tanh(X@p_) - torch.tanh(X@p))*g)/(n*lambda_2) - lambda_1*(sum(p_**2) - sum(p**2))/lambda_2 -( sum((p-p_ + eta * dU(g, X, p_))**2) - sum((np.sqrt(2*eta) * Bt)**2) )/(4*eta))\n","  u = torch.rand(100)\n","  alpha[alpha<u] = 0\n","  alpha[alpha>0] = 1\n","  p_ = p + (p_-p)*alpha\n","  return p_\n","\n","def sample(g,X,d,M,T1,method=LMC,params=None):\n","  if params == None:  \n","    params = torch.FloatTensor(d,M).normal_(mean=0,std=(lambda_2/(2*lambda_1))**0.5)\n","\n","  for t in range(T1):\n","    params = method(g,params,X,nabla_U)\n","  return params"]},{"cell_type":"markdown","metadata":{"id":"5oat6V5HIvnj"},"source":["**PSDCA Algorithm**"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"VKYl6pZ423yV"},"outputs":[],"source":["psdca_iters=[]\n","psdca_train_loss=[]\n","psdca_test_loss=[]\n","negativeentropy = []\n","L2 = []\n","\n","# ----------------------PSDCA--------------------------\n","def psdca(X,Y,g,T0,M,Xt,Yt,classification=False):\n","\n","  n = X.shape[0]\n","  d = X.shape[1]\n","  nt = Xt.shape[0]\n","  g = torch.reshape(g,(-1,))\n","  Y = torch.reshape(Y,(-1,))\n","  print('n: %d, d: %d' % (n,d) )\n","\n","  T1 = 500\n","  tot_iters = 0\n","  params = None\n","  for T in range(0,T0):\n","    \n","\n","    params = sample(g,X,d,M,T1,LMC,params) \n","    rs = torch.ones(M) / M\n","    \n","    tot_iters += T1 * n\n","    psdca_iters.append( tot_iters )\n","\n","    loss = sum(MSEloss(sum(model1(X,params))/M,Y))/n\n","    ent_est = - continuous.get_h(torch.t(params), k=5)  \n","    negativeentropy.append(ent_est)\n","    regularization =  lambda_1 * torch.norm(params)**2 / M + lambda_2 * ent_est\n","    L2.append( lambda_1 * torch.norm(params)**2 / M)\n","    psdca_train_loss.append( loss + regularization )\n","    psdca_test_loss.append( sum(MSEloss(sum(model1(Xt,params))/M,Yt[:,0]))/nt )\n","\n","    if T % 10 == 0:\n","      print(\"Step: %d, Train: %f, Test: %f\" % (tot_iters,psdca_train_loss[-1],psdca_test_loss[-1]))\n","      print(\"EntEst: %f, Loss: %f\" % (ent_est,loss))\n","\n","    indices = list( range(n) )\n","    random.shuffle(indices) \n","    for t in indices[:30]: \n","        i_t = (t,)\n","        integral = sum(model1(X[i_t,:],params)[:,0]*rs)/sum(rs)\n","        delg = - g[i_t]\n","        g[i_t] = MSEupdate(integral, n, lambda_2, Y[i_t], g[i_t])\n","        delg += g[i_t] \n","        rs = rs * torch.exp(- model1(X[i_t,:], params)[:,0]*delg/(n*lambda_2) )\n","        rs = rs/sum(rs)\n","\n","  return"]},{"cell_type":"markdown","metadata":{"id":"9eeC4Wx--uig"},"source":["#Training"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"RTU-FDpJHzXc"},"outputs":[],"source":["T0=12000\n","g = MSEinit(Y)\n","psdca(X,Y,g,T0,M,Xt,Yt)  "]},{"cell_type":"code","execution_count":null,"metadata":{"id":"HLKDH_O_LFY3"},"outputs":[],"source":["train = torch.tensor(psdca_train_loss)\n","test = torch.tensor(psdca_test_loss)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"fcWxe6Gz9_7s"},"outputs":[],"source":["opt_train = min(train)\n","opt_test = min(test)\n","FONT_SIZE = 18\n","plt.rc('font',size=FONT_SIZE)\n","\n","train =  train - opt_train\n","test = test - opt_test\n","max_iter = len(train)\n","\n","train_log = np.log10(list(train+10**(-18)))\n","a, b = np.linalg.lstsq(np.array([[i, 1] for i in range(0,max_iter)]), train_log)[0]\n","train_fit = np.exp(np.log(10)*(a*np.arange(0,max_iter,1)+b)) \n","\n","test_log = np.log10(list(test+10**(-18)))\n","a, b = np.linalg.lstsq(np.array([[i, 1] for i in range(0,max_iter)]), test_log)[0]\n","test_fit = np.exp(np.log(10)*(a*np.arange(0,max_iter,1)+b)) \n","\n","fig = plt.figure()\n","ax1 = fig.add_subplot(111)\n","ax1.plot( np.arange(0,max_iter,1), train_fit ,linewidth=2 ,color='rosybrown',linestyle='--',label = \"train fit\")\n","ax1.plot( np.arange(0,max_iter,1), train ,linewidth=2 ,color='crimson',label = \"train\")\n","ax1.set_yscale('log')\n","plt.setp(ax1.get_xticklabels(), fontsize=14)\n","plt.setp(ax1.get_yticklabels(), fontsize=14)\n","ax1.set_ylim([0.00015,0.30])\n","\n","ax2 = ax1.twinx()\n","ax2.plot( np.arange(0,max_iter,1), test  ,linewidth=1.5 ,color=\"steelblue\",label = \"test\")\n","ax2.plot( np.arange(0,max_iter,1), test_fit ,linewidth=1.5 ,color='lightslategrey',linestyle='--',label = \"test fit\")\n","ax2.set_yscale('log')\n","plt.setp(ax2.get_xticklabels(), fontsize=14)\n","plt.setp(ax2.get_yticklabels(), fontsize=14)\n","ax2.set_ylim([0.00015,0.30])\n","\n","h1, l1 = ax1.get_legend_handles_labels()\n","h2, l2 = ax2.get_legend_handles_labels()\n","\n","ax1.set_zorder(2)\n","ax2.set_zorder(1)\n","ax1.patch.set_alpha(0)\n","\n","ax1.legend([h1[1],h2[0],h1[0],h2[1]], [l1[1],l2[0],l1[0],l2[1]], loc='upper right',prop={'size':15})\n","\n","ax1.set_xlabel('iterations', fontsize=16)\n","ax1.set_ylabel(\"Excess training loss\", fontsize=16)\n","ax2.set_ylabel(\"Excess test loss\", fontsize=16)"]}],"metadata":{"colab":{"collapsed_sections":[],"name":"main1.ipynb","provenance":[{"file_id":"1yZnagfpvAgtDPw35uyjjJ6erDPOUFfZT","timestamp":1622808673790},{"file_id":"1WszNn6oC4_z03ebmZF4x3tTrNeGGPlWh","timestamp":1622166988617},{"file_id":"1mt31dXRwIjOzXM3w5paZcX1Vta2Yq0Dr","timestamp":1622143022767},{"file_id":"1TWGfhjiSZsiCbVuXHhJHHh0ggr0wFdE0","timestamp":1622133227006},{"file_id":"1lDOYwQOWw6-V8w6OPQyfbECWvbFfK_95","timestamp":1622034903083}]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}
