{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 4577,
     "status": "ok",
     "timestamp": 1599580494380,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "KVpizesFCowb"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import scipy\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "from keras.preprocessing.image import load_img\n",
    "from keras.preprocessing.image import img_to_array\n",
    "\n",
    "from tensorflow.keras.applications.mobilenet import MobileNet\n",
    "from keras.applications.mobilenet import preprocess_input\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats\n",
    "\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "f4U93p9iqX_6"
   },
   "source": [
    "## Extracting features from VGG16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 12971,
     "status": "ok",
     "timestamp": 1599580502788,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "1M4GZzmBk2XJ",
    "outputId": "2a4bfb81-1284-4c57-cdd6-dd5329c8bebb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet/mobilenet_1_0_224_tf.h5\n",
      "17227776/17225924 [==============================] - 1s 0us/step\n",
      "Model: \"mobilenet_1.00_224\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         [(None, 224, 224, 3)]     0         \n",
      "_________________________________________________________________\n",
      "conv1_pad (ZeroPadding2D)    (None, 225, 225, 3)       0         \n",
      "_________________________________________________________________\n",
      "conv1 (Conv2D)               (None, 112, 112, 32)      864       \n",
      "_________________________________________________________________\n",
      "conv1_bn (BatchNormalization (None, 112, 112, 32)      128       \n",
      "_________________________________________________________________\n",
      "conv1_relu (ReLU)            (None, 112, 112, 32)      0         \n",
      "_________________________________________________________________\n",
      "conv_dw_1 (DepthwiseConv2D)  (None, 112, 112, 32)      288       \n",
      "_________________________________________________________________\n",
      "conv_dw_1_bn (BatchNormaliza (None, 112, 112, 32)      128       \n",
      "_________________________________________________________________\n",
      "conv_dw_1_relu (ReLU)        (None, 112, 112, 32)      0         \n",
      "_________________________________________________________________\n",
      "conv_pw_1 (Conv2D)           (None, 112, 112, 64)      2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_1_bn (BatchNormaliza (None, 112, 112, 64)      256       \n",
      "_________________________________________________________________\n",
      "conv_pw_1_relu (ReLU)        (None, 112, 112, 64)      0         \n",
      "_________________________________________________________________\n",
      "conv_pad_2 (ZeroPadding2D)   (None, 113, 113, 64)      0         \n",
      "_________________________________________________________________\n",
      "conv_dw_2 (DepthwiseConv2D)  (None, 56, 56, 64)        576       \n",
      "_________________________________________________________________\n",
      "conv_dw_2_bn (BatchNormaliza (None, 56, 56, 64)        256       \n",
      "_________________________________________________________________\n",
      "conv_dw_2_relu (ReLU)        (None, 56, 56, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv_pw_2 (Conv2D)           (None, 56, 56, 128)       8192      \n",
      "_________________________________________________________________\n",
      "conv_pw_2_bn (BatchNormaliza (None, 56, 56, 128)       512       \n",
      "_________________________________________________________________\n",
      "conv_pw_2_relu (ReLU)        (None, 56, 56, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_3 (DepthwiseConv2D)  (None, 56, 56, 128)       1152      \n",
      "_________________________________________________________________\n",
      "conv_dw_3_bn (BatchNormaliza (None, 56, 56, 128)       512       \n",
      "_________________________________________________________________\n",
      "conv_dw_3_relu (ReLU)        (None, 56, 56, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_3 (Conv2D)           (None, 56, 56, 128)       16384     \n",
      "_________________________________________________________________\n",
      "conv_pw_3_bn (BatchNormaliza (None, 56, 56, 128)       512       \n",
      "_________________________________________________________________\n",
      "conv_pw_3_relu (ReLU)        (None, 56, 56, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv_pad_4 (ZeroPadding2D)   (None, 57, 57, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_4 (DepthwiseConv2D)  (None, 28, 28, 128)       1152      \n",
      "_________________________________________________________________\n",
      "conv_dw_4_bn (BatchNormaliza (None, 28, 28, 128)       512       \n",
      "_________________________________________________________________\n",
      "conv_dw_4_relu (ReLU)        (None, 28, 28, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_4 (Conv2D)           (None, 28, 28, 256)       32768     \n",
      "_________________________________________________________________\n",
      "conv_pw_4_bn (BatchNormaliza (None, 28, 28, 256)       1024      \n",
      "_________________________________________________________________\n",
      "conv_pw_4_relu (ReLU)        (None, 28, 28, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_5 (DepthwiseConv2D)  (None, 28, 28, 256)       2304      \n",
      "_________________________________________________________________\n",
      "conv_dw_5_bn (BatchNormaliza (None, 28, 28, 256)       1024      \n",
      "_________________________________________________________________\n",
      "conv_dw_5_relu (ReLU)        (None, 28, 28, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_5 (Conv2D)           (None, 28, 28, 256)       65536     \n",
      "_________________________________________________________________\n",
      "conv_pw_5_bn (BatchNormaliza (None, 28, 28, 256)       1024      \n",
      "_________________________________________________________________\n",
      "conv_pw_5_relu (ReLU)        (None, 28, 28, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv_pad_6 (ZeroPadding2D)   (None, 29, 29, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_6 (DepthwiseConv2D)  (None, 14, 14, 256)       2304      \n",
      "_________________________________________________________________\n",
      "conv_dw_6_bn (BatchNormaliza (None, 14, 14, 256)       1024      \n",
      "_________________________________________________________________\n",
      "conv_dw_6_relu (ReLU)        (None, 14, 14, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_6 (Conv2D)           (None, 14, 14, 512)       131072    \n",
      "_________________________________________________________________\n",
      "conv_pw_6_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_6_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_7 (DepthwiseConv2D)  (None, 14, 14, 512)       4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_7_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_7_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_7 (Conv2D)           (None, 14, 14, 512)       262144    \n",
      "_________________________________________________________________\n",
      "conv_pw_7_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_7_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_8 (DepthwiseConv2D)  (None, 14, 14, 512)       4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_8_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_8_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_8 (Conv2D)           (None, 14, 14, 512)       262144    \n",
      "_________________________________________________________________\n",
      "conv_pw_8_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_8_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_9 (DepthwiseConv2D)  (None, 14, 14, 512)       4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_9_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_9_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_9 (Conv2D)           (None, 14, 14, 512)       262144    \n",
      "_________________________________________________________________\n",
      "conv_pw_9_bn (BatchNormaliza (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_9_relu (ReLU)        (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_10 (DepthwiseConv2D) (None, 14, 14, 512)       4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_10_bn (BatchNormaliz (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_10_relu (ReLU)       (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_10 (Conv2D)          (None, 14, 14, 512)       262144    \n",
      "_________________________________________________________________\n",
      "conv_pw_10_bn (BatchNormaliz (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_10_relu (ReLU)       (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_11 (DepthwiseConv2D) (None, 14, 14, 512)       4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_11_bn (BatchNormaliz (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_11_relu (ReLU)       (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pw_11 (Conv2D)          (None, 14, 14, 512)       262144    \n",
      "_________________________________________________________________\n",
      "conv_pw_11_bn (BatchNormaliz (None, 14, 14, 512)       2048      \n",
      "_________________________________________________________________\n",
      "conv_pw_11_relu (ReLU)       (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_pad_12 (ZeroPadding2D)  (None, 15, 15, 512)       0         \n",
      "_________________________________________________________________\n",
      "conv_dw_12 (DepthwiseConv2D) (None, 7, 7, 512)         4608      \n",
      "_________________________________________________________________\n",
      "conv_dw_12_bn (BatchNormaliz (None, 7, 7, 512)         2048      \n",
      "_________________________________________________________________\n",
      "conv_dw_12_relu (ReLU)       (None, 7, 7, 512)         0         \n",
      "_________________________________________________________________\n",
      "conv_pw_12 (Conv2D)          (None, 7, 7, 1024)        524288    \n",
      "_________________________________________________________________\n",
      "conv_pw_12_bn (BatchNormaliz (None, 7, 7, 1024)        4096      \n",
      "_________________________________________________________________\n",
      "conv_pw_12_relu (ReLU)       (None, 7, 7, 1024)        0         \n",
      "_________________________________________________________________\n",
      "conv_dw_13 (DepthwiseConv2D) (None, 7, 7, 1024)        9216      \n",
      "_________________________________________________________________\n",
      "conv_dw_13_bn (BatchNormaliz (None, 7, 7, 1024)        4096      \n",
      "_________________________________________________________________\n",
      "conv_dw_13_relu (ReLU)       (None, 7, 7, 1024)        0         \n",
      "_________________________________________________________________\n",
      "conv_pw_13 (Conv2D)          (None, 7, 7, 1024)        1048576   \n",
      "_________________________________________________________________\n",
      "conv_pw_13_bn (BatchNormaliz (None, 7, 7, 1024)        4096      \n",
      "_________________________________________________________________\n",
      "conv_pw_13_relu (ReLU)       (None, 7, 7, 1024)        0         \n",
      "_________________________________________________________________\n",
      "global_average_pooling2d (Gl (None, 1024)              0         \n",
      "_________________________________________________________________\n",
      "reshape_1 (Reshape)          (None, 1, 1, 1024)        0         \n",
      "_________________________________________________________________\n",
      "dropout (Dropout)            (None, 1, 1, 1024)        0         \n",
      "_________________________________________________________________\n",
      "conv_preds (Conv2D)          (None, 1, 1, 1000)        1025000   \n",
      "_________________________________________________________________\n",
      "reshape_2 (Reshape)          (None, 1000)              0         \n",
      "_________________________________________________________________\n",
      "predictions (Activation)     (None, 1000)              0         \n",
      "=================================================================\n",
      "Total params: 4,253,864\n",
      "Trainable params: 4,231,976\n",
      "Non-trainable params: 21,888\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "img_size = 224\n",
    "mobilenet = MobileNet(weights='imagenet', include_top=True, pooling='max', input_shape = (img_size, img_size, 3))\n",
    "mobilenet.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 12968,
     "status": "ok",
     "timestamp": 1599580502790,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "xFFpLJ4FNLY9"
   },
   "outputs": [],
   "source": [
    "op = mobilenet.get_layer('reshape_2').output\n",
    "ip = mobilenet.input\n",
    "\n",
    "basemodel = keras.Model(\n",
    "    inputs=ip,\n",
    "    outputs=op,\n",
    ")\n",
    "\n",
    "basemodel.trainable=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 32458,
     "status": "ok",
     "timestamp": 1599580522284,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "kQPMWv1pqmeP"
   },
   "outputs": [],
   "source": [
    "images = np.empty((50, img_size, img_size, 3))\n",
    "for i in range(50):\n",
    "  image = load_img(\"data/\" + str(i+1) +\".jpg\", target_size=(img_size, img_size))\n",
    "  image = img_to_array(image)\n",
    "\n",
    "  for j in range(img_size):\n",
    "    for k in range(img_size):\n",
    "      g = (image[j,k,0] + image[j,k,1] + image[j,k,2])/3\n",
    "      image[j,k,0] = g\n",
    "      image[j,k,1] = g\n",
    "      image[j,k,2] = g\n",
    "\n",
    "  images[i,:,:,:] = image\n",
    "\n",
    "images = preprocess_input(images)\n",
    "features = basemodel.predict(images)\n",
    "\n",
    "colors = np.empty((5, img_size, img_size, 3))\n",
    "for i in range(5):\n",
    "  image = load_img(\"data/C\" + str(i) + \".png\", target_size=(img_size, img_size))\n",
    "  image = img_to_array(image)\n",
    "  colors[i,:,:,:] = image\n",
    "\n",
    "colors = preprocess_input(colors)\n",
    "colorfeatures = basemodel.predict(colors)\n",
    "\n",
    "features = features.reshape(50,-1)\n",
    "colorfeatures = colorfeatures.reshape(5,-1)\n",
    "\n",
    "data = np.genfromtxt('data/bandw.csv', skip_header = 6, delimiter=',')\n",
    "humanpred = np.zeros((50,5))\n",
    "for i in range(50):\n",
    "  for j in range(56):\n",
    "    color = int(data[j, i]);\n",
    "    humanpred[i, color] = humanpred[i, color]+1\n",
    "\n",
    "prob = humanpred/56"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 32458,
     "status": "ok",
     "timestamp": 1599580522287,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "3o4bRkBQ2TwJ"
   },
   "outputs": [],
   "source": [
    "inputimages = np.zeros((250,features.shape[1]))\n",
    "inputcolors = np.zeros((250,features.shape[1]))\n",
    "outputprob = np.zeros((250))\n",
    "\n",
    "for i in range(50):\n",
    "  for j in range(5):\n",
    "    inputimages[5*i+j,:] = features[i,:]\n",
    "\n",
    "for i in range(5):\n",
    "  for j in range(50):\n",
    "    inputcolors[5*j+i] = colorfeatures[i, :]\n",
    "\n",
    "outputprob = prob.reshape(250)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Aqm4sK5LQRUs"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 32067,
     "status": "ok",
     "timestamp": 1599580522288,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "ZRZR9zFuO9-2"
   },
   "outputs": [],
   "source": [
    "def get_model(op_features=75):\n",
    "  imageinputs = keras.Input(shape=(features.shape[1]), name=\"imageinputs\")\n",
    "  colorinputs = keras.Input(shape=(features.shape[1]), name=\"colorinputs\")\n",
    "  dense = layers.Dense(op_features, name = \"linearlayer\", activation='relu')\n",
    "  imageresults = dense(imageinputs)\n",
    "  colorresults = dense(colorinputs)\n",
    "  result = tf.keras.layers.Dot(axes=1, normalize=True, name=\"dot\")([imageresults, colorresults])\n",
    "\n",
    "  model = keras.Model(\n",
    "      inputs=[imageinputs, colorinputs],\n",
    "      outputs=[result],\n",
    "  )\n",
    "  model.compile(loss = 'mean_squared_error', optimizer = tf.keras.optimizers.Adam(lr=0.001), metrics = ['mean_squared_error'])\n",
    "  \n",
    "  return model\n",
    "\n",
    "def get_corr(output, outputprob, vs, ve):\n",
    "  x = np.copy(output.reshape(50,5))\n",
    "\n",
    "  for i in range(50):\n",
    "    rowsum = np.sum(x[i,:])\n",
    "    for j in range(5):\n",
    "      x[i,j] = x[i,j]/rowsum\n",
    "\n",
    "  y = np.copy(outputprob.reshape(50,5))\n",
    "  testx = x.reshape(-1)[vs:ve]\n",
    "  testy = y.reshape(-1)[vs:ve]\n",
    "\n",
    "  return scipy.stats.pearsonr(testx, testy)[0], scipy.stats.pearsonr(testx, testy)[1], testx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 66,
     "referenced_widgets": [
      "84edd15ba25647e2a34cd6edb9b1ce17",
      "1f295e141b714f31a18527a34a442f68",
      "250fa54ef03c458fa0decf326920106a",
      "53c7a2cb61124df1baee57324767577b",
      "9f35ed32f74b41c7a88c5662ddfcdbf1",
      "0ffdb23d6b944eb2b46abeba25dfd00a",
      "50384992e2e84073bbab361d52281f56",
      "36fd627e0d374f798b0dca8f3fe2fd6b"
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 474064,
     "status": "ok",
     "timestamp": 1599580965475,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "ygMZBSStrMBV",
    "outputId": "e8012f3f-e30b-473b-bd41-dc6fc9776ca3"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84edd15ba25647e2a34cd6edb9b1ce17",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "Rp = []\n",
    "pvalue = []\n",
    "\n",
    "for i in tqdm(range(50)):\n",
    "  corr = [] #This contains correlation corresponding to specific validation data\n",
    "  pred = np.zeros(250) #This will contain the final prediction for each of the loop. Total 5 subset of validation 50 each.\n",
    "\n",
    "  # 5-fold cross validation \n",
    "  for vs in [0, 50, 100, 150, 200]:\n",
    "    ve = vs + 50\n",
    "    model = get_model(op_features=75)\n",
    "    history = model.fit(\n",
    "        {\"imageinputs\": inputimages[[*range(vs)] + [*range(ve, 250)]], \"colorinputs\": inputcolors[[*range(vs)] + [*range(ve, 250)]]},\n",
    "        {\"dot\": outputprob[[*range(vs)] + [*range(ve, 250)]]},\n",
    "        epochs=30,\n",
    "        batch_size=10,\n",
    "        shuffle=True,\n",
    "        verbose=0,\n",
    "    )\n",
    "    output = model.predict({\"imageinputs\": inputimages, \"colorinputs\": inputcolors})\n",
    "    R, _, pred[vs:ve] = get_corr(output, outputprob, vs, ve)\n",
    "    corr.append(R)\n",
    "\n",
    "  temp = get_corr(pred, outputprob, 0, 250)\n",
    "  Rp.append(temp[0])\n",
    "  pvalue.append(temp[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 472582,
     "status": "ok",
     "timestamp": 1599580965478,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "_VfRgUs2UIKq",
    "outputId": "de01d452-9220-4b47-e910-8431cc1e47b4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5327758527035148\n",
      "0.027333013270097555\n"
     ]
    }
   ],
   "source": [
    "print(np.mean(Rp))\n",
    "print(np.std(Rp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 867
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 472379,
     "status": "ok",
     "timestamp": 1599580965479,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "R7bCK8y_fXn2",
    "outputId": "66101f5b-5368-4674-be7d-e1804eff9f64"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.4643335900208154,\n",
       " 0.5916653521154563,\n",
       " 0.5185870310582797,\n",
       " 0.5194349553736115,\n",
       " 0.566549185732745,\n",
       " 0.5323723131839688,\n",
       " 0.5301795700501558,\n",
       " 0.5292724826814267,\n",
       " 0.5083769604787874,\n",
       " 0.5109697858227084,\n",
       " 0.5275895929812386,\n",
       " 0.561507418772449,\n",
       " 0.5524841454085143,\n",
       " 0.5770596118034835,\n",
       " 0.562160777933095,\n",
       " 0.5527223198328242,\n",
       " 0.5155137575772851,\n",
       " 0.5424156093872838,\n",
       " 0.5151967030571446,\n",
       " 0.5112222698507539,\n",
       " 0.5204235755909852,\n",
       " 0.5201449275886896,\n",
       " 0.540954860066318,\n",
       " 0.4895335400393026,\n",
       " 0.5228167268189827,\n",
       " 0.5671034327875943,\n",
       " 0.5151137165417249,\n",
       " 0.5505974043268088,\n",
       " 0.5774490916412615,\n",
       " 0.5415809113311139,\n",
       " 0.5333043858823484,\n",
       " 0.5293366932808451,\n",
       " 0.5271670261274783,\n",
       " 0.5696856992092784,\n",
       " 0.5236238527117749,\n",
       " 0.5313123367647407,\n",
       " 0.5005779520847097,\n",
       " 0.4867698910310284,\n",
       " 0.5486593933219697,\n",
       " 0.5211652397182673,\n",
       " 0.5518766496361599,\n",
       " 0.548957560186373,\n",
       " 0.5653071893482361,\n",
       " 0.5428451255826859,\n",
       " 0.5383467223028208,\n",
       " 0.5587276412948893,\n",
       " 0.49227404447644313,\n",
       " 0.47167140926289575,\n",
       " 0.5006652398815824,\n",
       " 0.5611869632164047]"
      ]
     },
     "execution_count": 10,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Rp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 471005,
     "status": "ok",
     "timestamp": 1599580965480,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "LqlaIHSfg_gH",
    "outputId": "0d90e420-b1f2-4a3f-b609-5fe5cd73cc1b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.525612731288252e-16\n",
      "1.3156687381115543e-15\n"
     ]
    }
   ],
   "source": [
    "print(np.mean(pvalue))\n",
    "print(np.std(pvalue))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LOYeXNHQhACc"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "MobilNet.ipynb",
   "provenance": [
    {
     "file_id": "1tveJ219EYVMLwbiPBUAMc1wD0sBPQH4g",
     "timestamp": 1599513973760
    },
    {
     "file_id": "1o0FxaUOCCc6jeP1tPY3D1R9F8slNYkRo",
     "timestamp": 1595668597586
    },
    {
     "file_id": "1UE_yDN7vICMTyp3wv4qOQxfoZYEpAE8f",
     "timestamp": 1595256575799
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "0ffdb23d6b944eb2b46abeba25dfd00a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "1f295e141b714f31a18527a34a442f68": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "250fa54ef03c458fa0decf326920106a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_0ffdb23d6b944eb2b46abeba25dfd00a",
      "max": 50,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_9f35ed32f74b41c7a88c5662ddfcdbf1",
      "value": 50
     }
    },
    "36fd627e0d374f798b0dca8f3fe2fd6b": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "50384992e2e84073bbab361d52281f56": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "53c7a2cb61124df1baee57324767577b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_36fd627e0d374f798b0dca8f3fe2fd6b",
      "placeholder": "​",
      "style": "IPY_MODEL_50384992e2e84073bbab361d52281f56",
      "value": " 50/50 [07:23&lt;00:00,  8.86s/it]"
     }
    },
    "84edd15ba25647e2a34cd6edb9b1ce17": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_250fa54ef03c458fa0decf326920106a",
       "IPY_MODEL_53c7a2cb61124df1baee57324767577b"
      ],
      "layout": "IPY_MODEL_1f295e141b714f31a18527a34a442f68"
     }
    },
    "9f35ed32f74b41c7a88c5662ddfcdbf1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
