{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "860cda74",
      "metadata": {},
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "6108d1a8",
      "metadata": {},
      "outputs": [],
      "source": [
        "import os\n",
        "import os, sys\n",
        "sys.path.insert(0, \"..\") \n",
        "import numpy as np\n",
        "from tensorflow.keras.losses import BinaryCrossentropy\n",
        "# import tensorflow_addons as tfa\n",
        "from tensorflow.keras import activations\n",
        "from tensorflow.keras import utils\n",
        "from tensorflow.keras.losses import BinaryCrossentropy\n",
        "from tensorflow.keras.layers import GlobalAveragePooling2D, Dense\n",
        "from tensorflow.keras.models import Model\n",
        "from tensorflow.keras.optimizers import Adam\n",
        "from tensorflow.keras import backend as K\n",
        "import tensorflow as tf\n",
        "# import pandas as pd\n",
        "from tensorflow.keras.models import load_model\n",
        "from tensorflow.keras.utils import plot_model\n",
        "from datasets import get_dataset\n",
        "from dataloaders.datasetFromSequence import DatasetFromSequenceClass \n",
        "from auxiliary.viz_utils import *\n",
        "import pickle as pkl\n",
        "import math"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "b9ced5e8",
      "metadata": {},
      "outputs": [],
      "source": [
        "SEED = 42"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "3c639fda",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "3.0"
            ]
          },
          "execution_count": 3,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import math\n",
        "math.sqrt(9)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "95af07d0",
      "metadata": {},
      "outputs": [],
      "source": [
        "import tensorflow.keras.backend as K"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 48,
      "id": "aef9e9cc",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "You selected Cifar10\n"
          ]
        }
      ],
      "source": [
        "dataset, input_shape,n_classes, TRAIN_WITH_GEN, TRAIN_WITH_LOGITS, batch_size, normalize = get_dataset('cifar10')\n",
        "(x_train, y_train), (x_test, y_test) = dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "id": "40313ed2",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<module 'tensorflow.keras.initializers' from 'C:\\\\Users\\\\XXX (Anonimised for double-blind review)\\\\AppData\\\\Roaming\\\\Python\\\\Python38\\\\site-packages\\\\tensorflow\\\\keras\\\\initializers\\\\__init__.py'>"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "initializers.LecunUniform()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 108,
      "id": "5e0a7fd7",
      "metadata": {},
      "outputs": [],
      "source": [
        "from tensorflow.keras.models import Sequential\n",
        "from tensorflow.keras.layers import Conv2D, MaxPool2D, Dense,Input,ReLU,Flatten\n",
        "from tensorflow.keras import initializers\n",
        "# class Net(nn.Module):\n",
        "#     def __init__(self):\n",
        "#         super(Net, self).__init__()\n",
        "#         self.conv1 = nn.Conv2d(3, 6, 5)\n",
        "#         self.pool = nn.MaxPool2d(2, 2)\n",
        "#         self.conv2 = nn.Conv2d(6, 16, 5)\n",
        "#         self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
        "#         self.fc2 = nn.Linear(120, 84)\n",
        "#         self.fc3 = nn.Linear(84, 10)\n",
        "\n",
        "def forward(self, x):\n",
        "        x = self.pool(F.relu(self.conv1(x)))\n",
        "        x = self.pool(F.relu(self.conv2(x)))\n",
        "        x = x.view(-1, 16 * 5 * 5)\n",
        "        x = F.relu(self.fc1(x))\n",
        "        x = F.relu(self.fc2(x))\n",
        "        x = self.fc3(x)\n",
        "init = initializers.LecunUniform()\n",
        "inp = Input(shape=(32,32,3))\n",
        "conv1 =  Conv2D(6,kernel_size=(5,5),kernel_initializer=init,bias_initializer=init)\n",
        "pool = MaxPool2D((2,2))\n",
        "conv2 = Conv2D(16,kernel_size=(5,5),kernel_initializer=init,bias_initializer=init)\n",
        "fc1 = Dense(120,kernel_initializer=init,bias_initializer=init)\n",
        "fc2 = Dense(84,kernel_initializer=init, bias_initializer=init)\n",
        "fc3 = Dense(10,kernel_initializer=init, bias_initializer=init)\n",
        "# ops = []\n",
        "x1 = conv1(inp)\n",
        "x2 = ReLU(negative_slope=math.sqrt(5))(x1)\n",
        "x3 = pool(x2)\n",
        "x4 = conv2(x3)\n",
        "x5 = ReLU(negative_slope=math.sqrt(5))(x4)\n",
        "x6 = pool(x5)\n",
        "x7 = Flatten()(x6)\n",
        "x8 = fc1(x7)\n",
        "x9 = ReLU(negative_slope=math.sqrt(5))(x8)\n",
        "x10 = fc2(x9)\n",
        "x11 = ReLU(negative_slope=math.sqrt(5))(x10)\n",
        "x12 = fc3(x11)\n",
        "tf_mod = Model(inp,x12)\n",
        "# tf_mod = Sequential()\n",
        "# tf_mod.add(Input(shape=(32,32,3)))\n",
        "# tf_mod.add(Conv2D(6,kernel_size=(5,5)))\n",
        "# tf_mod.add(MaxPool2D((2,2)))\n",
        "# tf_mod.add(Conv2D(16,kernel_size=(5,5)))\n",
        "# tf_mod.add(Dense(120))\n",
        "# tf_mod.add(Dense(10))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 36,
      "id": "f72cd74e",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([[[0.1, 1. ],\n",
              "        [5. , 2. ],\n",
              "        [3. , 3. ]],\n",
              "\n",
              "       [[4. , 4. ],\n",
              "        [8. , 5. ],\n",
              "        [5. , 5. ]]])"
            ]
          },
          "execution_count": 36,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "t_a = np.array([[[0.1,-1],[5,2],[3,3]],[[4,4],[8,5],[5,5]]])\n",
        "t_a * np.sign(t_a)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 38,
      "id": "096158fa",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(2, 3)"
            ]
          },
          "execution_count": 38,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "a = np.array([1,2,3])\n",
        "b = np.array([5,6,7])\n",
        "np.stack([a,b]).shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "id": "faf1fe14",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "-11.131913602352142"
            ]
          },
          "execution_count": 44,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "sum([np.sum(w) for w in tf_mod.get_weights()])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "id": "e723d1ec",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "3207.1890330314636"
            ]
          },
          "execution_count": 46,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "sum([np.sum(w) for w in tf_mod.get_weights()])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 109,
      "id": "3ca1b1b5",
      "metadata": {},
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": 101,
      "id": "6ec6cf42",
      "metadata": {},
      "outputs": [],
      "source": [
        "outs = tf_mod(inps)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 103,
      "id": "9725a710",
      "metadata": {},
      "outputs": [],
      "source": [
        "f_out = 1\n",
        "for o in outs:\n",
        "    f_out *= tf.reduce_prod(o)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 106,
      "id": "8874af93",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(), dtype=float32, numpy=inf>"
            ]
          },
          "execution_count": 106,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "f_out"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e8e55868",
      "metadata": {},
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": 86,
      "id": "928184d3",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[<tensorflow.python.keras.layers.convolutional.Conv2D at 0x224e5b74820>,\n",
              " <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x224812df100>,\n",
              " <tensorflow.python.keras.layers.pooling.MaxPooling2D at 0x224e5b741c0>,\n",
              " <tensorflow.python.keras.layers.convolutional.Conv2D at 0x224e5ba79d0>,\n",
              " <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x224e5b69fd0>,\n",
              " <tensorflow.python.keras.layers.core.Flatten at 0x224e3253ca0>,\n",
              " <tensorflow.python.keras.layers.core.Dense at 0x224e5b74d60>,\n",
              " <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x224e5b67670>,\n",
              " <tensorflow.python.keras.layers.core.Dense at 0x224e5ba73d0>,\n",
              " <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x224e326f370>,\n",
              " <tensorflow.python.keras.layers.core.Dense at 0x224e5ba7790>]"
            ]
          },
          "execution_count": 86,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "[l for l in tf_mod.layers if(not l.name.startswith('input'))]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 172,
      "id": "be3a8d7c",
      "metadata": {},
      "outputs": [],
      "source": [
        "def get_synflow_scores(tf_mod):\n",
        "    #Step 1 - conv to absolute value\n",
        "    signs = {}\n",
        "    new_weights = []\n",
        "    for i_l,l_w in enumerate(tf_mod.get_weights()):\n",
        "        signs[i_l] = np.sign(l_w)\n",
        "        new_w = l_w * signs[i_l]\n",
        "        new_weights.append(new_w)\n",
        "    tf_mod.set_weights(new_weights)\n",
        "    # Step 2 - take data point with 1s\n",
        "    inps = tf.ones([1] + list(input_shape))\n",
        "    # Step 3 - Feed through the network\n",
        "    with tf.GradientTape() as t:\n",
        "        outs = tf_mod(inps)\n",
        "        # Step 4 - Sum all outputs into single number = pseudo loss\n",
        "        rsf = tf.reduce_sum(outs)\n",
        "    # Step 5 - Get gradients for the pseudo loss\n",
        "    gradients = t.gradient(rsf, [l.trainable_variables for l in tf_mod.layers])\n",
        "    # Step 6 - Multiply each weight by the backpropagated signal\n",
        "    scores = []\n",
        "    for i_l,l_w in enumerate(tf_mod.layers):\n",
        "        if(len(l_w.trainable_variables)>0):\n",
        "            new_w = l_w.trainable_variables[0]*gradients[i_l][0]\n",
        "            scores.append(new_w)\n",
        "            new_b= l_w.trainable_variables[1]*gradients[i_l][1]\n",
        "            scores.append(new_b)\n",
        "    # Step 7 - Revert old weights\n",
        "    old_weigths = []\n",
        "    for i_l,l_w in enumerate(tf_mod.get_weights()):\n",
        "        signs[i_l] = np.sign(l_w)\n",
        "        new_w = l_w * signs[i_l]\n",
        "        old_weigths.append(new_w)\n",
        "    tf_mod.set_weights(new_weights)\n",
        "    return scores\n",
        "\n",
        "\n",
        "\n",
        "# def get_synflow_nas_score(scores,tf_mod):\n",
        "#     final_score = get_synflow_scores(tf_mod)\n",
        "#     return final_score\n",
        "        \n",
        "scores = get_synflow_scores(tf_mod)\n",
        "synflow_score = np.log(sum(map(lambda x: np.sum(x),scores)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 177,
      "id": "42eaa39b",
      "metadata": {},
      "outputs": [],
      "source": [
        "avg_scores = []\n",
        "for sc in scores:\n",
        "    avg_scores.append(np.average(sc))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 179,
      "id": "c8cd287b",
      "metadata": {},
      "outputs": [],
      "source": [
        "def sum_arr(arr):\n",
        "    s = 0.\n",
        "    for i in range(len(arr)):\n",
        "        s += np.sum(arr[i])\n",
        "    return s.item()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 182,
      "id": "979d22b1",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "15.558528548993964"
            ]
          },
          "execution_count": 182,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "synflow_score = np.log(sum_arr(scores))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 210,
      "id": "20a0a3e8",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[1087970.1,\n",
              " 53070.24,\n",
              " 1141039.0,\n",
              " 2303.957,\n",
              " 1143345.1,\n",
              " 54.785492,\n",
              " 1143357.2,\n",
              " 8.138836,\n",
              " 1143412.8,\n",
              " 2.5569668]"
            ]
          },
          "execution_count": 210,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "list(map(lambda x: np.sum(x),scores))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 211,
      "id": "cf210d7d",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "5714563.926607132"
            ]
          },
          "execution_count": 211,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": 200,
      "id": "3fd12393",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "5721704.103704132"
            ]
          },
          "execution_count": 200,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "2.7185**15.558528548993964"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 120,
      "id": "743431ba",
      "metadata": {},
      "outputs": [],
      "source": [
        "gradients = t.gradient(rsf, [l.trainable_variables for l in tf_mod.layers])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 165,
      "id": "8c273b7c",
      "metadata": {},
      "outputs": [],
      "source": [
        "old_weigths = tf_mod.get_weights()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 168,
      "id": "156f6c5f",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(6,)"
            ]
          },
          "execution_count": 168,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "old_weigths[1].shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 169,
      "id": "ff37ecbe",
      "metadata": {},
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5bb8abeb",
      "metadata": {},
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "f9c8d76f",
      "metadata": {},
      "outputs": [
        {
          "ename": "TypeError",
          "evalue": "'Variable' object is not iterable.",
          "output_type": "error",
          "traceback": [
            "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[1;32m<ipython-input-12-2f089bbe93ba>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[0msigns\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m{\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mparam\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtf_mod\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweights\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      3\u001b[0m     \u001b[0msigns\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msign\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m     \u001b[0mparam\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mabs_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
            "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\ops\\variables.py\u001b[0m in \u001b[0;36m__iter__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m   1114\u001b[0m       \u001b[0mTypeError\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mwhen\u001b[0m \u001b[0minvoked\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1115\u001b[0m     \"\"\"\n\u001b[1;32m-> 1116\u001b[1;33m     \u001b[1;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"'Variable' object is not iterable.\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1117\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1118\u001b[0m   \u001b[1;31m# NOTE(mrry): This enables the Variable's overloaded \"right\" binary\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
            "\u001b[1;31mTypeError\u001b[0m: 'Variable' object is not iterable."
          ]
        }
      ],
      "source": [
        "signs = {}\n",
        "for l in tf_mod.weights:\n",
        "    signs[l.name] = tf.sign(param)\n",
        "    param.abs_()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 124,
      "id": "6d36e29d",
      "metadata": {},
      "outputs": [],
      "source": [
        "inputs = tf.ones([1] + [32,32,3])\n",
        "output = tf_mod.predict(inputs,steps=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 125,
      "id": "ca76e271",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([[  60.419643, -160.96631 ,  -33.074963,  -22.291834,  -16.564161,\n",
              "         -28.718136,  -63.27598 , -108.74212 ,  -23.881102,  -89.41574 ]],\n",
              "      dtype=float32)"
            ]
          },
          "execution_count": 125,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 72,
      "id": "3fe8397d",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[[ 0.4955282   0.37695056 -0.49467397  0.4096442  -2.5028992   0.9173297\n",
            "  -0.09562689  0.9182922   0.3671408   0.15159744]]\n"
          ]
        }
      ],
      "source": [
        "f = tf_mod(inputs)\n",
        "res = None\n",
        "init = tf.global_variables_initializer()\n",
        "\n",
        "with tf.compat.v1.Session() as sess:# Construct a `Session` to execute the graph.\n",
        "#     result = sess.run(tf_mod(inputs))\n",
        "    sess.run(init)\n",
        "    res = sess.run(f)\n",
        "    print(res)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "7695132d",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:From C:\\Users\\ddima\\anaconda3\\envs\\tfDML\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "If using Keras pass *_constraint arguments to layers.\n"
          ]
        }
      ],
      "source": [
        "from classification_models.tfkeras import Classifiers\n",
        "architecture,preprocess_inputs = Classifiers.get('resnet18')\n",
        "model = architecture(input_shape, classes=n_classes,weights=None,include_top=False)\n",
        "gap_layer = GlobalAveragePooling2D()(model.output)\n",
        "out = Dense(n_classes,activation='softmax')(gap_layer)\n",
        "model = Model(model.input,out)\n",
        "loss = 'categorical_crossentropy'\n",
        "opt = 'adam'\n",
        "model.compile(loss=loss, optimizer=opt, metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "df19bf13",
      "metadata": {},
      "outputs": [],
      "source": [
        "l1 = model.layers[3]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "eed80b56",
      "metadata": {},
      "outputs": [],
      "source": [
        "for lay in model.layers:\n",
        "        if('activation' in lay.__dict__):\n",
        "            if('relu' in str(lay.activation)):\n",
        "                relu_layers.append(lay.output)\n",
        "    model_naswot = Model(model.inputs, relu_layers+[model.layers[-1].output])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "id": "04f9b649",
      "metadata": {},
      "outputs": [],
      "source": [
        "inputs = tf.ones([1] + list(input_shape))\n",
        "output = model.predict()\n",
        "# output = net.forward(inputs)\n",
        "# torch.sum(output).backward() "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "id": "b5f9bc0d",
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "TensorShape([Dimension(1), Dimension(32), Dimension(32), Dimension(3)])"
            ]
          },
          "execution_count": 19,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "inputs.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "85556352",
      "metadata": {},
      "outputs": [],
      "source": [
        "def get_layer_metric_array(mdoel, metric, mode): \n",
        "    metric_array = []\n",
        "\n",
        "    for layer in net.modules():\n",
        "        if mode=='channel' and hasattr(layer,'dont_ch_prune'):\n",
        "            continue\n",
        "        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n",
        "            metric_array.append(metric(layer))\n",
        "    \n",
        "    return metric_array\n",
        "\n",
        "def compute_synflow_per_weight(model, inputs, targets, mode, split_data=1, loss_fn=None):\n",
        "\n",
        "    #convert params to their abs. Keep sign for converting it back.\n",
        "    def linearize(model):\n",
        "        signs = {}\n",
        "        for name, param in net.state_dict().items():\n",
        "            signs[name] = torch.sign(param)\n",
        "            param.abs_()\n",
        "        return signs\n",
        "\n",
        "    #convert to orig values\n",
        "    @torch.no_grad()\n",
        "    def nonlinearize(net, signs):\n",
        "        for name, param in net.state_dict().items():\n",
        "            if 'weight_mask' not in name:\n",
        "                param.mul_(signs[name])\n",
        "\n",
        "    # keep signs of all params\n",
        "    signs = linearize(net)\n",
        "    \n",
        "    # Compute gradients with input of 1s \n",
        "    net.zero_grad()\n",
        "    net.double()\n",
        "    input_dim = list(inputs[0,:].shape)\n",
        "    inputs = torch.ones([1] + input_dim).double().to(device)\n",
        "    output = net.forward(inputs)\n",
        "    torch.sum(output).backward() \n",
        "\n",
        "    # select the gradients that we want to use for search/prune\n",
        "    def synflow(layer):\n",
        "        if layer.weight.grad is not None:\n",
        "            return torch.abs(layer.weight * layer.weight.grad)\n",
        "        else:\n",
        "            return torch.zeros_like(layer.weight)\n",
        "\n",
        "    grads_abs = get_layer_metric_array(net, synflow, mode)\n",
        "\n",
        "    # apply signs of all params\n",
        "    nonlinearize(net, signs)\n",
        "\n",
        "    return grads_abs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6908f527",
      "metadata": {},
      "outputs": [],
      "source": [
        "def get_layer_metric_array(mdoel, metric, mode): \n",
        "    metric_array = []\n",
        "\n",
        "    for layer in net.modules():\n",
        "        if mode=='channel' and hasattr(layer,'dont_ch_prune'):\n",
        "            continue\n",
        "        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n",
        "            metric_array.append(metric(layer))\n",
        "    \n",
        "    return metric_array\n",
        "\n",
        "\n",
        "ds = train_gen.__getitem__(0)\n",
        "x_naswot = ds[0]\n",
        "y_naswot = ds[1]\n",
        "bs = len(x_naswot)\n",
        "model_naswot.K = np.zeros((bs,bs))\n",
        "naswot_score = 1\n",
        "preds = model_naswot.predict(x_naswot)\n",
        "if(type(preds)==type([])):\n",
        "    for l_o in preds:\n",
        "        l_o_temp = l_o.view()\n",
        "        if(len(l_o.shape)>2):\n",
        "            l_o_temp = l_o_temp.reshape(bs,-1)\n",
        "        x = (l_o_temp > 0)\n",
        "        K_temp = x @ x.transpose()\n",
        "        K2_temp = (1.-x) @ (1.-x.transpose())\n",
        "        model_naswot.K = model_naswot.K + K_temp + K2_temp\n",
        "else:\n",
        "    l_o_temp = preds.view()\n",
        "    if(len(l_o_temp.shape)>2):\n",
        "        l_o_temp = l_o_temp.reshape(bs,-1)\n",
        "    x = (l_o_temp > 0)\n",
        "    K_temp = x @ x.transpose()\n",
        "    K2_temp = (1.-x) @ (1.-x.transpose())\n",
        "    model_naswot.K = model_naswot.K + K_temp + K2_temp\n",
        "if(len(np.unique(model_naswot.K))>1):\n",
        "    s, naswot_score = np.linalg.slogdet(model_naswot.K)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "def5dfac",
      "metadata": {},
      "outputs": [],
      "source": [
        "\n",
        "def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn=None):\n",
        "\n",
        "    device = inputs.device\n",
        "\n",
        "    #convert params to their abs. Keep sign for converting it back.\n",
        "    @torch.no_grad()\n",
        "    def linearize(net):\n",
        "        signs = {}\n",
        "        for name, param in net.state_dict().items():\n",
        "            signs[name] = torch.sign(param)\n",
        "            param.abs_()\n",
        "        return signs\n",
        "\n",
        "    #convert to orig values\n",
        "    @torch.no_grad()\n",
        "    def nonlinearize(net, signs):\n",
        "        for name, param in net.state_dict().items():\n",
        "            if 'weight_mask' not in name:\n",
        "                param.mul_(signs[name])\n",
        "\n",
        "    # keep signs of all params\n",
        "    signs = linearize(net)\n",
        "    \n",
        "    # Compute gradients with input of 1s \n",
        "    net.zero_grad()\n",
        "    net.double()\n",
        "    input_dim = list(inputs[0,:].shape)\n",
        "    inputs = torch.ones([1] + input_dim).double().to(device)\n",
        "    output = net.forward(inputs)\n",
        "    torch.sum(output).backward() \n",
        "\n",
        "    # select the gradients that we want to use for search/prune\n",
        "    def synflow(layer):\n",
        "        if layer.weight.grad is not None:\n",
        "            return torch.abs(layer.weight * layer.weight.grad)\n",
        "        else:\n",
        "            return torch.zeros_like(layer.weight)\n",
        "\n",
        "    grads_abs = get_layer_metric_array(net, synflow, mode)\n",
        "\n",
        "    # apply signs of all params\n",
        "    nonlinearize(net, signs)\n",
        "\n",
        "    return grads_abs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1cd5ff26",
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python [conda env:tf2plat]",
      "language": "python",
      "name": "conda-env-tf2plat-py"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
