{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2d00677c",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7eb413e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os, sys\n",
    "sys.path.insert(0, \"..\") \n",
    "import numpy as np\n",
    "from tensorflow.keras.losses import BinaryCrossentropy\n",
    "# import tensorflow_addons as tfa\n",
    "from tensorflow.keras import activations\n",
    "from tensorflow.keras import utils\n",
    "from tensorflow.keras.losses import BinaryCrossentropy\n",
    "from tensorflow.keras.layers import GlobalAveragePooling2D, Dense\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras import backend as K\n",
    "import tensorflow as tf\n",
    "# import pandas as pd\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.utils import plot_model\n",
    "from datasets import get_dataset\n",
    "from dataloaders.datasetFromSequence import DatasetFromSequenceClass \n",
    "from auxiliary.viz_utils import *\n",
    "import pickle as pkl\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8c8c24c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6435338e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow.keras.backend as K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d43b0570",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You selected Cifar10\n"
     ]
    }
   ],
   "source": [
    "dataset, input_shape,n_classes, TRAIN_WITH_GEN, TRAIN_WITH_LOGITS, batch_size, normalize = get_dataset('cifar10')\n",
    "(x_train, y_train), (x_test, y_test) = dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "153de9ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Conv2D, MaxPool2D, Dense,Input,ReLU,Flatten\n",
    "from tensorflow.keras import initializers\n",
    "\n",
    "def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "init = initializers.LecunUniform()\n",
    "inp = Input(shape=(32,32,3))\n",
    "conv1 =  Conv2D(6,kernel_size=(5,5),kernel_initializer=init,bias_initializer=init)\n",
    "pool = MaxPool2D((2,2))\n",
    "conv2 = Conv2D(16,kernel_size=(5,5),kernel_initializer=init,bias_initializer=init)\n",
    "fc1 = Dense(120,kernel_initializer=init,bias_initializer=init)\n",
    "fc2 = Dense(84,kernel_initializer=init, bias_initializer=init)\n",
    "fc3 = Dense(10,kernel_initializer=init, bias_initializer=init)\n",
    "# ops = []\n",
    "x1 = conv1(inp)\n",
    "x2 = ReLU(negative_slope=math.sqrt(5))(x1)\n",
    "x3 = pool(x2)\n",
    "x4 = conv2(x3)\n",
    "x5 = ReLU(negative_slope=math.sqrt(5))(x4)\n",
    "x6 = pool(x5)\n",
    "x7 = Flatten()(x6)\n",
    "x8 = fc1(x7)\n",
    "x9 = ReLU(negative_slope=math.sqrt(5))(x8)\n",
    "x10 = fc2(x9)\n",
    "x11 = ReLU(negative_slope=math.sqrt(5))(x10)\n",
    "x12 = fc3(x11)\n",
    "tf_mod = Model(inp,x12)\n",
    "# tf_mod = Sequential()\n",
    "# tf_mod.add(Input(shape=(32,32,3)))\n",
    "# tf_mod.add(Conv2D(6,kernel_size=(5,5)))\n",
    "# tf_mod.add(MaxPool2D((2,2)))\n",
    "# tf_mod.add(Conv2D(16,kernel_size=(5,5)))\n",
    "# tf_mod.add(Dense(120))\n",
    "# tf_mod.add(Dense(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "9e146442",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[0.1, 1. ],\n",
       "        [5. , 2. ],\n",
       "        [3. , 3. ]],\n",
       "\n",
       "       [[4. , 4. ],\n",
       "        [8. , 5. ],\n",
       "        [5. , 5. ]]])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import copy\n",
    "import types\n",
    "\n",
    "from . import measure\n",
    "from ..p_utils import get_layer_metric_array\n",
    "\n",
    "\n",
    "def snip_forward_conv2d(self, x):\n",
    "        return F.conv2d(x, self.weight * self.weight_mask, self.bias,\n",
    "                        self.stride, self.padding, self.dilation, self.groups)\n",
    "\n",
    "def snip_forward_linear(self, x):\n",
    "        return F.linear(x, self.weight * self.weight_mask, self.bias)\n",
    "\n",
    "@measure('snip', bn=True, mode='param')\n",
    "def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):\n",
    "    for layer in net.modules():\n",
    "        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n",
    "            layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))\n",
    "            layer.weight.requires_grad = False\n",
    "\n",
    "        # Override the forward methods:\n",
    "        if isinstance(layer, nn.Conv2d):\n",
    "            layer.forward = types.MethodType(snip_forward_conv2d, layer)\n",
    "\n",
    "        if isinstance(layer, nn.Linear):\n",
    "            layer.forward = types.MethodType(snip_forward_linear, layer)\n",
    "\n",
    "    # Compute gradients (but don't apply them)\n",
    "    net.zero_grad()\n",
    "    N = inputs.shape[0]\n",
    "    for sp in range(split_data):\n",
    "        st=sp*N//split_data\n",
    "        en=(sp+1)*N//split_data\n",
    "    \n",
    "        outputs = net.forward(inputs[st:en])\n",
    "        loss = loss_fn(outputs, targets[st:en])\n",
    "        loss.backward()\n",
    "\n",
    "    # select the gradients that we want to use for search/prune\n",
    "    def snip(layer):\n",
    "        if layer.weight_mask.grad is not None:\n",
    "            return torch.abs(layer.weight_mask.grad)\n",
    "        else:\n",
    "            return torch.zeros_like(layer.weight)\n",
    "    \n",
    "    grads_abs = get_layer_metric_array(net, snip, mode)\n",
    "\n",
    "    return grads_abs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "ad2be3ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 3)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = np.array([1,2,3])\n",
    "b = np.array([5,6,7])\n",
    "np.stack([a,b]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "37de420b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-11.131913602352142"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum([np.sum(w) for w in tf_mod.get_weights()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "267b5913",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3207.1890330314636"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum([np.sum(w) for w in tf_mod.get_weights()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "e7935c59",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "084eae36",
   "metadata": {},
   "outputs": [],
   "source": [
    "outs = tf_mod(inps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "44ebf993",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_out = 1\n",
    "for o in outs:\n",
    "    f_out *= tf.reduce_prod(o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "a3895331",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=inf>"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d22da31",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "ca99010e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tensorflow.python.keras.layers.convolutional.Conv2D at 0x224e5b74820>,\n",
       " <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x224812df100>,\n",
       " <tensorflow.python.keras.layers.pooling.MaxPooling2D at 0x224e5b741c0>,\n",
       " <tensorflow.python.keras.layers.convolutional.Conv2D at 0x224e5ba79d0>,\n",
       " <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x224e5b69fd0>,\n",
       " <tensorflow.python.keras.layers.core.Flatten at 0x224e3253ca0>,\n",
       " <tensorflow.python.keras.layers.core.Dense at 0x224e5b74d60>,\n",
       " <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x224e5b67670>,\n",
       " <tensorflow.python.keras.layers.core.Dense at 0x224e5ba73d0>,\n",
       " <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x224e326f370>,\n",
       " <tensorflow.python.keras.layers.core.Dense at 0x224e5ba7790>]"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[l for l in tf_mod.layers if(not l.name.startswith('input'))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c6ed1189",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_snip_scores(tf_mod,x_train,y_train,loss_fn,bs):\n",
    "    # Step 1 - take 1 minibatch of data\n",
    "    # Currently generators not supported\n",
    "    inps = x_train[0:bs]\n",
    "    targets = y_train[0:bs]\n",
    "    # Step 2 - Feed through the network\n",
    "    with tf.GradientTape() as t:\n",
    "        outs = tf_mod(inps)\n",
    "        # Step 4 - Sum all outputs into single number = pseudo loss\n",
    "        loss = loss_fn(outs,targets)\n",
    "    # Step 5 - Get gradients for 1 minibatch loss\n",
    "    gradients = t.gradient(loss, [l.trainable_variables for l in tf_mod.layers])\n",
    "    \n",
    "    # Step 6 get scores and apply modulu operator\n",
    "    scores = []\n",
    "    for i_l,l_w in enumerate(tf_mod.layers):\n",
    "        if(len(l_w.trainable_variables)>0):\n",
    "            new_w = l_w.trainable_variables[0]*gradients[i_l][0]\n",
    "            new_w = np.abs(new_w)\n",
    "            scores.append(new_w)\n",
    "            new_b= l_w.trainable_variables[1]*gradients[i_l][1]\n",
    "            new_b = np.abs(new_b)\n",
    "            scores.append(new_b)\n",
    "            \n",
    "    return scores\n",
    "\n",
    "def get_synflow_scores(tf_mod):\n",
    "    #Step 1 - conv to absolute value\n",
    "    signs = {}\n",
    "    new_weights = []\n",
    "    for i_l,l_w in enumerate(tf_mod.get_weights()):\n",
    "        signs[i_l] = np.sign(l_w)\n",
    "        new_w = l_w * signs[i_l]\n",
    "        new_weights.append(new_w)\n",
    "    tf_mod.set_weights(new_weights)\n",
    "    # Step 2 - take data point with 1s\n",
    "    inps = tf.ones([1] + list(input_shape))\n",
    "    # Step 3 - Feed through the network\n",
    "    with tf.GradientTape() as t:\n",
    "        outs = tf_mod(inps)\n",
    "        # Step 4 - Sum all outputs into single number = pseudo loss\n",
    "        rsf = tf.reduce_sum(outs)\n",
    "    # Step 5 - Get gradients for the pseudo loss\n",
    "    gradients = t.gradient(rsf, [l.trainable_variables for l in tf_mod.layers])\n",
    "    # Step 6 - Multiply each weight by the backpropagated signal\n",
    "    scores = []\n",
    "    for i_l,l_w in enumerate(tf_mod.layers):\n",
    "        if(len(l_w.trainable_variables)>0):\n",
    "            new_w = l_w.trainable_variables[0]*gradients[i_l][0]\n",
    "            scores.append(new_w)\n",
    "            new_b= l_w.trainable_variables[1]*gradients[i_l][1]\n",
    "            scores.append(new_b)\n",
    "    # Step 7 - Revert old weights\n",
    "    old_weigths = []\n",
    "    for i_l,l_w in enumerate(tf_mod.get_weights()):\n",
    "        signs[i_l] = np.sign(l_w)\n",
    "        new_w = l_w * signs[i_l]\n",
    "        old_weigths.append(new_w)\n",
    "    tf_mod.set_weights(new_weights)\n",
    "    return scores\n",
    "\n",
    "\n",
    "\n",
    "# def get_synflow_nas_score(scores,tf_mod):\n",
    "#     final_score = get_synflow_scores(tf_mod)\n",
    "#     return final_score\n",
    "        \n",
    "scores = get_synflow_scores(tf_mod)\n",
    "synflow_score = np.log(sum(map(lambda x: np.sum(x),scores)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "499faa5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = tf.keras.losses.categorical_crossentropy\n",
    "bs = 128\n",
    "(x_train,y_train), (x_test,y_test) = dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "e1f9fb49",
   "metadata": {},
   "outputs": [],
   "source": [
    "snip_scores = get_snip_scores(tf_mod,x_train,y_train,loss_fn,bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "a45f460a",
   "metadata": {},
   "outputs": [],
   "source": [
    "snip_score = np.log(sum(map(lambda x: np.sum(x),snip_scores)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "4d723dcf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "22.438517175421747"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "snip_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a7b52860",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-20-f33382e745fc>:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
      "  list(map(lambda x: np.abs(x) if(len(x)>0) else x,gr))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[[],\n",
       " array([<tf.Tensor: shape=(5, 5, 3, 6), dtype=float32, numpy=\n",
       " array([[[[24498382., 24147292., 23762852., 23569360., 23917458.,\n",
       "           23083016.],\n",
       "          [23598200., 23207916., 22836944., 22647042., 22997706.,\n",
       "           22210448.],\n",
       "          [21845816., 21468162., 21113608., 20943962., 21270676.,\n",
       "           20566174.]],\n",
       " \n",
       "         [[24697844., 24411318., 24080478., 23730154., 23987396.,\n",
       "           23404778.],\n",
       "          [23771716., 23480460., 23147116., 22807978., 23059702.,\n",
       "           22544646.],\n",
       "          [21993112., 21710924., 21399644., 21072726., 21313840.,\n",
       "           20866402.]],\n",
       " \n",
       "         [[24523338., 24363860., 24124442., 23611554., 24065948.,\n",
       "           23363274.],\n",
       "          [23591872., 23445376., 23173034., 22694006., 23137784.,\n",
       "           22492608.],\n",
       "          [21828112., 21689278., 21429952., 20981150., 21388652.,\n",
       "           20821778.]],\n",
       " \n",
       "         [[24388536., 24355100., 24124358., 23725024., 24151712.,\n",
       "           23325328.],\n",
       "          [23472012., 23427422., 23191758., 22812652., 23243664.,\n",
       "           22453204.],\n",
       "          [21719116., 21671296., 21458200., 21118490., 21494232.,\n",
       "           20786380.]],\n",
       " \n",
       "         [[24178896., 24140246., 23997244., 23573176., 23996562.,\n",
       "           22915624.],\n",
       "          [23270372., 23212576., 23070408., 22669784., 23097632.,\n",
       "           22056360.],\n",
       "          [21523286., 21464016., 21348704., 20970540., 21368208.,\n",
       "           20407184.]]],\n",
       " \n",
       " \n",
       "        [[[24675748., 24202836., 23979826., 23816604., 24122236.,\n",
       "           23307240.],\n",
       "          [23677402., 23184432., 22970310., 22817692., 23118130.,\n",
       "           22357394.],\n",
       "          [21882334., 21402060., 21199692., 21069196., 21341720.,\n",
       "           20666852.]],\n",
       " \n",
       "         [[24879032., 24377522., 24339458., 23937908., 24136104.,\n",
       "           23582420.],\n",
       "          [23853050., 23358298., 23317980., 22932032., 23123270.,\n",
       "           22639178.],\n",
       "          [22039956., 21552782., 21525970., 21154948., 21338502.,\n",
       "           20911772.]],\n",
       " \n",
       "         [[24778638., 24373892., 24314094., 23817390., 24235826.,\n",
       "           23563906.],\n",
       "          [23736452., 23357624., 23258036., 22813272., 23214988.,\n",
       "           22605870.],\n",
       "          [21941926., 21575546., 21477362., 21061620., 21440382.,\n",
       "           20889164.]],\n",
       " \n",
       "         [[24597942., 24423928., 24195604., 23883376., 24315770.,\n",
       "           23543296.],\n",
       "          [23569562., 23406704., 23158124., 22895564., 23319720.,\n",
       "           22591890.],\n",
       "          [21780962., 21628390., 21395640., 21164958., 21542610.,\n",
       "           20881890.]],\n",
       " \n",
       "         [[24371234., 24134152., 24000866., 23672460., 24024282.,\n",
       "           23058496.],\n",
       "          [23346518., 23117010., 22976104., 22703630., 23034060.,\n",
       "           22122400.],\n",
       "          [21561554., 21342620., 21230212., 20973480., 21277138.,\n",
       "           20434378.]]],\n",
       " \n",
       " \n",
       "        [[[24660570., 24189150., 24144294., 23743308., 24058402.,\n",
       "           23165686.],\n",
       "          [23576060., 23092714., 23052222., 22672940., 22969360.,\n",
       "           22141380.],\n",
       "          [21745488., 21282836., 21246758., 20884974., 21165368.,\n",
       "           20415660.]],\n",
       " \n",
       "         [[24792410., 24266126., 24453376., 23759150., 24196204.,\n",
       "           23499882.],\n",
       "          [23681662., 23163230., 23356816., 22672322., 23104504.,\n",
       "           22473124.],\n",
       "          [21834274., 21320414., 21521678., 20861704., 21284564.,\n",
       "           20705958.]],\n",
       " \n",
       "         [[24692146., 24177008., 24363718., 23578414., 24242532.,\n",
       "           23600190.],\n",
       "          [23571056., 23067094., 23238184., 22489322., 23147626.,\n",
       "           22564532.],\n",
       "          [21737458., 21235594., 21407474., 20694420., 21325152.,\n",
       "           20792430.]],\n",
       " \n",
       "         [[24696958., 24171802., 24199326., 23631364., 24309784.,\n",
       "           23470678.],\n",
       "          [23587932., 23059390., 23081042., 22562728., 23231876.,\n",
       "           22442288.],\n",
       "          [21749052., 21240512., 21266262., 20787214., 21405758.,\n",
       "           20684986.]],\n",
       " \n",
       "         [[24512668., 23929412., 23824860., 23511172., 24038550.,\n",
       "           22994328.],\n",
       "          [23406992., 22821446., 22731152., 22455470., 22971398.,\n",
       "           21979686.],\n",
       "          [21554454., 21003800., 20929754., 20679418., 21161848.,\n",
       "           20241484.]]],\n",
       " \n",
       " \n",
       "        [[[24646524., 24108602., 23934504., 23451698., 23882622.,\n",
       "           22941040.],\n",
       "          [23479322., 22948704., 22779472., 22319126., 22700710.,\n",
       "           21836122.],\n",
       "          [21618156., 21118616., 20960140., 20517120., 20873526.,\n",
       "           20092762.]],\n",
       " \n",
       "         [[24672166., 24244934., 24321758., 23555532., 24144022.,\n",
       "           23253224.],\n",
       "          [23484988., 23078014., 23153518., 22402270., 22961730.,\n",
       "           22143216.],\n",
       "          [21605386., 21211168., 21293352., 20572344., 21105380.,\n",
       "           20358836.]],\n",
       " \n",
       "         [[24484122., 24153000., 24191376., 23417054., 24189408.,\n",
       "           23382390.],\n",
       "          [23292642., 22974454., 22994572., 22248626., 22998734.,\n",
       "           22270976.],\n",
       "          [21432024., 21109334., 21132434., 20424952., 21124586.,\n",
       "           20470518.]],\n",
       " \n",
       "         [[24545642., 23948056., 23992180., 23372974., 24236902.,\n",
       "           23220232.],\n",
       "          [23363114., 22770660., 22798230., 22228452., 23069094.,\n",
       "           22119712.],\n",
       "          [21492650., 20923104., 20944898., 20423754., 21197468.,\n",
       "           20340524.]],\n",
       " \n",
       "         [[24294638., 23693768., 23681138., 23393402., 23948480.,\n",
       "           22749806.],\n",
       "          [23107718., 22520268., 22502216., 22262208., 22803868.,\n",
       "           21657242.],\n",
       "          [21210298., 20665288., 20658310., 20447850., 20949894.,\n",
       "           19889150.]]],\n",
       " \n",
       " \n",
       "        [[[24121886., 23575854., 23278922., 22985938., 23522388.,\n",
       "           22638128.],\n",
       "          [22892054., 22383248., 22071096., 21825364., 22291068.,\n",
       "           21482962.],\n",
       "          [21007572., 20538256., 20225904., 20020158., 20429606.,\n",
       "           19726492.]],\n",
       " \n",
       "         [[24228838., 23786506., 23804950., 23139748., 23744908.,\n",
       "           22781772.],\n",
       "          [22978180., 22584738., 22577742., 21943608., 22506834.,\n",
       "           21615320.],\n",
       "          [21071172., 20714314., 20693556., 20111912., 20617026.,\n",
       "           19825128.]],\n",
       " \n",
       "         [[24057630., 23720142., 23780054., 22996274., 23787546.,\n",
       "           22762872.],\n",
       "          [22811184., 22509876., 22534302., 21783428., 22542972.,\n",
       "           21595716.],\n",
       "          [20916930., 20641660., 20641340., 19949212., 20638598.,\n",
       "           19787034.]],\n",
       " \n",
       "         [[24032460., 23440062., 23454676., 22944258., 23803210.,\n",
       "           22666026.],\n",
       "          [22791518., 22231244., 22217852., 21750864., 22579038.,\n",
       "           21501044.],\n",
       "          [20890312., 20368512., 20339114., 19926774., 20678636.,\n",
       "           19703754.]],\n",
       " \n",
       "         [[23663940., 23144320., 23203420., 22909078., 23388884.,\n",
       "           22265268.],\n",
       "          [22420290., 21932626., 21978934., 21727626., 22191014.,\n",
       "           21109894.],\n",
       "          [20502022., 20054572., 20111104., 19892704., 20308984.,\n",
       "           19325192.]]]], dtype=float32)>,\n",
       "        <tf.Tensor: shape=(6,), dtype=float32, numpy=\n",
       " array([45269450., 44528690., 44459650., 43541932., 44494228., 42811040.],\n",
       "       dtype=float32)>], dtype=object),\n",
       " [],\n",
       " [],\n",
       " array([<tf.Tensor: shape=(5, 5, 6, 16), dtype=float32, numpy=\n",
       " array([[[[7100042. , 6987892. , 6977036. , ..., 6980013. , 7164174. ,\n",
       "           7027491. ],\n",
       "          [6913127.5, 6804306. , 6791624.5, ..., 6796315. , 6973281. ,\n",
       "           6841496. ],\n",
       "          [6317820.5, 6218257.5, 6208734.5, ..., 6211361.5, 6375785. ,\n",
       "           6253781. ],\n",
       "          [6970580. , 6860221.5, 6848070. , ..., 6852493.5, 7032829. ,\n",
       "           6899529.5],\n",
       "          [6309227.5, 6208789. , 6200990.5, ..., 6201986. , 6368947.5,\n",
       "           6245506. ],\n",
       "          [6394779. , 6295136. , 6285504.5, ..., 6286752. , 6455287.5,\n",
       "           6332660.5]],\n",
       " \n",
       "         [[7169085.5, 7046027.5, 7021502.5, ..., 7051331.5, 7211975.5,\n",
       "           7088851. ],\n",
       "          [6974077. , 6855509. , 6830244.5, ..., 6859826.5, 7015385.5,\n",
       "           6896286.5],\n",
       "          [6376620.5, 6267417. , 6244660. , ..., 6272362.5, 6414235. ,\n",
       "           6305855. ],\n",
       "          [7035865. , 6914003.5, 6887472. , ..., 6919755. , 7075603.5,\n",
       "           6957199. ],\n",
       "          [6370453.5, 6257934.5, 6234734. , ..., 6265677. , 6405873. ,\n",
       "           6297942. ],\n",
       "          [6453387.5, 6342598.5, 6320098. , ..., 6347207. , 6491867. ,\n",
       "           6382932. ]],\n",
       " \n",
       "         [[7192715. , 7055124. , 7013507. , ..., 7071575. , 7204405. ,\n",
       "           7090023.5],\n",
       "          [6995559. , 6863906.5, 6821909. , ..., 6878412.5, 7007516. ,\n",
       "           6897419. ],\n",
       "          [6395990. , 6272251.5, 6233479. , ..., 6287983. , 6404206. ,\n",
       "           6304281.5],\n",
       "          [7053875. , 6918662. , 6877375.5, ..., 6933723. , 7064129.5,\n",
       "           6952901. ],\n",
       "          [6389187. , 6262889. , 6224604. , ..., 6279587. , 6395384. ,\n",
       "           6294405. ],\n",
       "          [6476423.5, 6351577.5, 6311642. , ..., 6366684.5, 6485524.5,\n",
       "           6385221. ]],\n",
       " \n",
       "         [[7186146. , 7046940.5, 6992055. , ..., 7064258. , 7180653.5,\n",
       "           7066383. ],\n",
       "          [6987573. , 6855481. , 6800829. , ..., 6870551. , 6984948.5,\n",
       "           6874266.5],\n",
       "          [6387582. , 6262724. , 6214495. , ..., 6278414.5, 6381163. ,\n",
       "           6278568. ],\n",
       "          [7043916. , 6910311.5, 6855190. , ..., 6925553.5, 7039991. ,\n",
       "           6928585. ],\n",
       "          [6375198.5, 6250748.5, 6201442. , ..., 6266149. , 6368347. ,\n",
       "           6265113. ],\n",
       "          [6467707. , 6342573. , 6292953.5, ..., 6357360. , 6463471. ,\n",
       "           6360430.5]],\n",
       " \n",
       "         [[7112261.5, 6979179. , 6922196. , ..., 6990925. , 7096732. ,\n",
       "           6978280. ],\n",
       "          [6922241.5, 6795403. , 6738799.5, ..., 6806000. , 6909824. ,\n",
       "           6795111. ],\n",
       "          [6317842. , 6199574. , 6148619.5, ..., 6211003.5, 6303027. ,\n",
       "           6197564. ],\n",
       "          [6968298. , 6838926.5, 6781846. , ..., 6849878. , 6954591. ,\n",
       "           6838697. ],\n",
       "          [6300259. , 6181430. , 6129837.5, ..., 6191951. , 6283003. ,\n",
       "           6177920.5],\n",
       "          [6402116. , 6283328. , 6229431. , ..., 6294244. , 6388057. ,\n",
       "           6282293. ]]],\n",
       " \n",
       " \n",
       "        [[[7061513. , 6944882. , 6944812. , ..., 6946583.5, 7116867.5,\n",
       "           6976002. ],\n",
       "          [6887326. , 6773385.5, 6772516. , ..., 6774591. , 6940901.5,\n",
       "           6805576.5],\n",
       "          [6286843.5, 6181708.5, 6182063.5, ..., 6184149.5, 6335277.5,\n",
       "           6209743. ],\n",
       "          [6937215. , 6823112. , 6821368. , ..., 6824158. , 6990788.5,\n",
       "           6854030.5],\n",
       "          [6272176.5, 6168343. , 6169162. , ..., 6170946. , 6321852.5,\n",
       "           6196729. ],\n",
       "          [6361892.5, 6257443. , 6259122.5, ..., 6258629.5, 6415128. ,\n",
       "           6287043. ]],\n",
       " \n",
       "         [[7138408. , 7017624.5, 6996009.5, ..., 7030854. , 7176179.5,\n",
       "           7046197. ],\n",
       "          [6957557.5, 6839623. , 6817505.5, ..., 6852099. , 6992245.5,\n",
       "           6868746. ],\n",
       "          [6350249.5, 6243388. , 6221898. , ..., 6256014. , 6382631. ,\n",
       "           6268316. ],\n",
       "          [7013060.5, 6892879. , 6870198. , ..., 6907333.5, 7048653. ,\n",
       "           6922609. ],\n",
       "          [6339041.5, 6230783. , 6207936. , ..., 6244778.5, 6369993. ,\n",
       "           6256376. ],\n",
       "          [6428465.5, 6319320.5, 6300296. , ..., 6331905. , 6463018.5,\n",
       "           6347177. ]],\n",
       " \n",
       "         [[7167683.5, 7033425.5, 6987990.5, ..., 7053654. , 7171245. ,\n",
       "           7055058. ],\n",
       "          [6984882.5, 6853352.5, 6809318.5, ..., 6873656. , 6988711. ,\n",
       "           6875167. ],\n",
       "          [6374940. , 6253786.5, 6211493. , ..., 6273367.5, 6377401. ,\n",
       "           6272484. ],\n",
       "          [7038295. , 6903716. , 6859814.5, ..., 6924361.5, 7040078. ,\n",
       "           6924953. ],\n",
       "          [6363728.5, 6241877. , 6199632.5, ..., 6260819. , 6363618. ,\n",
       "           6259930.5],\n",
       "          [6459607. , 6335572.5, 6293587.5, ..., 6356554. , 6461653. ,\n",
       "           6357570.5]],\n",
       " \n",
       "         [[7160550. , 7020328. , 6968156. , ..., 7039214. , 7143314.5,\n",
       "           7029212.5],\n",
       "          [6976180. , 6840175. , 6789307. , ..., 6858304. , 6960320.5,\n",
       "           6847870. ],\n",
       "          [6364331. , 6238857. , 6192701. , ..., 6254807. , 6346238.5,\n",
       "           6244557. ],\n",
       "          [7027156. , 6889141. , 6838294. , ..., 6907847. , 7008701. ,\n",
       "           6896865.5],\n",
       "          [6348263. , 6222938.5, 6176302. , ..., 6238698. , 6328301.5,\n",
       "           6227881. ],\n",
       "          [6450190.5, 6320286. , 6274147. , ..., 6338359.5, 6431368.5,\n",
       "           6327805. ]],\n",
       " \n",
       "         [[7069279. , 6932785.5, 6882242. , ..., 6942328. , 7034800. ,\n",
       "           6924369.5],\n",
       "          [6891861. , 6760100. , 6709549.5, ..., 6769591.5, 6860230. ,\n",
       "           6752105. ],\n",
       "          [6279650. , 6158492. , 6113539.5, ..., 6166592. , 6247794. ,\n",
       "           6150367.5],\n",
       "          [6933564.5, 6799638.5, 6748878.5, ..., 6809402. , 6899824. ,\n",
       "           6790650. ],\n",
       "          [6257244.5, 6136131.5, 6090506.5, ..., 6143596. , 6222703.5,\n",
       "           6126544. ],\n",
       "          [6368189.5, 6244652. , 6198661.5, ..., 6254171.5, 6335080.5,\n",
       "           6235781. ]]],\n",
       " \n",
       " \n",
       "        [[[6985690. , 6874393. , 6864716.5, ..., 6880625. , 7022390. ,\n",
       "           6892261. ],\n",
       "          [6824157.5, 6714741.5, 6705789.5, ..., 6720909. , 6861543.5,\n",
       "           6732621. ],\n",
       "          [6220151. , 6120620. , 6112338.5, ..., 6126931.5, 6253112.5,\n",
       "           6137292. ],\n",
       "          [6868185.5, 6758822.5, 6748887.5, ..., 6765052.5, 6906121.5,\n",
       "           6776355. ],\n",
       "          [6202966.5, 6104986.5, 6097225. , ..., 6111326.5, 6236802. ,\n",
       "           6121193.5],\n",
       "          [6292030.5, 6190253.5, 6185447.5, ..., 6196521.5, 6327147. ,\n",
       "           6208033. ]],\n",
       " \n",
       "         [[7052848.5, 6949504. , 6917190. , ..., 6958546. , 7087777.5,\n",
       "           6962579. ],\n",
       "          [6888350. , 6784337.5, 6752372.5, ..., 6794752. , 6919964.5,\n",
       "           6797829. ],\n",
       "          [6275054. , 6183088.5, 6151425. , ..., 6192228.5, 6303931. ,\n",
       "           6194253. ],\n",
       "          [6933527.5, 6830416. , 6796945. , ..., 6840595. , 6965991. ,\n",
       "           6843500. ],\n",
       "          [6260127.5, 6168856. , 6135649.5, ..., 6178141. , 6289881. ,\n",
       "           6179798. ],\n",
       "          [6350455. , 6257315. , 6227588. , ..., 6266230. , 6381780.5,\n",
       "           6268714. ]],\n",
       " \n",
       "         [[7074167. , 6968279.5, 6909609. , ..., 6974645.5, 7086555.5,\n",
       "           6974594. ],\n",
       "          [6907424. , 6799991. , 6743839.5, ..., 6807612.5, 6916871. ,\n",
       "           6806843. ],\n",
       "          [6288189. , 6193770.5, 6139098. , ..., 6198849. , 6296875. ,\n",
       "           6198454.5],\n",
       "          [6950479. , 6843355. , 6785819.5, ..., 6849883. , 6957877. ,\n",
       "           6849419.5],\n",
       "          [6273104. , 6179644. , 6125178. , ..., 6184344. , 6281766.5,\n",
       "           6184406.5],\n",
       "          [6370408.5, 6272646.5, 6218535.5, ..., 6280160.5, 6379601.5,\n",
       "           6278571. ]],\n",
       " \n",
       "         [[7054905.5, 6941614. , 6882665. , ..., 6945794. , 7041795. ,\n",
       "           6945250. ],\n",
       "          [6885462. , 6772918.5, 6716915. , ..., 6777767. , 6872952. ,\n",
       "           6776032. ],\n",
       "          [6266313.5, 6166249. , 6113791.5, ..., 6169331. , 6253080. ,\n",
       "           6168062. ],\n",
       "          [6926481.5, 6813712.5, 6756625. , ..., 6817627.5, 6911324. ,\n",
       "           6816427. ],\n",
       "          [6249076.5, 6148684.5, 6096008.5, ..., 6151100.5, 6232481. ,\n",
       "           6150671. ],\n",
       "          [6354649. , 6249175. , 6196124. , ..., 6253944. , 6339168. ,\n",
       "           6252005.5]],\n",
       " \n",
       "         [[6965099. , 6848389.5, 6791370. , ..., 6845332. , 6927464. ,\n",
       "           6842132. ],\n",
       "          [6795289. , 6680622. , 6626635. , ..., 6679066.5, 6759705. ,\n",
       "           6673464. ],\n",
       "          [6178824. , 6075844.5, 6025660.5, ..., 6072661. , 6144349. ,\n",
       "           6069538. ],\n",
       "          [6829336. , 6713331. , 6658464.5, ..., 6711385. , 6791686. ,\n",
       "           6705355. ],\n",
       "          [6153429.5, 6050376.5, 5999430. , ..., 6046475. , 6116814.5,\n",
       "           6043334.5],\n",
       "          [6269482. , 6162481.5, 6112454.5, ..., 6160157. , 6231859. ,\n",
       "           6155518. ]]],\n",
       " \n",
       " \n",
       "        [[[6874438. , 6771810. , 6750538.5, ..., 6782688.5, 6894070.5,\n",
       "           6786028. ],\n",
       "          [6717910. , 6617294.5, 6596637. , ..., 6627401.5, 6737034. ,\n",
       "           6630946. ],\n",
       "          [6118845.5, 6027550.5, 6008658. , ..., 6037335.5, 6134888.5,\n",
       "           6040133. ],\n",
       "          [6755626. , 6655924. , 6633853. , ..., 6666157. , 6776613. ,\n",
       "           6669728.5],\n",
       "          [6092024.5, 6003306. , 5984251. , ..., 6013669. , 6109671. ,\n",
       "           6015382. ],\n",
       "          [6186391. , 6093231. , 6076628. , ..., 6103650. , 6203990.5,\n",
       "           6106568. ]],\n",
       " \n",
       "         [[6924833.5, 6835622. , 6795615. , ..., 6843912. , 6951624.5,\n",
       "           6843510. ],\n",
       "          [6763459. , 6676306. , 6635152.5, ..., 6684075. , 6788554. ,\n",
       "           6682761.5],\n",
       "          [6158902. , 6081108. , 6042789. , ..., 6088932. , 6183315. ,\n",
       "           6087052.5],\n",
       "          [6802508.5, 6715321.5, 6674151. , ..., 6723273.5, 6828341.5,\n",
       "           6722467. ],\n",
       "          [6133054. , 6057036. , 6017219.5, ..., 6063921. , 6158167.5,\n",
       "           6062550.5],\n",
       "          [6230312.5, 6150838. , 6113498. , ..., 6158576.5, 6256056. ,\n",
       "           6157471. ]],\n",
       " \n",
       "         [[6927821. , 6847712.5, 6783432.5, ..., 6845542. , 6947209. ,\n",
       "           6850839. ],\n",
       "          [6770751. , 6689359.5, 6627020. , ..., 6688160.5, 6788063. ,\n",
       "           6691219. ],\n",
       "          [6161476. , 6089158. , 6031050. , ..., 6086681. , 6177870.5,\n",
       "           6091487.5],\n",
       "          [6805893. , 6726261. , 6662117. , ..., 6723645. , 6822084. ,\n",
       "           6728351. ],\n",
       "          [6134276. , 6064268. , 6005481. , ..., 6061405. , 6148978. ,\n",
       "           6066665.5],\n",
       "          [6233586.5, 6161526.5, 6102545.5, ..., 6159527. , 6252168.5,\n",
       "           6164167.5]],\n",
       " \n",
       "         [[6896306. , 6813114. , 6745592. , ..., 6804918.5, 6888320.5,\n",
       "           6816396. ],\n",
       "          [6738225. , 6653708. , 6590366. , ..., 6646419. , 6729887.5,\n",
       "           6656620. ],\n",
       "          [6124137. , 6050926. , 5991716.5, ..., 6042779. , 6116148.5,\n",
       "           6054734. ],\n",
       "          [6770560.5, 6687118. , 6621405. , ..., 6678527.5, 6761083.5,\n",
       "           6690002.5],\n",
       "          [6094982. , 6022500.5, 5960409.5, ..., 6012860. , 6081913.5,\n",
       "           6024668.5],\n",
       "          [6204805. , 6128511. , 6067698.5, ..., 6121896. , 6195497. ,\n",
       "           6131279. ]],\n",
       " \n",
       "         [[6806102. , 6715629. , 6646224. , ..., 6701266. , 6775774. ,\n",
       "           6716257.5],\n",
       "          [6653599. , 6562931. , 6497293.5, ..., 6550093.5, 6623825. ,\n",
       "           6562267. ],\n",
       "          [6042271.5, 5961658. , 5899145.5, ..., 5948336. , 6013114.5,\n",
       "           5961635. ],\n",
       "          [6678070.5, 6588056.5, 6521622. , ..., 6574765. , 6649165. ,\n",
       "           6587739.5],\n",
       "          [6006101.5, 5926440. , 5863144. , ..., 5912300. , 5976493. ,\n",
       "           5926335.5],\n",
       "          [6123470. , 6040483. , 5977854.5, ..., 6026551. , 6091788.5,\n",
       "           6040501. ]]],\n",
       " \n",
       " \n",
       "        [[[6703606.5, 6608599.5, 6580063. , ..., 6617163.5, 6714137. ,\n",
       "           6615619.5],\n",
       "          [6570280. , 6477151. , 6449504. , ..., 6485707. , 6581853.5,\n",
       "           6484294. ],\n",
       "          [5961083. , 5877608. , 5851905. , ..., 5885028.5, 5969895.5,\n",
       "           5882549. ],\n",
       "          [6594474. , 6501955.5, 6473897. , ..., 6511095.5, 6607074. ,\n",
       "           6509050. ],\n",
       "          [5926919. , 5845172. , 5819714. , ..., 5853110.5, 5937766. ,\n",
       "           5850091. ],\n",
       "          [6032198.5, 5946532. , 5922918.5, ..., 5955115.5, 6041000.5,\n",
       "           5952771. ]],\n",
       " \n",
       "         [[6729733. , 6644800. , 6601670.5, ..., 6652584.5, 6748848.5,\n",
       "           6646285. ],\n",
       "          [6594548. , 6512620.5, 6469730. , ..., 6518813.5, 6615724. ,\n",
       "           6514412. ],\n",
       "          [5980947. , 5907592. , 5866895.5, ..., 5913816. , 6000481. ,\n",
       "           5909339. ],\n",
       "          [6620743. , 6538640. , 6494485. , ..., 6545297.5, 6640058.5,\n",
       "           6540624. ],\n",
       "          [5947079. , 5874338. , 5832899. , ..., 5880242.5, 5965479.5,\n",
       "           5875339. ],\n",
       "          [6055423. , 5981288. , 5941150. , ..., 5988206.5, 6075607. ,\n",
       "           5982879.5]],\n",
       " \n",
       "         [[6726517.5, 6649869.5, 6586153. , ..., 6646174. , 6741500.5,\n",
       "           6653723. ],\n",
       "          [6589044.5, 6513698. , 6451657. , ..., 6509913. , 6603664. ,\n",
       "           6517707.5],\n",
       "          [5973829.5, 5905385.5, 5848063. , ..., 5902130. , 5984394. ,\n",
       "           5908833. ],\n",
       "          [6612852. , 6538555. , 6474780. , ..., 6534121.5, 6625760. ,\n",
       "           6542948.5],\n",
       "          [5938222.5, 5872776. , 5813986.5, ..., 5868192. , 5948770. ,\n",
       "           5876323.5],\n",
       "          [6051573.5, 5984004. , 5925686. , ..., 5980578.5, 6065631. ,\n",
       "           5987561.5]],\n",
       " \n",
       "         [[6686759. , 6610033.5, 6541072.5, ..., 6598458. , 6676255.5,\n",
       "           6615689.5],\n",
       "          [6550049. , 6474622.5, 6406666.5, ..., 6462746.5, 6538754. ,\n",
       "           6479695.5],\n",
       "          [5935585. , 5868273. , 5805216. , ..., 5856056. , 5923510.5,\n",
       "           5873133. ],\n",
       "          [6572841. , 6497615.5, 6429021. , ..., 6484801. , 6560543. ,\n",
       "           6502478. ],\n",
       "          [5899642. , 5832932.5, 5768794. , ..., 5819904. , 5885465.5,\n",
       "           5837671. ],\n",
       "          [6011882. , 5945558.5, 5881379. , ..., 5933561. , 6000859.5,\n",
       "           5950564.5]],\n",
       " \n",
       "         [[6615532. , 6534769. , 6462097.5, ..., 6517269. , 6593755.5,\n",
       "           6536953. ],\n",
       "          [6477722. , 6397447. , 6327246. , ..., 6380449. , 6455351. ,\n",
       "           6399730.5],\n",
       "          [5869150. , 5796730. , 5731775. , ..., 5780709.5, 5847201. ,\n",
       "           5797937.5],\n",
       "          [6495200. , 6416379. , 6345335. , ..., 6398844. , 6473716. ,\n",
       "           6418002. ],\n",
       "          [5830576. , 5759159. , 5694130. , ..., 5742299. , 5809943. ,\n",
       "           5760133. ],\n",
       "          [5947932.5, 5875654.5, 5808024. , ..., 5858042. , 5923723.5,\n",
       "           5877275. ]]]], dtype=float32)>,\n",
       "        <tf.Tensor: shape=(16,), dtype=float32, numpy=\n",
       " array([1568006.2, 1544428. , 1534542.8, 1571875.5, 1546273. , 1552475.5,\n",
       "        1578633.6, 1562073.5, 1562487.5, 1561693.5, 1551420.1, 1584267.6,\n",
       "        1564062.9, 1544908. , 1570540. , 1545939.6], dtype=float32)>],\n",
       "       dtype=object),\n",
       " [],\n",
       " [],\n",
       " array([<tf.Tensor: shape=(400, 120), dtype=float32, numpy=\n",
       " array([[602041.7 , 676307.9 , 586654.25, ..., 558899.25, 558983.  ,\n",
       "         594481.56],\n",
       "        [592690.7 , 665809.1 , 577548.75, ..., 550224.56, 550302.9 ,\n",
       "         585252.1 ],\n",
       "        [569912.3 , 640219.4 , 555347.75, ..., 529074.56, 529150.  ,\n",
       "         562757.3 ],\n",
       "        ...,\n",
       "        [534656.  , 600559.3 , 521033.1 , ..., 496500.06, 496469.9 ,\n",
       "         527990.94],\n",
       "        [517854.03, 581691.06, 504661.2 , ..., 480896.97, 480862.47,\n",
       "         511397.5 ],\n",
       "        [513635.  , 576950.9 , 500548.62, ..., 476981.22, 476949.1 ,\n",
       "         507234.4 ]], dtype=float32)>,\n",
       "        <tf.Tensor: shape=(120,), dtype=float32, numpy=\n",
       " array([12380.019 , 13907.33  , 12061.978 , 12148.26  , 11274.277 ,\n",
       "        11390.834 , 11734.613 , 13092.017 , 10770.453 , 10568.196 ,\n",
       "        11801.369 , 12183.264 , 11896.854 , 12358.086 , 11833.978 ,\n",
       "        11690.589 , 11899.969 , 11068.198 , 12608.973 , 12448.198 ,\n",
       "        12452.189 , 12455.535 , 11150.9   , 12661.892 , 11584.066 ,\n",
       "        11882.144 , 12925.558 , 10955.146 , 12521.211 , 12303.146 ,\n",
       "        11293.101 , 12506.583 , 12416.003 , 12473.28  , 11935.315 ,\n",
       "        11450.875 , 11876.287 , 13238.082 , 12922.55  , 13694.152 ,\n",
       "        11994.059 , 11955.037 , 11413.105 , 12087.508 , 12557.483 ,\n",
       "        12654.906 , 13796.367 , 12361.298 , 11870.755 , 11766.352 ,\n",
       "        12054.647 , 13114.776 , 12026.314 , 11425.9795, 12425.837 ,\n",
       "        11092.023 , 10841.483 , 11638.298 , 10834.917 , 11994.991 ,\n",
       "        11393.69  , 12880.693 , 12210.715 , 11918.676 , 11925.006 ,\n",
       "        11736.655 , 11441.564 , 10983.687 , 13085.129 , 12116.849 ,\n",
       "        12974.47  , 11229.95  , 12149.088 , 10573.964 , 11720.536 ,\n",
       "        13011.145 , 11860.201 , 12078.446 , 11783.206 , 12708.741 ,\n",
       "        11903.624 ,  9963.631 , 11512.652 , 12516.899 , 11259.1   ,\n",
       "        12931.355 , 12547.682 , 12350.257 , 12005.122 , 12482.722 ,\n",
       "        12407.752 , 12279.178 , 11295.037 , 13141.008 , 12141.764 ,\n",
       "        13787.384 , 11994.213 ,  9327.378 , 11301.446 , 11733.148 ,\n",
       "        12842.002 , 12665.655 , 11288.749 , 11904.007 , 12086.814 ,\n",
       "        11793.307 , 11127.932 , 12997.7295, 11782.104 , 11396.587 ,\n",
       "        12116.484 , 12764.517 , 11992.292 , 12794.85  , 11540.082 ,\n",
       "        11262.76  , 11733.904 , 11497.808 , 11496.64  , 12225.8125],\n",
       "       dtype=float32)>], dtype=object),\n",
       " [],\n",
       " array([<tf.Tensor: shape=(120, 84), dtype=float32, numpy=\n",
       " array([[1481074.1 ,  779702.25, 1353147.8 , ..., 1353168.9 , 1685998.  ,\n",
       "         1245911.8 ],\n",
       "        [1381998.5 ,  727600.06, 1262681.8 , ..., 1262740.8 , 1573211.5 ,\n",
       "         1162668.2 ],\n",
       "        [1371783.1 ,  722193.44, 1253302.6 , ..., 1253330.6 , 1561503.6 ,\n",
       "         1154096.9 ],\n",
       "        ...,\n",
       "        [1352219.2 ,  711935.  , 1235440.9 , ..., 1235537.  , 1539193.4 ,\n",
       "         1137587.  ],\n",
       "        [1392642.5 ,  733255.1 , 1272303.8 , ..., 1272505.8 , 1585160.6 ,\n",
       "         1171523.1 ],\n",
       "        [1427173.2 ,  751431.3 , 1303939.  , ..., 1304039.  , 1624525.1 ,\n",
       "         1200657.  ]], dtype=float32)>,\n",
       "        <tf.Tensor: shape=(84,), dtype=float32, numpy=\n",
       " array([1769.7874 ,  930.69714, 1618.6388 , 2687.8357 , 2246.0857 ,\n",
       "        2358.1533 , 1986.4025 , 2095.6455 ,  908.2037 , 2145.608  ,\n",
       "        2087.5317 , 1387.2139 , 1947.48   , 1555.6362 , 2051.7495 ,\n",
       "        2042.405  , 1647.5135 , 1521.5493 , 1475.4414 , 2206.7688 ,\n",
       "        1421.4209 , 1850.1775 , 1965.9397 , 1674.004  , 2073.2974 ,\n",
       "        2296.7466 , 1540.822  , 1910.7069 , 1535.0823 , 1431.0554 ,\n",
       "        2165.805  , 2149.0303 , 1936.1979 , 1628.8271 , 1634.9775 ,\n",
       "        1677.4008 , 1642.1497 , 1900.2384 , 1987.6996 , 2151.0276 ,\n",
       "        1439.4268 , 1847.967  , 2315.3127 , 2150.8223 , 1947.4426 ,\n",
       "        2172.8367 , 1592.873  , 1704.5781 , 1814.2656 , 1548.1526 ,\n",
       "        1757.8169 , 2246.0742 , 1300.0217 , 2161.9236 , 1569.8846 ,\n",
       "        2274.2463 , 1679.132  , 1601.769  , 1150.8745 , 1969.1875 ,\n",
       "        1999.8989 , 1958.2102 , 1727.1497 , 1462.7046 , 1966.0791 ,\n",
       "        1727.5178 , 2118.4985 , 2198.6245 , 1652.6825 , 1105.0933 ,\n",
       "        1844.282  , 1971.0325 , 1662.9114 , 1284.4413 , 1615.4094 ,\n",
       "        1558.7397 , 1973.8125 , 1670.9474 , 1564.4335 , 1587.5037 ,\n",
       "        1599.1396 , 1617.1758 , 2018.4188 , 1494.553  ], dtype=float32)>],\n",
       "       dtype=object),\n",
       " [],\n",
       " array([<tf.Tensor: shape=(84, 10), dtype=float32, numpy=\n",
       " array([[14379436., 12900452., 13392522., 13789520., 14184518., 14579000.,\n",
       "         14109568., 13979343., 14474638., 13183284.],\n",
       "        [14921946., 13386975., 13898266., 14309712., 14719803., 15129056.,\n",
       "         14641952., 14506824., 15020696., 13680708.],\n",
       "        [13598858., 12200433., 12665542., 13040739., 13414323., 13787133.,\n",
       "         13343438., 13220178., 13688706., 12468172.],\n",
       "        [15051914., 13503656., 14019010., 14434326., 14847836., 15260571.,\n",
       "         14769274., 14633050., 15151502., 13799840.],\n",
       "        [14638468., 13132652., 13634010., 14037586., 14440066., 14841256.,\n",
       "         14363505., 14231022., 14735283., 13420386.],\n",
       "        [13479790., 12093492., 12554602., 12926896., 13297256., 13666712.,\n",
       "         13226950., 13104852., 13568930., 12358071.],\n",
       "        [14452580., 12966112., 13460634., 13859502., 14256693., 14652852.,\n",
       "         14181099., 14050228., 14548104., 13250364.],\n",
       "        [14391498., 12910965., 13403810., 13800807., 14196192., 14591126.,\n",
       "         14121308., 13990888., 14486506., 13194120.],\n",
       "        [14446000., 12959856., 13454313., 13853246., 14250372., 14646466.,\n",
       "         14174842., 14044166., 14541460., 13243592.],\n",
       "        [13905006., 12474332., 12950664., 13334601., 13716538., 14097797.,\n",
       "         13643975., 13518394., 13996854., 12747876.],\n",
       "        [14230312., 12766808., 13253847., 13646523., 14037522., 14427747.,\n",
       "         13963347., 13834412., 14324612., 13046931.],\n",
       "        [13024292., 11684884., 12130386., 12489844., 12847820., 13204827.,\n",
       "         12779836., 12661737., 13110334., 11941078.],\n",
       "        [12660447., 11358128., 11791438., 12140706., 12488490., 12835564.,\n",
       "         12422442., 12307890., 12743974., 11607549.],\n",
       "        [15207616., 13643169., 14164264., 14583772., 15001668., 15418725.,\n",
       "         14922268., 14784626., 15308236., 13942449.],\n",
       "        [14120566., 12667832., 13151840., 13541582., 13929581., 14316581.,\n",
       "         13855728., 13727890., 14214155., 12946020.],\n",
       "        [13424224., 12043408., 12502583., 12873168., 13241753., 13609919.,\n",
       "         13171835., 13050252., 13512653., 12307600.],\n",
       "        [13806773., 12386612., 12859397., 13240431., 13620046., 13998467.,\n",
       "         13547870., 13422804., 13898170., 12657512.],\n",
       "        [13702574., 12293184., 12762680., 13140456., 13516942., 13892655.,\n",
       "         13445283., 13321508., 13793132., 12562794.],\n",
       "        [13160967., 11807564., 12257838., 12620908., 12982689., 13343373.,\n",
       "         12913738., 12794542., 13247848., 12065628.],\n",
       "        [15015987., 13471083., 13985535., 14399883., 14812618., 15224386.,\n",
       "         14734186., 14598220., 15115252., 13767267.],\n",
       "        [14601832., 13099821., 13599567., 14002498., 14403946., 14804169.,\n",
       "         14327708., 14195418., 14698260., 13386910.],\n",
       "        [14812490., 13288742., 13796098., 14204900., 14612088., 15018116.,\n",
       "         14534882., 14400592., 14910594., 13579185.],\n",
       "        [13337858., 11965943., 12422796., 12790414., 13157064., 13522650.,\n",
       "         13087276., 12966725., 13426030., 12228200.],\n",
       "        [14773402., 13253654., 13759334., 14166909., 14572808., 14977932.,\n",
       "         14495666., 14361957., 14870733., 13545064.],\n",
       "        [14359700., 12882327., 13374075., 13770556., 14164974., 14558876.,\n",
       "         14090218., 13960444., 14454514., 13165030.],\n",
       "        [14393240., 12912771., 13405293., 13802678., 14198192., 14592738.,\n",
       "         14123114., 13992694., 14488377., 13195539.],\n",
       "        [14265272., 12797896., 13286355., 13679934., 14071836., 14462964.,\n",
       "         13997468., 13868145., 14359700., 13078600.],\n",
       "        [13990404., 12551345., 13030322., 13416322., 13800774., 14184485.,\n",
       "         13727825., 13601147., 14082768., 12826115.],\n",
       "        [13995597., 12555376., 13035192., 13421096., 13805838., 14189484.,\n",
       "         13732630., 13606017., 14087961., 12830792.],\n",
       "        [13545194., 12151800., 12615490., 12989462., 13361756., 13733146.,\n",
       "         13291064., 13168384., 13634526., 12417734.],\n",
       "        [14735476., 13220178., 13723988., 14130918., 14535720., 14939748.,\n",
       "         14458900., 14325386., 14832872., 13509718.],\n",
       "        [14294168., 12823438., 13313638., 13707862., 14100732., 14492570.,\n",
       "         14025976., 13896590., 14388854., 13104723.],\n",
       "        [15394408., 13810998., 14337963., 14762889., 15185816., 15607774.,\n",
       "         15105448., 14966000., 15496125., 14113568.],\n",
       "        [13469954., 12084365., 12545604., 12917480., 13287742., 13657004.,\n",
       "         13217372., 13095274., 13559093., 12348750.],\n",
       "        [13537066., 12144770., 12608008., 12981722., 13353628., 13724826.,\n",
       "         13283066., 13160322., 13626528., 12411412.],\n",
       "        [13872240., 12445242., 12920350., 13303125., 13684158., 14064580.,\n",
       "         13611854., 13486208., 13964088., 12718400.],\n",
       "        [13952027., 12516966., 12994589., 13379622., 13762913., 14145269.,\n",
       "         13690092., 13563802., 14044520., 12791350.],\n",
       "        [14399238., 12918254., 13410969., 13808289., 14204061., 14598736.,\n",
       "         14128596., 13998306., 14494376., 13201860.],\n",
       "        [15188073., 13626141., 14146076., 14565068., 14982512., 15398859.,\n",
       "         14903048., 14765662., 15288758., 13923938.],\n",
       "        [14760696., 13242624., 13747917., 14155234., 14560746., 14965419.,\n",
       "         14483604., 14350024., 14858220., 13532810.],\n",
       "        [14511920., 13019390., 13515846., 13916456., 14315194., 14713224.,\n",
       "         14239600., 14108150., 14607831., 13304673.],\n",
       "        [13435962., 12053728., 12514000., 12884326., 13253944., 13622110.,\n",
       "         13183638., 13061862., 13524650., 12317532.],\n",
       "        [15056042., 13507332., 14022558., 14438390., 14852028., 15265086.,\n",
       "         14773660., 14637308., 15155565., 13803645.],\n",
       "        [13295094., 11927436., 12382548., 12749328., 13114752., 13479306.,\n",
       "         13045480., 12925122., 13382879., 12189113.],\n",
       "        [15919310., 14282235., 14826680., 15266118., 15703557., 16139835.,\n",
       "         15620416., 15476194., 16024574., 14595060.],\n",
       "        [13883948., 12455724., 12931089., 13314542., 13695930., 14076609.,\n",
       "         13623432., 13497786., 13975602., 12727914.],\n",
       "        [14675492., 13166062., 13668582., 14073578., 14476896., 14879312.,\n",
       "         14400206., 14267142., 14772564., 13454313.],\n",
       "        [14531850., 13037127., 13534486., 13935741., 14334996., 14733412.,\n",
       "         14259080., 14127564., 14628084., 13323249.],\n",
       "        [13962089., 12526512., 13003877., 13389232., 13772846., 14155524.,\n",
       "         13699768., 13573218., 14054388., 12800057.],\n",
       "        [12819536., 11500962., 11939692., 12293313., 12645838., 12997104.,\n",
       "         12578628., 12462722., 12904224., 11753222.],\n",
       "        [14055936., 12610169., 13091146., 13478952., 13865339., 14250920.,\n",
       "         13792002., 13664680., 14148816., 12886874.],\n",
       "        [13754399., 12339978., 12810442., 13190186., 13568058., 13944996.,\n",
       "         13496141., 13371785., 13845666., 12610234.],\n",
       "        [13771911., 12355104., 12826922., 13206891., 13585570., 13963024.,\n",
       "         13513524., 13388846., 13862920., 12625810.],\n",
       "        [14396206., 12915609., 13408196., 13805386., 14201094., 14595382.,\n",
       "         14125564., 13995274., 14491408., 13197796.],\n",
       "        [13549805., 12156412., 12619844., 12993912., 13366174., 13737758.,\n",
       "         13295610., 13172674., 13639460., 12422474.],\n",
       "        [14589126., 13088469., 13587828., 13990114., 14391304., 14791269.,\n",
       "         14315130., 14183163., 14685618., 13375623.],\n",
       "        [14235214., 12771258., 13258362., 13650974., 14042230., 14432391.,\n",
       "         13967926., 13838798., 14329320., 13050866.],\n",
       "        [14268303., 12800541., 13289064., 13682966., 14074996., 14466189.,\n",
       "         14000628., 13871306., 14362666., 13081696.],\n",
       "        [14624536., 13120784., 13620658., 14024042., 14426006., 14826873.,\n",
       "         14350024., 14216832., 14721093., 13408131.],\n",
       "        [14500116., 13008360., 13504816., 13905490., 14303842., 14701485.,\n",
       "         14228120., 14096926., 14595898., 13293450.],\n",
       "        [15039981., 13493206., 14007723., 14423103., 14836290., 15248703.,\n",
       "         14757794., 14621570., 15139440., 13788810.],\n",
       "        [13591504., 12193725., 12658770., 13033902., 13407422., 13779974.,\n",
       "         13336536., 13213276., 13681418., 12460174.],\n",
       "        [13388748., 12011674., 12469559., 12839370., 13207310., 13574380.,\n",
       "         13137392., 13016196., 13477178., 12274834.],\n",
       "        [13453764., 12069917., 12530640., 12901870., 13271552., 13640363.,\n",
       "         13201505., 13079406., 13542710., 12333593.],\n",
       "        [13435995., 12053824., 12513903., 12884714., 13253976., 13622142.,\n",
       "         13183800., 13062153., 13524812., 12317888.],\n",
       "        [13488692., 12101232., 12563052., 12935346., 13306092., 13675935.,\n",
       "         13235852., 13113624., 13577960., 12366262.],\n",
       "        [15142730., 13585184., 14103634., 14521594., 14937813., 15352677.,\n",
       "         14858672., 14721416., 15242704., 13882270.],\n",
       "        [14905498., 13372527., 13882593., 14293910., 14703420., 15112028.,\n",
       "         14625633., 14490506., 15003926., 13665680.],\n",
       "        [13731212., 12318984., 12788996., 13167998., 13545258., 13921680.,\n",
       "         13473663., 13349049., 13822092., 12588594.],\n",
       "        [14838160., 13312026., 13819706., 14229216., 14636856., 15043722.,\n",
       "         14559456., 14424974., 14936136., 13604146.],\n",
       "        [13996822., 12557118., 13036353., 13422579., 13807515., 14191096.,\n",
       "         13734436., 13607565., 14089380., 12832210.],\n",
       "        [13854729., 12429408., 12903806., 13286420., 13666970., 14047068.,\n",
       "         13594730., 13469406., 13946190., 12702308.],\n",
       "        [15311526., 13736372., 14260692., 14683038., 15103772., 15523538.,\n",
       "         15023727., 14885374., 15412662., 14038038.],\n",
       "        [13936870., 12503744., 12980464., 13365045., 13748207., 14130240.,\n",
       "         13675580., 13549160., 14029169., 12777546.],\n",
       "        [14367182., 12889680., 13380848., 13777716., 14172456., 14566164.,\n",
       "         14097314., 13967346., 14462319., 13172254.],\n",
       "        [14297264., 12827050., 13315832., 13710442., 14103248., 14495150.,\n",
       "         14028621., 13898912., 14391627., 13107690.],\n",
       "        [13470728., 12085268., 12545798., 12917608., 13287742., 13657198.,\n",
       "         13217566., 13095596., 13559544., 12350814.],\n",
       "        [14614023., 13110657., 13611177., 14014366., 14416072., 14816618.,\n",
       "         14339769., 14207544., 14710708., 13398198.],\n",
       "        [13375430., 11999580., 12457336., 12826792., 13194314., 13561190.,\n",
       "         13124460., 13003458., 13463924., 12262740.],\n",
       "        [14736831., 13221597., 13725020., 14131821., 14536816., 14940586.,\n",
       "         14459868., 14326353., 14834162., 13511460.],\n",
       "        [14575388., 13075956., 13575057., 13977408., 14377888., 14777530.,\n",
       "         14301778., 14169876., 14671815., 13362594.],\n",
       "        [13703186., 12293796., 12763034., 13141359., 13517878., 13893654.,\n",
       "         13446476., 13322249., 13793938., 12562310.],\n",
       "        [12766517., 11453490., 11890284., 12242552., 12593399., 12943376.,\n",
       "         12526770., 12411251., 12850818., 11704202.],\n",
       "        [14522949., 13029064., 13526230., 13927098., 14326030., 14724189.,\n",
       "         14250178., 14118534., 14618796., 13314284.]], dtype=float32)>,\n",
       "        <tf.Tensor: shape=(10,), dtype=float32, numpy=\n",
       " array([1934.1714, 1740.7544, 1789.1086, 1837.4628, 1885.8171, 1934.1714,\n",
       "        1869.699 , 1853.5809, 1950.2893, 1772.9905], dtype=float32)>],\n",
       "       dtype=object)]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(map(lambda x: list(map()) if(len(x)>0) else x,gr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "id": "ecb2f128",
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_scores = []\n",
    "for sc in scores:\n",
    "    avg_scores.append(np.average(sc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "id": "f0b92ed8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sum_arr(arr):\n",
    "    s = 0.\n",
    "    for i in range(len(arr)):\n",
    "        s += np.sum(arr[i])\n",
    "    return s.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "id": "a9926f1b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15.558528548993964"
      ]
     },
     "execution_count": 182,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synflow_score = np.log(sum_arr(scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 210,
   "id": "58288bca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1087970.1,\n",
       " 53070.24,\n",
       " 1141039.0,\n",
       " 2303.957,\n",
       " 1143345.1,\n",
       " 54.785492,\n",
       " 1143357.2,\n",
       " 8.138836,\n",
       " 1143412.8,\n",
       " 2.5569668]"
      ]
     },
     "execution_count": 210,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(map(lambda x: np.sum(x),scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 211,
   "id": "e9756300",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5714563.926607132"
      ]
     },
     "execution_count": 211,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 200,
   "id": "a51f267a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5721704.103704132"
      ]
     },
     "execution_count": 200,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "2.7185**15.558528548993964"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "3def439c",
   "metadata": {},
   "outputs": [],
   "source": [
    "gradients = t.gradient(rsf, [l.trainable_variables for l in tf_mod.layers])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "id": "caf124d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "old_weigths = tf_mod.get_weights()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "id": "59ffec33",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6,)"
      ]
     },
     "execution_count": 168,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "old_weigths[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "id": "fa9d0148",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66df4b3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4d31f7d6",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "'Variable' object is not iterable.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-12-2f089bbe93ba>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[0msigns\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m{\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mparam\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtf_mod\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweights\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      3\u001b[0m     \u001b[0msigns\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msign\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m     \u001b[0mparam\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mabs_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\ops\\variables.py\u001b[0m in \u001b[0;36m__iter__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m   1114\u001b[0m       \u001b[0mTypeError\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mwhen\u001b[0m \u001b[0minvoked\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1115\u001b[0m     \"\"\"\n\u001b[1;32m-> 1116\u001b[1;33m     \u001b[1;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"'Variable' object is not iterable.\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1117\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1118\u001b[0m   \u001b[1;31m# NOTE(mrry): This enables the Variable's overloaded \"right\" binary\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mTypeError\u001b[0m: 'Variable' object is not iterable."
     ]
    }
   ],
   "source": [
    "signs = {}\n",
    "for l in tf_mod.weights:\n",
    "    signs[l.name] = tf.sign(param)\n",
    "    param.abs_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "ef8aea4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tf.ones([1] + [32,32,3])\n",
    "output = tf_mod.predict(inputs,steps=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "5891eb50",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  60.419643, -160.96631 ,  -33.074963,  -22.291834,  -16.564161,\n",
       "         -28.718136,  -63.27598 , -108.74212 ,  -23.881102,  -89.41574 ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 125,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "af934831",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.4955282   0.37695056 -0.49467397  0.4096442  -2.5028992   0.9173297\n",
      "  -0.09562689  0.9182922   0.3671408   0.15159744]]\n"
     ]
    }
   ],
   "source": [
    "f = tf_mod(inputs)\n",
    "res = None\n",
    "init = tf.global_variables_initializer()\n",
    "\n",
    "with tf.compat.v1.Session() as sess:# Construct a `Session` to execute the graph.\n",
    "#     result = sess.run(tf_mod(inputs))\n",
    "    sess.run(init)\n",
    "    res = sess.run(f)\n",
    "    print(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4a755b95",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\ddima\\anaconda3\\envs\\tfDML\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    }
   ],
   "source": [
    "from classification_models.tfkeras import Classifiers\n",
    "architecture,preprocess_inputs = Classifiers.get('resnet18')\n",
    "model = architecture(input_shape, classes=n_classes,weights=None,include_top=False)\n",
    "gap_layer = GlobalAveragePooling2D()(model.output)\n",
    "out = Dense(n_classes,activation='softmax')(gap_layer)\n",
    "model = Model(model.input,out)\n",
    "loss = 'categorical_crossentropy'\n",
    "opt = 'adam'\n",
    "model.compile(loss=loss, optimizer=opt, metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "883f31b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "l1 = model.layers[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d1dfe1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for lay in model.layers:\n",
    "        if('activation' in lay.__dict__):\n",
    "            if('relu' in str(lay.activation)):\n",
    "                relu_layers.append(lay.output)\n",
    "    model_naswot = Model(model.inputs, relu_layers+[model.layers[-1].output])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "cd45603d",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tf.ones([1] + list(input_shape))\n",
    "output = model.predict()\n",
    "# output = net.forward(inputs)\n",
    "# torch.sum(output).backward() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "aa74c105",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([Dimension(1), Dimension(32), Dimension(32), Dimension(3)])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bb25033",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_layer_metric_array(mdoel, metric, mode): \n",
    "    metric_array = []\n",
    "\n",
    "    for layer in net.modules():\n",
    "        if mode=='channel' and hasattr(layer,'dont_ch_prune'):\n",
    "            continue\n",
    "        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n",
    "            metric_array.append(metric(layer))\n",
    "    \n",
    "    return metric_array\n",
    "\n",
    "def compute_synflow_per_weight(model, inputs, targets, mode, split_data=1, loss_fn=None):\n",
    "\n",
    "    #convert params to their abs. Keep sign for converting it back.\n",
    "    def linearize(model):\n",
    "        signs = {}\n",
    "        for name, param in net.state_dict().items():\n",
    "            signs[name] = torch.sign(param)\n",
    "            param.abs_()\n",
    "        return signs\n",
    "\n",
    "    #convert to orig values\n",
    "    @torch.no_grad()\n",
    "    def nonlinearize(net, signs):\n",
    "        for name, param in net.state_dict().items():\n",
    "            if 'weight_mask' not in name:\n",
    "                param.mul_(signs[name])\n",
    "\n",
    "    # keep signs of all params\n",
    "    signs = linearize(net)\n",
    "    \n",
    "    # Compute gradients with input of 1s \n",
    "    net.zero_grad()\n",
    "    net.double()\n",
    "    input_dim = list(inputs[0,:].shape)\n",
    "    inputs = torch.ones([1] + input_dim).double().to(device)\n",
    "    output = net.forward(inputs)\n",
    "    torch.sum(output).backward() \n",
    "\n",
    "    # select the gradients that we want to use for search/prune\n",
    "    def synflow(layer):\n",
    "        if layer.weight.grad is not None:\n",
    "            return torch.abs(layer.weight * layer.weight.grad)\n",
    "        else:\n",
    "            return torch.zeros_like(layer.weight)\n",
    "\n",
    "    grads_abs = get_layer_metric_array(net, synflow, mode)\n",
    "\n",
    "    # apply signs of all params\n",
    "    nonlinearize(net, signs)\n",
    "\n",
    "    return grads_abs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a86d893",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_layer_metric_array(mdoel, metric, mode): \n",
    "    metric_array = []\n",
    "\n",
    "    for layer in net.modules():\n",
    "        if mode=='channel' and hasattr(layer,'dont_ch_prune'):\n",
    "            continue\n",
    "        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n",
    "            metric_array.append(metric(layer))\n",
    "    \n",
    "    return metric_array\n",
    "\n",
    "\n",
    "ds = train_gen.__getitem__(0)\n",
    "x_naswot = ds[0]\n",
    "y_naswot = ds[1]\n",
    "bs = len(x_naswot)\n",
    "model_naswot.K = np.zeros((bs,bs))\n",
    "naswot_score = 1\n",
    "preds = model_naswot.predict(x_naswot)\n",
    "if(type(preds)==type([])):\n",
    "    for l_o in preds:\n",
    "        l_o_temp = l_o.view()\n",
    "        if(len(l_o.shape)>2):\n",
    "            l_o_temp = l_o_temp.reshape(bs,-1)\n",
    "        x = (l_o_temp > 0)\n",
    "        K_temp = x @ x.transpose()\n",
    "        K2_temp = (1.-x) @ (1.-x.transpose())\n",
    "        model_naswot.K = model_naswot.K + K_temp + K2_temp\n",
    "else:\n",
    "    l_o_temp = preds.view()\n",
    "    if(len(l_o_temp.shape)>2):\n",
    "        l_o_temp = l_o_temp.reshape(bs,-1)\n",
    "    x = (l_o_temp > 0)\n",
    "    K_temp = x @ x.transpose()\n",
    "    K2_temp = (1.-x) @ (1.-x.transpose())\n",
    "    model_naswot.K = model_naswot.K + K_temp + K2_temp\n",
    "if(len(np.unique(model_naswot.K))>1):\n",
    "    s, naswot_score = np.linalg.slogdet(model_naswot.K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4625df14",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn=None):\n",
    "\n",
    "    device = inputs.device\n",
    "\n",
    "    #convert params to their abs. Keep sign for converting it back.\n",
    "    @torch.no_grad()\n",
    "    def linearize(net):\n",
    "        signs = {}\n",
    "        for name, param in net.state_dict().items():\n",
    "            signs[name] = torch.sign(param)\n",
    "            param.abs_()\n",
    "        return signs\n",
    "\n",
    "    #convert to orig values\n",
    "    @torch.no_grad()\n",
    "    def nonlinearize(net, signs):\n",
    "        for name, param in net.state_dict().items():\n",
    "            if 'weight_mask' not in name:\n",
    "                param.mul_(signs[name])\n",
    "\n",
    "    # keep signs of all params\n",
    "    signs = linearize(net)\n",
    "    \n",
    "    # Compute gradients with input of 1s \n",
    "    net.zero_grad()\n",
    "    net.double()\n",
    "    input_dim = list(inputs[0,:].shape)\n",
    "    inputs = torch.ones([1] + input_dim).double().to(device)\n",
    "    output = net.forward(inputs)\n",
    "    torch.sum(output).backward() \n",
    "\n",
    "    # select the gradients that we want to use for search/prune\n",
    "    def synflow(layer):\n",
    "        if layer.weight.grad is not None:\n",
    "            return torch.abs(layer.weight * layer.weight.grad)\n",
    "        else:\n",
    "            return torch.zeros_like(layer.weight)\n",
    "\n",
    "    grads_abs = get_layer_metric_array(net, synflow, mode)\n",
    "\n",
    "    # apply signs of all params\n",
    "    nonlinearize(net, signs)\n",
    "\n",
    "    return grads_abs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca717620",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf2plat]",
   "language": "python",
   "name": "conda-env-tf2plat-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
