{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7fmgRJo4XFDx"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "from torch import nn\n",
        "import numpy as np\n",
        "import tqdm"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# a normal neural network\n",
        "class DefaultNet(nn.Module):\n",
        "  def __init__(self,depth):\n",
        "        super(DefaultNet, self).__init__()\n",
        "        layers = []\n",
        "        layers.append(nn.Linear(1,4))\n",
        "        layers.append(nn.ReLU())\n",
        "        for i in range(depth-1):\n",
        "          layers.append(nn.Linear(4,4))\n",
        "          layers.append(nn.ReLU())\n",
        "        layers.append(nn.Linear(4,1))\n",
        "        self.network= nn.Sequential(*layers)\n",
        "  def forward(self,f):\n",
        "        return self.network(f)\n",
        "  def layer_output(self,f,n):\n",
        "        for i in range(2*n - 1):\n",
        "          f = self.network[i](f)\n",
        "        return f"
      ],
      "metadata": {
        "id": "scFUsRZDXkw1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# subtracts curves from y=x for convex functions like x^3 etc\n",
        "# changes the first scaling factor because of this\n",
        "class Phase1LayerConvex(nn.Module):\n",
        "    def __init__(self,d):\n",
        "        super().__init__()\n",
        "        self.d = d\n",
        "        self.activation = nn.ReLU()\n",
        "        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n",
        "    def forward(self, x ,y):\n",
        "        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n",
        "        #set index 0 0 to 1 after scaling is applied\n",
        "        self.weights[0][1] = -1/x[self.d]\n",
        "        self.weights[0][2] = 1/(x[self.d]-x[self.d]**2)\n",
        "\n",
        "        self.weights[1][1] = 1/x[self.d]\n",
        "        self.weights[1][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "\n",
        "        self.weights[2][1] = 1/x[self.d]\n",
        "        self.weights[2][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "        self.weights[2][3] = -1*x[self.d+1]\n",
        "\n",
        "        self.weights[3][3] = 1\n",
        "        #scaling\n",
        "        if self.d==0: #uses x instead of 1-x because of subtraction from y=x\n",
        "          self.weights*=(x[self.d])*x[self.d+1]\n",
        "        else:\n",
        "          self.weights*=(1-x[self.d])*x[self.d+1]\n",
        "\n",
        "        self.weights[0][0] = 1\n",
        "        return self.activation(y@self.weights.T)\n",
        "\n",
        "class Phase1InputConvex(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.activation = nn.ReLU()\n",
        "        self.weights = torch.tensor([[1],[1],[1],[0]],dtype = torch.float32)\n",
        "        self.bias = torch.tensor([0,0,0,1],dtype = torch.float32)\n",
        "    def forward(self, x, y):\n",
        "        self.bias[2] = -1*x[0]\n",
        "        return self.activation(y@self.weights.T+self.bias)\n",
        "\n",
        "class Phase1OutputConvex(nn.Module):\n",
        "    def __init__(self,d):\n",
        "        super().__init__()\n",
        "        self.d = d\n",
        "        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n",
        "    def forward(self, x, y):\n",
        "        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n",
        "        #set index 0 0 to 1 after scaling is applied\n",
        "        self.weights[0][1] = -1/x[self.d]\n",
        "        self.weights[0][2] = 1/(x[self.d]-x[self.d]**2)\n",
        "        #scaling\n",
        "        if self.d==0:\n",
        "          self.weights*=(x[self.d])*x[self.d+1]\n",
        "        else:\n",
        "          self.weights*=(1-x[self.d])*x[self.d+1]\n",
        "\n",
        "        self.weights[0][0] = 1\n",
        "        return y@self.weights.T\n",
        "\n",
        "class Phase1Convex(nn.Module):\n",
        "  def __init__(self,depth):\n",
        "        super(Phase1Convex, self).__init__()\n",
        "        self.x = nn.Parameter(torch.tensor(np.random.uniform(0,1,depth+1),dtype=torch.float32))\n",
        "        self.layers = []\n",
        "        self.layers.append(Phase1InputConvex())\n",
        "        for i in range(depth-1):\n",
        "          self.layers.append(Phase1LayerConvex(i))\n",
        "        self.layers.append(Phase1OutputConvex(depth-1))\n",
        "  def forward(self,f):\n",
        "        with torch.no_grad():\n",
        "          self.x[:]=self.x.clamp(0.001,0.999)\n",
        "        for l in self.layers:\n",
        "          f = l.forward(self.x,f)\n",
        "        return f\n",
        "  def layer_output(self,f,n):\n",
        "        for i in range(n):\n",
        "          f = self.layers[i].forward(self.x,f)\n",
        "        return f"
      ],
      "metadata": {
        "id": "lJT7wWZnXkzj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class Phase2LayerConvex(nn.Module):\n",
        "    def __init__(self,d):\n",
        "        super().__init__()\n",
        "        self.d = d\n",
        "        self.activation = nn.ReLU()\n",
        "        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n",
        "    def forward(self, x ,y, f):\n",
        "        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n",
        "        #set index 0 0 to 1 after scaling is applied\n",
        "        self.weights[0][1] = -1/x[self.d]\n",
        "        self.weights[0][2] = 1/(x[self.d]-x[self.d]**2)\n",
        "\n",
        "        self.weights[1][1] = 1/x[self.d]\n",
        "        self.weights[1][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "\n",
        "        self.weights[2][1] = 1/x[self.d]\n",
        "        self.weights[2][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "        self.weights[2][3] = -1*x[self.d+1]\n",
        "\n",
        "        self.weights[3][3] = 1\n",
        "        #scaling\n",
        "        self.weights*=y[self.d]\n",
        "\n",
        "        self.weights[0][0] = 1\n",
        "        return self.activation(f@self.weights.T)\n",
        "\n",
        "class Phase2InputConvex(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.activation = nn.ReLU()\n",
        "        self.weights = torch.tensor([[1],[1],[1],[0]],dtype = torch.float32)\n",
        "        self.bias = torch.tensor([0,0,0,1],dtype = torch.float32)\n",
        "    def forward(self, x, y, f):\n",
        "        self.bias[2] = -1*x[0]\n",
        "        return self.activation(f@self.weights.T+self.bias)\n",
        "\n",
        "class Phase2OutputConvex(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n",
        "    def forward(self, x, y, f):\n",
        "        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n",
        "        #set index 0 0 to 1 after scaling is applied\n",
        "        self.weights[0][1] = -1/x[-1]\n",
        "        self.weights[0][2] = 1/(x[-1]-x[-1]**2)\n",
        "        #scaling\n",
        "        self.weights*=y[-1]\n",
        "\n",
        "        self.weights[0][0] = 1\n",
        "        return f@self.weights.T\n",
        "\n",
        "class Phase2Convex(nn.Module):\n",
        "  def __init__(self,depth,peaks, scales_from_peaks = True): #uses peak parameters from last stage\n",
        "        super(Phase2Convex, self).__init__()\n",
        "        # peak location parameters\n",
        "        x = []\n",
        "        for i in range(depth):\n",
        "          x.append(peaks[i])\n",
        "        self.x = nn.Parameter(torch.tensor(x,dtype=torch.float32))\n",
        "\n",
        "        #independent scaling parameters\n",
        "        if scales_from_peaks: #option to initialize on manifold or with random scales\n",
        "          y = []\n",
        "          y.append(peaks[0]*peaks[1])\n",
        "          for i in range(1,depth):\n",
        "            y.append((1-peaks[i])*peaks[i+1])\n",
        "          self.y = nn.Parameter(torch.tensor(y,dtype=torch.float32))\n",
        "        else:\n",
        "          self.y = nn.Parameter(torch.tensor(np.random.uniform(0,1,depth),dtype=torch.float32))\n",
        "\n",
        "        self.layers = []\n",
        "        self.layers.append(Phase2InputConvex())\n",
        "        for i in range(depth-1):\n",
        "          self.layers.append(Phase2LayerConvex(i))\n",
        "        self.layers.append(Phase2OutputConvex())\n",
        "  def forward(self,f):\n",
        "        with torch.no_grad():\n",
        "          self.x[:]=self.x.clamp(0.001,0.999)\n",
        "          self.y[:]=self.y.clamp(0.001,0.999)\n",
        "        for l in self.layers:\n",
        "          f = l.forward(self.x, self.y, f)\n",
        "        return f\n",
        "  def layer_output(self,f,n):\n",
        "        for i in range(n):\n",
        "          f = self.layers[i].forward(self.x,self.y,f)\n",
        "        return f"
      ],
      "metadata": {
        "id": "FPzEzADYXk2T"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class Phase3LayerConvex(nn.Module):\n",
        "    def __init__(self,x1,x2,y):\n",
        "        super().__init__()\n",
        "        self.activation = nn.ReLU()\n",
        "        weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n",
        "\n",
        "        weights[0][1] = -1/x1\n",
        "        weights[0][2] = 1/(x1-x1**2)\n",
        "\n",
        "        weights[1][1] = 1/x1\n",
        "        weights[1][2] = -1/(x1-x1**2)\n",
        "\n",
        "        weights[2][1] = 1/x1\n",
        "        weights[2][2] = -1/(x1-x1**2)\n",
        "        weights[2][3] = -1*x2\n",
        "\n",
        "        weights[3][3] = 1\n",
        "        #scaling\n",
        "        weights*=y\n",
        "\n",
        "        weights[0][0] = 1\n",
        "        self.weights = nn.Parameter(weights)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.activation(x@self.weights.T)\n",
        "\n",
        "class Phase3InputConvex(nn.Module):\n",
        "    def __init__(self,x):\n",
        "        super().__init__()\n",
        "        self.activation = nn.ReLU()\n",
        "        self.weights = nn.Parameter(torch.tensor([[1],[1],[1],[0]],dtype = torch.float32))\n",
        "        self.bias = nn.Parameter(torch.tensor([0,0,-1*x,1],dtype = torch.float32))\n",
        "    def forward(self, x):\n",
        "        return self.activation(x@self.weights.T+self.bias)\n",
        "\n",
        "class Phase3OutputConvex(nn.Module):\n",
        "    def __init__(self,x,y):\n",
        "        super().__init__()\n",
        "        weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n",
        "\n",
        "        weights[0][1] = -1/x\n",
        "        weights[0][2] = 1/(x-x**2)\n",
        "\n",
        "        weights*=y\n",
        "\n",
        "        weights[0][0] = 1\n",
        "        self.weights = nn.Parameter(weights)\n",
        "    def forward(self, x):\n",
        "        return x@self.weights.T\n",
        "\n",
        "class Phase3Convex(nn.Module):\n",
        "  def __init__(self,x,y): # uses peaks and scales from phase 2\n",
        "        super(Phase3Convex, self).__init__()\n",
        "        self.layers = []\n",
        "        self.layers.append(Phase3InputConvex(x[0]))\n",
        "        for i in range(len(x)-1):\n",
        "          self.layers.append(Phase3LayerConvex(x[i],x[i+1],y[i]))\n",
        "        self.layers.append(Phase3OutputConvex(x[-1],y[-1]))\n",
        "        self.network = nn.Sequential(*self.layers)\n",
        "  def forward(self,f):\n",
        "        return self.network(f)\n",
        "  def layer_output(self,f,n):\n",
        "        for i in range(n):\n",
        "          f = self.layers[i].forward(f)\n",
        "        return f"
      ],
      "metadata": {
        "id": "VpoXkYjTXk49"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# approximates concave functions (sine, tanh)\n",
        "class Phase1LayerConcave(nn.Module):\n",
        "    def __init__(self,d):\n",
        "        super().__init__()\n",
        "        self.d = d\n",
        "        self.activation = nn.ReLU()\n",
        "        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n",
        "    def forward(self, x ,y):\n",
        "        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n",
        "        #set index 0 0 to 1 after scaling is applied\n",
        "        self.weights[0][1] = 1/x[self.d]\n",
        "        self.weights[0][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "\n",
        "        self.weights[1][1] = 1/x[self.d]\n",
        "        self.weights[1][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "\n",
        "        self.weights[2][1] = 1/x[self.d]\n",
        "        self.weights[2][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "        self.weights[2][3] = -1*x[self.d+1]\n",
        "\n",
        "        self.weights[3][3] = 1\n",
        "\n",
        "        self.weights*=(1-x[self.d])*x[self.d+1]\n",
        "\n",
        "        self.weights[0][0] = 1\n",
        "        return self.activation(y@self.weights.T)\n",
        "\n",
        "class Phase1InputConcave(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.activation = nn.ReLU()\n",
        "        self.weights = torch.tensor([[1],[1],[1],[0]],dtype = torch.float32)\n",
        "        self.bias = torch.tensor([0,0,0,1],dtype = torch.float32)\n",
        "    def forward(self, x, y):\n",
        "        self.bias[2] = -1*x[0]\n",
        "        return self.activation(y@self.weights.T+self.bias)\n",
        "\n",
        "class Phase1OutputConcave(nn.Module):\n",
        "    def __init__(self,d):\n",
        "        super().__init__()\n",
        "        self.d = d\n",
        "        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n",
        "    def forward(self, x, y):\n",
        "        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n",
        "        #set index 0 0 to 1 after scaling is applied\n",
        "        self.weights[0][1] = 1/x[self.d]\n",
        "        self.weights[0][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "        #scaling\n",
        "\n",
        "        self.weights*=(1-x[self.d])*x[self.d+1]\n",
        "\n",
        "        self.weights[0][0] = 1\n",
        "        return y@self.weights.T\n",
        "\n",
        "class Phase1Concave(nn.Module):\n",
        "  def __init__(self,depth):\n",
        "        super(Phase1Concave, self).__init__()\n",
        "        self.x = nn.Parameter(torch.tensor(np.random.uniform(0,1,depth+1),dtype=torch.float32))\n",
        "        self.layers = []\n",
        "        self.layers.append(Phase1InputConcave())\n",
        "        for i in range(depth-1):\n",
        "          self.layers.append(Phase1LayerConcave(i))\n",
        "        self.layers.append(Phase1OutputConcave(depth-1))\n",
        "  def forward(self,f):\n",
        "        with torch.no_grad():\n",
        "          self.x[:]=self.x.clamp(0.001,0.999)\n",
        "        for l in self.layers:\n",
        "          f = l.forward(self.x,f)\n",
        "        return f\n",
        "  def layer_output(self,f,n):\n",
        "        for i in range(n):\n",
        "          f = self.layers[i].forward(self.x,f)\n",
        "        return f"
      ],
      "metadata": {
        "id": "lIAr2hUfXk7h"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class Phase2LayerConcave(nn.Module):\n",
        "    def __init__(self,d):\n",
        "        super().__init__()\n",
        "        self.d = d\n",
        "        self.activation = nn.ReLU()\n",
        "        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n",
        "    def forward(self, x ,y, f):\n",
        "        self.weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n",
        "        #set index 0 0 to 1 after scaling is applied\n",
        "        self.weights[0][1] = 1/x[self.d]\n",
        "        self.weights[0][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "\n",
        "        self.weights[1][1] = 1/x[self.d]\n",
        "        self.weights[1][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "\n",
        "        self.weights[2][1] = 1/x[self.d]\n",
        "        self.weights[2][2] = -1/(x[self.d]-x[self.d]**2)\n",
        "        self.weights[2][3] = -1*x[self.d+1]\n",
        "\n",
        "        self.weights[3][3] = 1\n",
        "        #scaling\n",
        "        self.weights*=y[self.d]\n",
        "\n",
        "        self.weights[0][0] = 1\n",
        "        return self.activation(f@self.weights.T)\n",
        "\n",
        "class Phase2InputConcave(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.activation = nn.ReLU()\n",
        "        self.weights = torch.tensor([[1],[1],[1],[0]],dtype = torch.float32)\n",
        "        self.bias = torch.tensor([0,0,0,1],dtype = torch.float32)\n",
        "    def forward(self, x, y, f):\n",
        "        self.bias[2] = -1*x[0]\n",
        "        return self.activation(f@self.weights.T+self.bias)\n",
        "\n",
        "class Phase2OutputConcave(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n",
        "    def forward(self, x, y, f):\n",
        "        self.weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n",
        "        #set index 0 0 to 1 after scaling is applied\n",
        "        self.weights[0][1] = 1/x[-1]\n",
        "        self.weights[0][2] = -1/(x[-1]-x[-1]**2)\n",
        "        #scaling\n",
        "        self.weights*=y[-1]\n",
        "\n",
        "        self.weights[0][0] = 1\n",
        "        return f@self.weights.T\n",
        "\n",
        "class Phase2Concave(nn.Module):\n",
        "  def __init__(self,depth,peaks, scales_from_peaks = True): #uses peak parameters from last stage\n",
        "        super(Phase2Concave, self).__init__()\n",
        "        # peak location parameters\n",
        "        x = []\n",
        "        for i in range(depth):\n",
        "          x.append(peaks[i])\n",
        "        self.x = nn.Parameter(torch.tensor(x,dtype=torch.float32))\n",
        "\n",
        "        #independent scaling parameters\n",
        "        if scales_from_peaks: #option to initialize on manifold or with random scales\n",
        "          y = []\n",
        "          for i in range(depth):\n",
        "            y.append((1-peaks[i])*peaks[i+1])\n",
        "          self.y = nn.Parameter(torch.tensor(y,dtype=torch.float32))\n",
        "        else:\n",
        "          self.y = nn.Parameter(torch.tensor(np.random.uniform(0,1,depth),dtype=torch.float32))\n",
        "\n",
        "        self.layers = []\n",
        "        self.layers.append(Phase2InputConcave())\n",
        "        for i in range(depth-1):\n",
        "          self.layers.append(Phase2LayerConcave(i))\n",
        "        self.layers.append(Phase2OutputConcave())\n",
        "  def forward(self,f):\n",
        "        with torch.no_grad():\n",
        "          self.x[:]=self.x.clamp(0.001,0.999)\n",
        "          self.y[:]=self.y.clamp(0.001,0.999)\n",
        "        for l in self.layers:\n",
        "          f = l.forward(self.x, self.y, f)\n",
        "        return f\n",
        "  def layer_output(self,f,n):\n",
        "        for i in range(n):\n",
        "          f = self.layers[i].forward(self.x,self.y,f)\n",
        "        return f"
      ],
      "metadata": {
        "id": "GrfpHa39Xk-E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class Phase3LayerConcave(nn.Module):\n",
        "    def __init__(self,x1,x2,y):\n",
        "        super().__init__()\n",
        "        self.activation = nn.ReLU()\n",
        "        weights = torch.tensor([[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],dtype = torch.float32)\n",
        "\n",
        "        weights[0][1] = 1/x1\n",
        "        weights[0][2] = -1/(x1-x1**2)\n",
        "\n",
        "        weights[1][1] = 1/x1\n",
        "        weights[1][2] = -1/(x1-x1**2)\n",
        "\n",
        "        weights[2][1] = 1/x1\n",
        "        weights[2][2] = -1/(x1-x1**2)\n",
        "        weights[2][3] = -1*x2\n",
        "\n",
        "        weights[3][3] = 1\n",
        "        #scaling\n",
        "        weights*=y\n",
        "\n",
        "        weights[0][0] = 1\n",
        "        self.weights = nn.Parameter(weights)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.activation(x@self.weights.T)\n",
        "\n",
        "class Phase3InputConcave(nn.Module):\n",
        "    def __init__(self,x):\n",
        "        super().__init__()\n",
        "        self.activation = nn.ReLU()\n",
        "        self.weights = nn.Parameter(torch.tensor([[1],[1],[1],[0]],dtype = torch.float32))\n",
        "        self.bias = nn.Parameter(torch.tensor([0,0,-1*x,1],dtype = torch.float32))\n",
        "    def forward(self, x):\n",
        "        return self.activation(x@self.weights.T+self.bias)\n",
        "\n",
        "class Phase3OutputConcave(nn.Module):\n",
        "    def __init__(self,x,y):\n",
        "        super().__init__()\n",
        "        weights = torch.tensor([[0,0,0,0]],dtype = torch.float32)\n",
        "\n",
        "        weights[0][1] = 1/x\n",
        "        weights[0][2] = -1/(x-x**2)\n",
        "\n",
        "        weights*=y\n",
        "\n",
        "        weights[0][0] = 1\n",
        "        self.weights = nn.Parameter(weights)\n",
        "    def forward(self, x):\n",
        "        return x@self.weights.T\n",
        "\n",
        "class Phase3Concave(nn.Module):\n",
        "  def __init__(self,x,y): # uses peaks and scales from phase 2\n",
        "        super(Phase3Concave, self).__init__()\n",
        "        self.layers = []\n",
        "        self.layers.append(Phase3InputConcave(x[0]))\n",
        "        for i in range(len(x)-1):\n",
        "          self.layers.append(Phase3LayerConcave(x[i],x[i+1],y[i]))\n",
        "        self.layers.append(Phase3OutputConcave(x[-1],y[-1]))\n",
        "        self.network = nn.Sequential(*self.layers)\n",
        "  def forward(self,f):\n",
        "        return self.network(f)\n",
        "  def layer_output(self,f,n):\n",
        "        for i in range(n):\n",
        "          f = self.layers[i].forward(f)\n",
        "        return f"
      ],
      "metadata": {
        "id": "j8_QkZipXlAX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def RAAI ( fan_in , fan_out , k = 100 , variance_weights = 0.9) :\n",
        "  \"\"\" Randomized Asymmetric Anti - correlated Initializer ( RAAI ) - Singh and Sreejith (2021)\n",
        "  Arguments :\n",
        "  fan_in -- the number of neurons in the previous layer\n",
        "  fan_out -- the number of neurons in the next layer\n",
        "  corr -- correlation strength for the Gaussian weights\n",
        "  variance_weights -- variance of the weights\n",
        "  Returns :\n",
        "  W, b -- weight and bias matrices with shape (fan_in , fan_out ), and\n",
        "  ( fan_out , )\n",
        "  \"\"\"\n",
        "  corr = k /(1+ k )\n",
        "  mean = np . zeros ( fan_in + 1)\n",
        "  J = np . ones (( fan_in + 1 , fan_in + 1) )\n",
        "  cov = ( np . identity ( fan_in + 1) - J *( corr /( fan_in +1) ) ) * variance_weights / fan_in\n",
        "  P = np . random . multivariate_normal ( mean = mean , cov = cov , size = (\n",
        "  fan_out ) )\n",
        "  for j in range ( P . shape [0]) :\n",
        "      k = np . random . randint (0 , high = fan_in + 1)\n",
        "      P [j , k ] = np . random . beta (2 , 1)\n",
        "  W = P [: , : -1]\n",
        "  b = P [: , -1]\n",
        "  return W . astype ( np . float32 ) , b . astype ( np . float32 )"
      ],
      "metadata": {
        "id": "XIWbs9zgXy_0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class RAAINet(nn.Module):\n",
        "  def __init__(self,width,depth):\n",
        "        super(RAAINet, self).__init__()\n",
        "        layers = []\n",
        "        t = nn.Linear(1,width)\n",
        "        w,b = RAAI(1,width)\n",
        "        t.bias.data = torch.tensor(b,dtype=torch.float32,requires_grad=True)\n",
        "        t.weight.data = torch.tensor(w,dtype=torch.float32,requires_grad=True)\n",
        "        layers.append(t)\n",
        "        layers.append(nn.ReLU())\n",
        "        for i in range(depth-1):\n",
        "          t = nn.Linear(width,width)\n",
        "          w,b = RAAI(width,width)\n",
        "          t.bias.data = torch.tensor(b,dtype=torch.float32,requires_grad=True)\n",
        "          t.weight.data = torch.tensor(w,dtype=torch.float32,requires_grad=True)\n",
        "          layers.append(t)\n",
        "          layers.append(nn.ReLU())\n",
        "        t = nn.Linear(width,1)\n",
        "        w,b = RAAI(width,1)\n",
        "        t.bias.data = torch.tensor(b,dtype=torch.float32,requires_grad=True)\n",
        "        t.weight.data = torch.tensor(w,dtype=torch.float32,requires_grad=True)\n",
        "        layers.append(t)\n",
        "        self.network= nn.Sequential(*layers)\n",
        "  def forward(self,f):\n",
        "        return self.network(f)"
      ],
      "metadata": {
        "id": "zx9j6LTXXzCb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "n = 9\n",
        "x = []\n",
        "s = []\n",
        "for i in range(2**n+1):\n",
        "  x.append(i/2**n)\n",
        "  #s.append(np.sin(0.5*np.pi*(i/2**n)))\n",
        "  #s.append((i/2**n)**11)\n",
        "  s.append(np.tanh(3*(i/2**n)))\n",
        "x = torch.tensor(x,dtype=torch.float32).reshape((2**n+1,1))\n",
        "labels = torch.tensor(s,dtype = torch.float32).reshape(2**n+1,1)\n",
        "n = 5"
      ],
      "metadata": {
        "id": "-J0r9CD1XzFJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "Random_Seeds = np.load(\"/content/drive/MyDrive/DeepLearning/Saved_Models/Random_Seeds.npy\")"
      ],
      "metadata": {
        "id": "YmMG8gUdYH1H"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "min_loss = 100000000\n",
        "losses = []\n",
        "best_model = None\n",
        "lr=1e-3\n",
        "\n",
        "for i in tqdm.tqdm(range(len(Random_Seeds))):\n",
        "  model=Phase1Concave(n)\n",
        "  model.x = nn.Parameter(torch.tensor(Random_Seeds[i],dtype = torch.float32))\n",
        "  #comment out below to skip phase 1\n",
        "\n",
        "  optim=torch.optim.Adam(model.parameters(),lr=lr)\n",
        "  objective = nn.MSELoss()\n",
        "  iter = 1000\n",
        "  for j in range(iter):\n",
        "    f=model(x)\n",
        "    loss = objective(f,labels)\n",
        "    optim.zero_grad()\n",
        "    loss.backward(retain_graph = True)\n",
        "    optim.step()\n",
        "\n",
        "\n",
        "  #Phase 2\n",
        "  model=Phase2Concave(n,model.x) # add false to initialize off differentiable manifold\n",
        "  #comment below to skip phase 2\n",
        "\n",
        "  optim=torch.optim.Adam(model.parameters(),lr=lr)\n",
        "  objective = nn.MSELoss()\n",
        "  iter = 1000\n",
        "  for j in range(iter):\n",
        "    f=model(x)\n",
        "    loss = objective(f,labels)\n",
        "    optim.zero_grad()\n",
        "    loss.backward(retain_graph = True)\n",
        "    optim.step()\n",
        "\n",
        "\n",
        "  #Phase 3\n",
        "  model=Phase3Concave(model.x,model.y)\n",
        "  optim=torch.optim.Adam(model.parameters(),lr=lr)\n",
        "  objective = nn.MSELoss()\n",
        "  iter = 1000\n",
        "  for j in range(iter):\n",
        "    f=model(x)\n",
        "    loss = objective(f,labels)\n",
        "    optim.zero_grad()\n",
        "    loss.backward(retain_graph = True)\n",
        "    optim.step()\n",
        "\n",
        "  L = objective(model(x),labels).detach().numpy()\n",
        "  losses.append(L)\n",
        "  if L<min_loss:\n",
        "    min_loss = L\n",
        "    best_model = model"
      ],
      "metadata": {
        "id": "YC_Ojj76YH3s"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(np.average(losses))\n",
        "print(min_loss)"
      ],
      "metadata": {
        "id": "FiNRpergYH6c"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# plotting code (only for depth 5 networks)\n",
        "network = best_model\n",
        "fig, axs = plt.subplots(2, 3)\n",
        "fig.set_figheight(7)\n",
        "fig.set_figwidth(11)\n",
        "axs[0, 0].plot(x.detach().numpy(),network.layer_output(x,1).detach().numpy())\n",
        "axs[0, 0].set_title(\"Layer 1\")\n",
        "axs[0, 1].plot(x.detach().numpy(),network.layer_output(x,2).detach().numpy())\n",
        "axs[0, 1].set_title(\"Layer 2\")\n",
        "axs[0, 2].plot(x.detach().numpy(),network.layer_output(x,3).detach().numpy())\n",
        "axs[0, 2].set_title(\"Layer 3\")\n",
        "axs[1, 0].plot(x.detach().numpy(),network.layer_output(x,4).detach().numpy())\n",
        "axs[1, 0].set_title(\"Layer 4\")\n",
        "axs[1,0].set_ybound(lower = 0,upper=0.2)\n",
        "axs[1, 1].plot(x.detach().numpy(),network.layer_output(x,5).detach().numpy())\n",
        "axs[1, 1].set_title(\"Layer 5\")\n",
        "axs[1,1].set_ybound(lower = 0,upper = 0.05)\n",
        "axs[1, 2].plot(x.detach().numpy(),network.layer_output(x,6).detach().numpy())\n",
        "axs[1, 2].set_title(\"Output\")"
      ],
      "metadata": {
        "id": "YUjdp28rXzLH"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}