{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"mount_file_id":"1MW0Y820wzXWK-m04iAP1WQaa3fxfda_Y","authorship_tag":"ABX9TyM+B1VOGw9MkyoKQK/TkWb/"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"7fmgRJo4XFDx"},"outputs":[],"source":["import matplotlib.pyplot as plt\n","import torch\n","from torch import nn\n","import numpy as np\n","import tqdm"]},{"cell_type":"markdown","source":["#Initialize all these things"],"metadata":{"id":"vfsdFdedTYyn"}},{"cell_type":"code","source":["# a normal neural network\n","class DefaultNet(nn.Module):\n","  def __init__(self,depth):\n","        super(DefaultNet, self).__init__()\n","        layers = []\n","        layers.append(nn.Linear(1,4))\n","        layers.append(nn.ReLU())\n","        for i in range(depth-1):\n","          layers.append(nn.Linear(4,4))\n","          layers.append(nn.ReLU())\n","        layers.append(nn.Linear(4,1))\n","        self.network= nn.Sequential(*layers)\n","  def forward(self,f):\n","        return self.network(f)\n","  def layer_output(self,f,n):\n","        for i in range(2*n - 1):\n","          f = self.network[i](f)\n","        return f"],"metadata":{"id":"scFUsRZDXkw1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# subtracts curves from y=x for convex functions like x^3 etc\n","# changes the first scaling factor because of this\n","class Phase1LayerConvex(nn.Module):\n","    def __init__(self,d):\n","        super().__init__()\n","        self.d = d # hold locations of triangle tips\n","        self.activation = nn.ReLU()\n","        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n","\n","    def forward(self, x ,y):\n","        with torch.no_grad():\n","          self.weights = self.weights * 0\n","        #set index 0 0 to 1 after scaling is applied\n","\n","        # assign tips to matrix. needs to do it for many at a time, not slow\n","        self.weights[0][1] = -1/x[self.d]\n","        self.weights[0][2] = 1/(x[self.d]-x[self.d]**2)\n","\n","        self.weights[1][1] = 1/x[self.d]\n","        self.weights[1][2] = -1/(x[self.d]-x[self.d]**2)\n","\n","        self.weights[2][1] = 1/x[self.d]\n","        self.weights[2][2] = -1/(x[self.d]-x[self.d]**2)\n","        self.weights[2][3] = -1*x[self.d+1]\n","\n","        self.weights[3][3] = 1\n","        #scaling\n","        if self.d==0: #uses x instead of 1-x because of subtraction from y=x\n","          self.weights*=(x[self.d])*x[self.d+1]\n","        else:\n","          self.weights*=(1-x[self.d])*x[self.d+1]\n","\n","        self.weights[0][0] = 1\n","        return self.activation(y@self.weights.T)\n","\n","class Phase1InputConvex(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.activation = nn.ReLU()\n","        self.weights = torch.tensor([[1],[1],[1],[0]],dtype = torch.float32)\n","        self.bias = torch.tensor([0,0,0,1],dtype = torch.float32)\n","    def forward(self, x, y):\n","        new_bias = self.bias.clone() #big autodiff related bugs without this\n","        new_bias[2] = -1*x[0]\n","        return self.activation(y@self.weights.T+new_bias)\n","\n","class Phase1OutputConvex(nn.Module):\n","    def __init__(self,d):\n","        super().__init__()\n","        self.d = d\n","        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n","    def forward(self, x, y):\n","        with torch.no_grad():\n","          self.weights = self.weights * 0\n","        #set index 0 0 to 1 after scaling is applied\n","        self.weights[0][1] = -1/x[self.d]\n","        self.weights[0][2] = 1/(x[self.d]-x[self.d]**2)\n","        #scaling\n","        if self.d==0:\n","          self.weights*=(x[self.d])*x[self.d+1]\n","        else:\n","          self.weights*=(1-x[self.d])*x[self.d+1]\n","\n","        self.weights[0][0] = 1\n","        return y@self.weights.T\n","\n","class Phase1Convex(nn.Module):\n","  def __init__(self,depth):\n","        super(Phase1Convex, self).__init__()\n","        self.x = nn.Parameter(torch.tensor(np.random.uniform(0,1,depth+1),dtype=torch.float32))\n","        self.layers = []\n","        self.layers.append(Phase1InputConvex())\n","        for i in range(depth-1):\n","          self.layers.append(Phase1LayerConvex(i))\n","        self.layers.append(Phase1OutputConvex(depth-1))\n","  def forward(self,f):\n","        with torch.no_grad():\n","          self.x[:]=self.x.clamp(0.01,0.99)\n","        for l in self.layers:\n","          f = l.forward(self.x,f)\n","        return f\n","  def layer_output(self,f,n):\n","        for i in range(n):\n","          f = self.layers[i].forward(self.x,f)\n","        return f"],"metadata":{"id":"lJT7wWZnXkzj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class Phase2LayerConvex(nn.Module):\n","    def __init__(self,d):\n","        super().__init__()\n","        self.d = d\n","        self.activation = nn.ReLU()\n","        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n","    def forward(self, x ,y, f):\n","        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n","        #set index 0 0 to 1 after scaling is applied\n","        self.weights[0][1] = -1/x[self.d]\n","        self.weights[0][2] = 1/(x[self.d]-x[self.d]**2)\n","\n","        self.weights[1][1] = 1/x[self.d]\n","        self.weights[1][2] = -1/(x[self.d]-x[self.d]**2)\n","\n","        self.weights[2][1] = 1/x[self.d]\n","        self.weights[2][2] = -1/(x[self.d]-x[self.d]**2)\n","        self.weights[2][3] = -1*x[self.d+1]\n","\n","        self.weights[3][3] = 1\n","        #scaling\n","        self.weights*=y[self.d]\n","\n","        self.weights[0][0] = 1\n","        return self.activation(f@self.weights.T)\n","\n","class Phase2InputConvex(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.activation = nn.ReLU()\n","        self.weights = torch.tensor([[1],[1],[1],[0]],dtype = torch.float32)\n","        self.bias = torch.tensor([0,0,0,1],dtype = torch.float32)\n","    def forward(self, x, y, f):\n","        new_bias = self.bias.clone() #big autodiff related bugs without this\n","        new_bias[2] = -1*x[0]\n","        return self.activation(f@self.weights.T+new_bias)\n","\n","class Phase2OutputConvex(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n","    def forward(self, x, y, f):\n","        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n","        #set index 0 0 to 1 after scaling is applied\n","        self.weights[0][1] = -1/x[-1]\n","        self.weights[0][2] = 1/(x[-1]-x[-1]**2)\n","        #scaling\n","        self.weights*=y[-1]\n","\n","        self.weights[0][0] = 1\n","        return f@self.weights.T\n","\n","class Phase2Convex(nn.Module):\n","  def __init__(self,depth,peaks, scales_from_peaks = True): #uses peak parameters from last stage\n","        super(Phase2Convex, self).__init__()\n","        # peak location parameters\n","        x = []\n","        for i in range(depth):\n","          x.append(peaks[i])\n","        self.x = nn.Parameter(torch.tensor(x,dtype=torch.float32))\n","\n","        #independent scaling parameters\n","        if scales_from_peaks: #option to initialize on manifold or with random scales\n","          y = []\n","          y.append(peaks[0]*peaks[1])\n","          for i in range(1,depth):\n","            y.append((1-peaks[i])*peaks[i+1])\n","          self.y = nn.Parameter(torch.tensor(y,dtype=torch.float32))\n","        else:\n","          self.y = nn.Parameter(torch.tensor(np.random.uniform(0,1,depth),dtype=torch.float32))\n","\n","        self.layers = []\n","        self.layers.append(Phase2InputConvex())\n","        for i in range(depth-1):\n","          self.layers.append(Phase2LayerConvex(i))\n","        self.layers.append(Phase2OutputConvex())\n","  def forward(self,f):\n","        with torch.no_grad():\n","          self.x[:]=self.x.clamp(0.01,0.99)\n","          self.y[:]=self.y.clamp(0.01,0.99)\n","        for l in self.layers:\n","          f = l.forward(self.x, self.y, f)\n","        return f\n","  def layer_output(self,f,n):\n","        for i in range(n):\n","          f = self.layers[i].forward(self.x,self.y,f)\n","        return f"],"metadata":{"id":"FPzEzADYXk2T"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class Phase3LayerConvex(nn.Module):\n","    def __init__(self,x1,x2,y):\n","        super().__init__()\n","        self.activation = nn.ReLU()\n","        weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n","\n","        weights[0][1] = -1/x1\n","        weights[0][2] = 1/(x1-x1**2)\n","\n","        weights[1][1] = 1/x1\n","        weights[1][2] = -1/(x1-x1**2)\n","\n","        weights[2][1] = 1/x1\n","        weights[2][2] = -1/(x1-x1**2)\n","        weights[2][3] = -1*x2\n","\n","        weights[3][3] = 1\n","        #scaling\n","        weights*=y\n","\n","        weights[0][0] = 1\n","        self.weights = nn.Parameter(weights)\n","\n","    def forward(self, x):\n","        return self.activation(x@self.weights.T)\n","\n","class Phase3InputConvex(nn.Module):\n","    def __init__(self,x):\n","        super().__init__()\n","        self.activation = nn.ReLU()\n","        self.weights = nn.Parameter(torch.tensor([[1],[1],[1],[0]],dtype = torch.float32))\n","        self.bias = nn.Parameter(torch.tensor([0,0,-1*x,1],dtype = torch.float32))\n","    def forward(self, x):\n","        return self.activation(x@self.weights.T+self.bias)\n","\n","class Phase3OutputConvex(nn.Module):\n","    def __init__(self,x,y):\n","        super().__init__()\n","        weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n","\n","        weights[0][1] = -1/x\n","        weights[0][2] = 1/(x-x**2)\n","\n","        weights*=y\n","\n","        weights[0][0] = 1\n","        self.weights = nn.Parameter(weights)\n","    def forward(self, x):\n","        return x@self.weights.T\n","\n","class Phase3Convex(nn.Module):\n","  def __init__(self,x,y): # uses peaks and scales from phase 2\n","        super(Phase3Convex, self).__init__()\n","        self.layers = []\n","        self.layers.append(Phase3InputConvex(x[0]))\n","        for i in range(len(x)-1):\n","          self.layers.append(Phase3LayerConvex(x[i],x[i+1],y[i]))\n","        self.layers.append(Phase3OutputConvex(x[-1],y[-1]))\n","        self.network = nn.Sequential(*self.layers)\n","  def forward(self,f):\n","        return self.network(f)\n","  def layer_output(self,f,n):\n","        for i in range(n):\n","          f = self.layers[i].forward(f)\n","        return f"],"metadata":{"id":"VpoXkYjTXk49"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# approximates concave functions (sine, tanh)\n","class Phase1LayerConcave(nn.Module):\n","    def __init__(self,d):\n","        super().__init__()\n","        self.d = d\n","        self.activation = nn.ReLU()\n","        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n","    def forward(self, x ,y):\n","        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n","        #set index 0 0 to 1 after scaling is applied\n","        self.weights[0][1] = 1/x[self.d]\n","        self.weights[0][2] = -1/(x[self.d]-x[self.d]**2)\n","\n","        self.weights[1][1] = 1/x[self.d]\n","        self.weights[1][2] = -1/(x[self.d]-x[self.d]**2)\n","\n","        self.weights[2][1] = 1/x[self.d]\n","        self.weights[2][2] = -1/(x[self.d]-x[self.d]**2)\n","        self.weights[2][3] = -1*x[self.d+1]\n","\n","        self.weights[3][3] = 1\n","\n","        self.weights*=(1-x[self.d])*x[self.d+1]\n","\n","        self.weights[0][0] = 1\n","        return self.activation(y@self.weights.T)\n","\n","class Phase1InputConcave(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.activation = nn.ReLU()\n","        self.weights = torch.tensor([[1],[1],[1],[0]],dtype = torch.float32)\n","        self.bias = torch.tensor([0,0,0,1],dtype = torch.float32)\n","    def forward(self, x, y):\n","        new_bias = self.bias.clone() #big autodiff related bugs without this\n","        new_bias[2] = -1*x[0]\n","        return self.activation(y@self.weights.T+new_bias)\n","\n","class Phase1OutputConcave(nn.Module):\n","    def __init__(self,d):\n","        super().__init__()\n","        self.d = d\n","        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n","    def forward(self, x, y):\n","        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n","        #set index 0 0 to 1 after scaling is applied\n","        self.weights[0][1] = 1/x[self.d]\n","        self.weights[0][2] = -1/(x[self.d]-x[self.d]**2)\n","        #scaling\n","\n","        self.weights*=(1-x[self.d])*x[self.d+1]\n","\n","        self.weights[0][0] = 1\n","        return y@self.weights.T\n","\n","class Phase1Concave(nn.Module):\n","  def __init__(self,depth):\n","        super(Phase1Concave, self).__init__()\n","        self.x = nn.Parameter(torch.tensor(np.random.uniform(0,1,depth+1),dtype=torch.float32))\n","        self.layers = []\n","        self.layers.append(Phase1InputConcave())\n","        for i in range(depth-1):\n","          self.layers.append(Phase1LayerConcave(i))\n","        self.layers.append(Phase1OutputConcave(depth-1))\n","  def forward(self,f):\n","        with torch.no_grad():\n","          self.x[:]=self.x.clamp(0.001,0.999)\n","        for l in self.layers:\n","          f = l.forward(self.x,f)\n","        return f\n","  def layer_output(self,f,n):\n","        for i in range(n):\n","          f = self.layers[i].forward(self.x,f)\n","        return f"],"metadata":{"id":"lIAr2hUfXk7h"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class Phase2LayerConcave(nn.Module):\n","    def __init__(self,d):\n","        super().__init__()\n","        self.d = d\n","        self.activation = nn.ReLU()\n","        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n","    def forward(self, x ,y, f):\n","        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n","        #set index 0 0 to 1 after scaling is applied\n","        self.weights[0][1] = 1/x[self.d]\n","        self.weights[0][2] = -1/(x[self.d]-x[self.d]**2)\n","\n","        self.weights[1][1] = 1/x[self.d]\n","        self.weights[1][2] = -1/(x[self.d]-x[self.d]**2)\n","\n","        self.weights[2][1] = 1/x[self.d]\n","        self.weights[2][2] = -1/(x[self.d]-x[self.d]**2)\n","        self.weights[2][3] = -1*x[self.d+1]\n","\n","        self.weights[3][3] = 1\n","        #scaling\n","        self.weights*=y[self.d]\n","\n","        self.weights[0][0] = 1\n","        return self.activation(f@self.weights.T)\n","\n","class Phase2InputConcave(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.activation = nn.ReLU()\n","        self.weights = torch.tensor([[1],[1],[1],[0]],dtype = torch.float32)\n","        self.bias = torch.tensor([0,0,0,1],dtype = torch.float32)\n","    def forward(self, x, y, f):\n","        new_bias = self.bias.clone() #big autodiff related bugs without this\n","        new_bias[2] = -1*x[0]\n","        return self.activation(f@self.weights.T+new_bias)\n","\n","class Phase2OutputConcave(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n","    def forward(self, x, y, f):\n","        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n","        #set index 0 0 to 1 after scaling is applied\n","        self.weights[0][1] = 1/x[-1]\n","        self.weights[0][2] = -1/(x[-1]-x[-1]**2)\n","        #scaling\n","        self.weights*=y[-1]\n","\n","        self.weights[0][0] = 1\n","        return f@self.weights.T\n","\n","class Phase2Concave(nn.Module):\n","  def __init__(self,depth,peaks, scales_from_peaks = True): #uses peak parameters from last stage\n","        super(Phase2Concave, self).__init__()\n","        # peak location parameters\n","        x = []\n","        for i in range(depth):\n","          x.append(peaks[i])\n","        self.x = nn.Parameter(torch.tensor(x,dtype=torch.float32))\n","\n","        #independent scaling parameters\n","        if scales_from_peaks: #option to initialize on manifold or with random scales\n","          y = []\n","          for i in range(depth):\n","            y.append((1-peaks[i])*peaks[i+1])\n","          self.y = nn.Parameter(torch.tensor(y,dtype=torch.float32))\n","        else:\n","          self.y = nn.Parameter(torch.tensor(np.random.uniform(0,1,depth),dtype=torch.float32))\n","\n","        self.layers = []\n","        self.layers.append(Phase2InputConcave())\n","        for i in range(depth-1):\n","          self.layers.append(Phase2LayerConcave(i))\n","        self.layers.append(Phase2OutputConcave())\n","  def forward(self,f):\n","        with torch.no_grad():\n","          self.x[:]=self.x.clamp(0.001,0.999)\n","          self.y[:]=self.y.clamp(0.001,0.999)\n","        for l in self.layers:\n","          f = l.forward(self.x, self.y, f)\n","        return f\n","  def layer_output(self,f,n):\n","        for i in range(n):\n","          f = self.layers[i].forward(self.x,self.y,f)\n","        return f"],"metadata":{"id":"GrfpHa39Xk-E"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class Phase3LayerConcave(nn.Module):\n","    def __init__(self,x1,x2,y):\n","        super().__init__()\n","        self.activation = nn.ReLU()\n","        weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n","\n","        weights[0][1] = 1/x1\n","        weights[0][2] = -1/(x1-x1**2)\n","\n","        weights[1][1] = 1/x1\n","        weights[1][2] = -1/(x1-x1**2)\n","\n","        weights[2][1] = 1/x1\n","        weights[2][2] = -1/(x1-x1**2)\n","        weights[2][3] = -1*x2\n","\n","        weights[3][3] = 1\n","        #scaling\n","        weights*=y\n","\n","        weights[0][0] = 1\n","        self.weights = nn.Parameter(weights)\n","\n","    def forward(self, x):\n","        return self.activation(x@self.weights.T)\n","\n","class Phase3InputConcave(nn.Module):\n","    def __init__(self,x):\n","        super().__init__()\n","        self.activation = nn.ReLU()\n","        self.weights = nn.Parameter(torch.tensor([[1],[1],[1],[0]],dtype = torch.float32))\n","        self.bias = nn.Parameter(torch.tensor([0,0,-1*x,1],dtype = torch.float32))\n","    def forward(self, x):\n","        return self.activation(x@self.weights.T+self.bias)\n","\n","class Phase3OutputConcave(nn.Module):\n","    def __init__(self,x,y):\n","        super().__init__()\n","        weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n","\n","        weights[0][1] = 1/x\n","        weights[0][2] = -1/(x-x**2)\n","\n","        weights*=y\n","\n","        weights[0][0] = 1\n","        self.weights = nn.Parameter(weights)\n","    def forward(self, x):\n","        return x@self.weights.T\n","\n","class Phase3Concave(nn.Module):\n","  def __init__(self,x,y): # uses peaks and scales from phase 2\n","        super(Phase3Concave, self).__init__()\n","        self.layers = []\n","        self.layers.append(Phase3InputConcave(x[0]))\n","        for i in range(len(x)-1):\n","          self.layers.append(Phase3LayerConcave(x[i],x[i+1],y[i]))\n","        self.layers.append(Phase3OutputConcave(x[-1],y[-1]))\n","        self.network = nn.Sequential(*self.layers)\n","  def forward(self,f):\n","        return self.network(f)\n","  def layer_output(self,f,n):\n","        for i in range(n):\n","          f = self.layers[i].forward(f)\n","        return f"],"metadata":{"id":"j8_QkZipXlAX"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def RAAI ( fan_in , fan_out , k = 100 , variance_weights = 0.9) :\n","  \"\"\" Randomized Asymmetric Anti - correlated Initializer ( RAAI ) - Singh and Sreejith (2021)\n","  Arguments :\n","  fan_in -- the number of neurons in the previous layer\n","  fan_out -- the number of neurons in the next layer\n","  corr -- correlation strength for the Gaussian weights\n","  variance_weights -- variance of the weights\n","  Returns :\n","  W, b -- weight and bias matrices with shape (fan_in , fan_out ), and\n","  ( fan_out , )\n","  \"\"\"\n","  corr = k /(1+ k )\n","  mean = np . zeros ( fan_in + 1)\n","  J = np . ones (( fan_in + 1 , fan_in + 1) )\n","  cov = ( np . identity ( fan_in + 1) - J *( corr /( fan_in +1) ) ) * variance_weights / fan_in\n","  P = np . random . multivariate_normal ( mean = mean , cov = cov , size = (\n","  fan_out ) )\n","  for j in range ( P . shape [0]) :\n","      k = np . random . randint (0 , high = fan_in + 1)\n","      P [j , k ] = np . random . beta (2 , 1)\n","  W = P [: , : -1]\n","  b = P [: , -1]\n","  return W . astype ( np . float32 ) , b . astype ( np . float32 )"],"metadata":{"id":"XIWbs9zgXy_0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class RAAINet(nn.Module):\n","  def __init__(self,width,depth):\n","        super(RAAINet, self).__init__()\n","        layers = []\n","        t = nn.Linear(1,width)\n","        w,b = RAAI(1,width)\n","        t.bias.data = torch.tensor(b,dtype=torch.float32,requires_grad=True)\n","        t.weight.data = torch.tensor(w,dtype=torch.float32,requires_grad=True)\n","        layers.append(t)\n","        layers.append(nn.ReLU())\n","        for i in range(depth-1):\n","          t = nn.Linear(width,width)\n","          w,b = RAAI(width,width)\n","          t.bias.data = torch.tensor(b,dtype=torch.float32,requires_grad=True)\n","          t.weight.data = torch.tensor(w,dtype=torch.float32,requires_grad=True)\n","          layers.append(t)\n","          layers.append(nn.ReLU())\n","        t = nn.Linear(width,1)\n","        w,b = RAAI(width,1)\n","        t.bias.data = torch.tensor(b,dtype=torch.float32,requires_grad=True)\n","        t.weight.data = torch.tensor(w,dtype=torch.float32,requires_grad=True)\n","        layers.append(t)\n","        self.network= nn.Sequential(*layers)\n","  def forward(self,f):\n","        return self.network(f)"],"metadata":{"id":"zx9j6LTXXzCb"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#Training Code"],"metadata":{"id":"uvt_UwY2Txdx"}},{"cell_type":"code","source":["#lots of data\n","n = 9\n","x = []\n","s = []\n","for i in range(2**n+1):\n","  x.append(i/2**n)\n","  #s.append(np.sin(0.5*np.pi*(i/2**n)))\n","  s.append((i/2**n)**3)\n","  #s.append(np.tanh(3*(i/2**n)))\n","test_x = torch.tensor(x,dtype=torch.float32).reshape((2**n+1,1))\n","test_labels = torch.tensor(s,dtype = torch.float32).reshape(2**n+1,1)\n","x = torch.tensor(x,dtype=torch.float32).reshape((2**n+1,1))\n","labels = torch.tensor(s,dtype = torch.float32).reshape(2**n+1,1)\n","n = 5"],"metadata":{"id":"-J0r9CD1XzFJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#sparse data\n","train_x = []\n","test_x = []\n","train_y = []\n","test_y = []\n","for i in range(10):\n","  if (i%2==0):\n","    train_x.append(i/10)\n","    train_y.append((i/10)**3)\n","  else:\n","    test_x.append(i/10)\n","    test_y.append((i/10)**3)\n","x = torch.tensor(train_x,dtype=torch.float32).reshape((5,1))\n","labels = torch.tensor(train_y,dtype = torch.float32).reshape(5,1)\n","test_x = torch.tensor(test_x,dtype=torch.float32).reshape((5,1))\n","test_labels = torch.tensor(test_y,dtype = torch.float32).reshape(5,1)\n","n = 5"],"metadata":{"id":"CQceMOrRYxkH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["Random_Seeds = np.load(\"/path/Random_Seeds.npy\")"],"metadata":{"id":"YmMG8gUdYH1H"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#if loss looks bad try changing the 3 lines that look like model = Phase1Convex to concave, which is necessary to learn the concave functions\n","min_loss = 100000000\n","losses = []\n","best_model = None\n","lr=1e-3\n","models = []\n","for i in tqdm.tqdm(range(len(Random_Seeds))):\n","  #Trains output to be differentiable\n","  model=Phase1Convex(n)\n","  model.x = nn.Parameter(torch.tensor(Random_Seeds[i],dtype = torch.float32))\n","  #comment out below to skip\n","\n","  optim=torch.optim.Adam(model.parameters(),lr=lr)\n","  objective = nn.MSELoss()\n","  iter = 1000\n","  for j in range(iter):\n","    f=model(x)\n","    loss = objective(f,labels)\n","    optim.zero_grad()\n","    loss.backward()\n","    optim.step()\n","\n","  #trains scaling factors seperately\n","  model=Phase2Convex(n,model.x) # add false to initialize to a non-differentiable output\n","  #comment below to skip\n","  \"\"\"\n","  optim=torch.optim.Adam(model.parameters(),lr=lr)\n","  objective = nn.MSELoss()\n","  iter = 1000\n","  for j in range(iter):\n","    f=model(x)\n","    loss = objective(f,labels)\n","    optim.zero_grad()\n","    loss.backward()\n","    optim.step()\n","  \"\"\"\n","\n","  #Phase 3 - regular parameter training\n","  model=Phase3Convex(model.x,model.y)\n","\n","  optim=torch.optim.Adam(model.parameters(),lr=lr)\n","  objective = nn.MSELoss()\n","  iter = 1000\n","  for j in range(iter):\n","    f=model(x)\n","    loss = objective(f,labels)\n","    optim.zero_grad()\n","    loss.backward()\n","    optim.step()\n","\n","\n","  L = objective(model(test_x),test_labels).detach().numpy()\n","  losses.append(L)\n","  models.append(model)\n","  if L<min_loss:\n","    min_loss = L\n","    best_model = model"],"metadata":{"id":"YC_Ojj76YH3s"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(np.average(losses))\n","print(min_loss)"],"metadata":{"id":"FiNRpergYH6c"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#Trains the regular network benchmarks\n","min_loss = 1000000\n","losses = []\n","best_model = None\n","for i in tqdm.tqdm(range(30)):\n","  lr=1e-3\n","  model = DefaultNet(n)\n","  #model = RAANet(4,5)\n","  optim=torch.optim.Adam(model.parameters(),lr=lr)\n","  objective = nn.MSELoss()\n","  iter = 1000\n","  for j in range(iter):\n","    f=model(x)\n","    loss = objective(f,labels)\n","    optim.zero_grad()\n","    loss.backward()\n","    optim.step()\n","    losses.append(loss.detach().numpy())\n","  L = objective(model(test_x),test_labels).detach().numpy()\n","  losses.append(L)\n","  if L<min_loss:\n","    min_loss = L\n","    best_model = model"],"metadata":{"id":"nGxVGXRRkOWd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import matplotlib.patches as mpatches\n","import matplotlib.pyplot as plt\n","\n","red_patch = mpatches.Patch(color='red', label='Bias')\n","blue_patch = mpatches.Patch(color='blue', label='Sum')\n","orange_patch = mpatches.Patch(color='orange', label='$t_1$')\n","green_patch = mpatches.Patch(color='green', label='$t_2$')"],"metadata":{"id":"f4cficiRmZ6k"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# plotting code (only for depth 5 networks)\n","network = best_model\n","fig, axs = plt.subplots(2, 3)\n","fig.set_figheight(7)\n","fig.set_figwidth(11)\n","axs[0, 0].plot(x.detach().numpy(),network.layer_output(x,1).detach().numpy())\n","axs[0, 0].set_title(\"Layer 1\")\n","axs[0, 1].plot(x.detach().numpy(),network.layer_output(x,2).detach().numpy())\n","axs[0, 1].set_title(\"Layer 2\")\n","axs[0, 2].plot(x.detach().numpy(),network.layer_output(x,3).detach().numpy())\n","axs[0, 2].set_title(\"Layer 3\")\n","axs[1, 0].plot(x.detach().numpy(),network.layer_output(x,4).detach().numpy())\n","axs[1, 0].set_title(\"Layer 4\")\n","axs[1,0].set_ybound(lower = 0,upper=0.25)\n","axs[1, 1].plot(x.detach().numpy(),network.layer_output(x,5).detach().numpy())\n","axs[1, 1].set_title(\"Layer 5\")\n","axs[1,1].set_ybound(lower = 0,upper = 0.05)\n","axs[1, 2].plot(x.detach().numpy(),network.layer_output(x,6).detach().numpy())\n","axs[1, 2].legend(handles=[red_patch, blue_patch, orange_patch, green_patch], prop={'size': 15})\n","axs[1, 2].set_title(\"Output\")"],"metadata":{"id":"YUjdp28rXzLH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plot_points = torch.tensor(np.arange(0,1,1/500), dtype = torch.float32).reshape((500,1))\n","fig, axs = plt.subplots(1, 2)\n","fig.set_figheight(4)\n","fig.set_figwidth(10)\n","axs[0].plot(plot_points.detach().numpy(),network1(plot_points).detach().numpy())\n","axs[0].scatter(train_x,train_y,s=10,label = \"Training Data\")\n","axs[0].scatter(test_x,test_y,s=10, label = \"Testing Data\")\n","axs[0].set_title(\"Default Network\")\n","axs[0].legend()\n","axs[1].plot(plot_points.detach().numpy(),network2(plot_points).detach().numpy())\n","axs[1].scatter(train_x,train_y, s= 10, label = \"Training Data\")\n","axs[1].scatter(test_x,test_y, s= 10, label = \"Testing Data\")\n","axs[1].legend()\n","axs[1].set_title(\"Reparameterized Pretraining\")\n"],"metadata":{"id":"Wvwqozo6Zg8k"},"execution_count":null,"outputs":[]}]}