{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3871,
     "status": "ok",
     "timestamp": 1599588028779,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "KVpizesFCowb"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import scipy\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "from tensorflow.keras.applications.vgg16 import VGG16\n",
    "from keras.preprocessing.image import load_img\n",
    "from keras.preprocessing.image import img_to_array\n",
    "from keras.applications.vgg16 import preprocess_input\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "f4U93p9iqX_6"
   },
   "source": [
    "## Extracting features from VGG16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 969
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 12588,
     "status": "ok",
     "timestamp": 1599588038187,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "1M4GZzmBk2XJ",
    "outputId": "fddb4f4c-86e7-438f-f930-c19d35cce00c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5\n",
      "553467904/553467096 [==============================] - 3s 0us/step\n",
      "Model: \"vgg16\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         [(None, 224, 224, 3)]     0         \n",
      "_________________________________________________________________\n",
      "block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      \n",
      "_________________________________________________________________\n",
      "block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     \n",
      "_________________________________________________________________\n",
      "block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         \n",
      "_________________________________________________________________\n",
      "block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     \n",
      "_________________________________________________________________\n",
      "block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    \n",
      "_________________________________________________________________\n",
      "block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         \n",
      "_________________________________________________________________\n",
      "block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    \n",
      "_________________________________________________________________\n",
      "block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    \n",
      "_________________________________________________________________\n",
      "block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    \n",
      "_________________________________________________________________\n",
      "block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         \n",
      "_________________________________________________________________\n",
      "block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   \n",
      "_________________________________________________________________\n",
      "block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   \n",
      "_________________________________________________________________\n",
      "block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   \n",
      "_________________________________________________________________\n",
      "block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   \n",
      "_________________________________________________________________\n",
      "block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   \n",
      "_________________________________________________________________\n",
      "block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   \n",
      "_________________________________________________________________\n",
      "block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 25088)             0         \n",
      "_________________________________________________________________\n",
      "fc1 (Dense)                  (None, 4096)              102764544 \n",
      "_________________________________________________________________\n",
      "fc2 (Dense)                  (None, 4096)              16781312  \n",
      "_________________________________________________________________\n",
      "predictions (Dense)          (None, 1000)              4097000   \n",
      "=================================================================\n",
      "Total params: 138,357,544\n",
      "Trainable params: 138,357,544\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "img_size = 224\n",
    "vgg16 = VGG16(weights='imagenet', include_top=True, pooling='max', input_shape = (img_size, img_size, 3))\n",
    "vgg16.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 11203,
     "status": "ok",
     "timestamp": 1599588038194,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "xFFpLJ4FNLY9"
   },
   "outputs": [],
   "source": [
    "op = vgg16.get_layer('fc2').output\n",
    "ip = vgg16.input\n",
    "\n",
    "basemodel = keras.Model(\n",
    "    inputs=ip,\n",
    "    outputs=op,\n",
    ")\n",
    "\n",
    "basemodel.trainable=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 31263,
     "status": "ok",
     "timestamp": 1599588058630,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "kQPMWv1pqmeP"
   },
   "outputs": [],
   "source": [
    "images = np.empty((50, img_size, img_size, 3))\n",
    "for i in range(50):\n",
    "  image = load_img(\"data/\" + str(i+1) +\".jpg\", target_size=(img_size, img_size))\n",
    "  image = img_to_array(image)\n",
    "\n",
    "  for j in range(img_size):\n",
    "    for k in range(img_size):\n",
    "      g = (image[j,k,0] + image[j,k,1] + image[j,k,2])/3\n",
    "      image[j,k,0] = g\n",
    "      image[j,k,1] = g\n",
    "      image[j,k,2] = g\n",
    "\n",
    "  images[i,:,:,:] = image\n",
    "\n",
    "images = preprocess_input(images)\n",
    "features = basemodel.predict(images)\n",
    "\n",
    "colors = np.empty((5, img_size, img_size, 3))\n",
    "for i in range(5):\n",
    "  image = load_img(\"data/C\" + str(i) + \".png\", target_size=(img_size, img_size))\n",
    "  image = img_to_array(image)\n",
    "  colors[i,:,:,:] = image\n",
    "\n",
    "colors = preprocess_input(colors)\n",
    "colorfeatures = basemodel.predict(colors)\n",
    "\n",
    "features = features.reshape(50,-1)\n",
    "colorfeatures = colorfeatures.reshape(5,-1)\n",
    "\n",
    "data = np.genfromtxt('data/bandw.csv', skip_header = 6, delimiter=',')\n",
    "humanpred = np.zeros((50,5))\n",
    "for i in range(50):\n",
    "  for j in range(56):\n",
    "    color = int(data[j, i]);\n",
    "    humanpred[i, color] = humanpred[i, color]+1\n",
    "\n",
    "prob = humanpred/56"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 29869,
     "status": "ok",
     "timestamp": 1599588058631,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "3o4bRkBQ2TwJ"
   },
   "outputs": [],
   "source": [
    "inputimages = np.zeros((250,features.shape[1]))\n",
    "inputcolors = np.zeros((250,features.shape[1]))\n",
    "outputprob = np.zeros((250))\n",
    "\n",
    "for i in range(50):\n",
    "  for j in range(5):\n",
    "    inputimages[5*i+j,:] = features[i,:]\n",
    "\n",
    "for i in range(5):\n",
    "  for j in range(50):\n",
    "    inputcolors[5*j+i] = colorfeatures[i, :]\n",
    "\n",
    "outputprob = prob.reshape(250)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Aqm4sK5LQRUs"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 29123,
     "status": "ok",
     "timestamp": 1599588058633,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "ZRZR9zFuO9-2"
   },
   "outputs": [],
   "source": [
    "def get_model(op_features=200, color_transform=True):\n",
    "  imageinputs = keras.Input(shape=(features.shape[1]), name=\"imageinputs\")\n",
    "  colorinputs = keras.Input(shape=(features.shape[1]), name=\"colorinputs\")\n",
    "\n",
    "  dense = layers.Dense(op_features, name = \"linearlayer\", activation='relu')\n",
    "\n",
    "  imageresults = dense(imageinputs)\n",
    "  colorresults = dense(colorinputs)\n",
    "\n",
    "  if color_transform:\n",
    "    result = tf.keras.layers.Dot(axes=1, normalize=True, name=\"dot\")([imageresults, colorresults])\n",
    "  else:\n",
    "    result = tf.keras.layers.Dot(axes=1, normalize=True, name=\"dot\")([imageinputs, colorinputs])\n",
    "\n",
    "  model = keras.Model(\n",
    "      inputs=[imageinputs, colorinputs],\n",
    "      outputs=[result],\n",
    "  )\n",
    "\n",
    "  model.compile(loss = 'mean_squared_error', optimizer = tf.keras.optimizers.Adam(lr=0.001), metrics = ['mean_squared_error'])\n",
    "\n",
    "  return model\n",
    "\n",
    "def get_corr(output, outputprob, vs, ve, seq=[0, 1, 2, 3, 4]):\n",
    "  x = np.copy(output.reshape(50,5))[:, seq]\n",
    "\n",
    "  for i in range(50):\n",
    "    rowsum = np.sum(x[i,:])\n",
    "    for j in range(5):\n",
    "      x[i,j] = x[i,j]/rowsum\n",
    "\n",
    "  y = np.copy(outputprob.reshape(50,5))\n",
    "\n",
    "  testx = x.reshape(-1)[vs:ve]\n",
    "  testy = y.reshape(-1)[vs:ve]\n",
    "\n",
    "  return scipy.stats.pearsonr(testx, testy)[0], scipy.stats.pearsonr(testx, testy)[1], testx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FWWngjhQJZOP"
   },
   "source": [
    "# Raw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 28887,
     "status": "ok",
     "timestamp": 1599588063491,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "ZHrgyRAOqEBI",
    "outputId": "9bd6392c-6dee-4b9c-b9ab-9e0edc5ac616"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall R: 0.33 and p-value: 0.0\n",
      "24 54 40\n"
     ]
    }
   ],
   "source": [
    "pred = np.zeros(250) #This will contain the final prediction for each of the loop. Total 5 subset of validation 50 each.\n",
    "\n",
    "# 5-fold cross validation \n",
    "for vs in [0, 50, 100, 150, 200]:\n",
    "  ve = vs + 50\n",
    "\n",
    "  model = get_model(op_features=75, color_transform=False)\n",
    "\n",
    "  history = model.fit(\n",
    "      {\"imageinputs\": inputimages[[*range(vs)] + [*range(ve, 250)]], \"colorinputs\": inputcolors[[*range(vs)] + [*range(ve, 250)]]},\n",
    "      {\"dot\": outputprob[[*range(vs)] + [*range(ve, 250)]]},\n",
    "      epochs=30,\n",
    "      batch_size=10,\n",
    "      shuffle=True,\n",
    "      verbose=0,\n",
    "  )\n",
    "\n",
    "  output = model.predict({\"imageinputs\": inputimages, \"colorinputs\": inputcolors})\n",
    "\n",
    "  R, _, pred[vs:ve] = get_corr(output, outputprob, vs, ve)\n",
    "\n",
    "print(\"Overall R:\", round(get_corr(pred, outputprob, 0, 250)[0], 2), \"and p-value:\", round(get_corr(pred, outputprob, 0, 250)[1], 6))\n",
    "\n",
    "prob_model = pred.reshape((50, 5))\n",
    "prob_human = outputprob.reshape((50, 5))\n",
    "correct = np.genfromtxt('data/bandw.csv', skip_header = 1, skip_footer = 60, delimiter=',')\n",
    "human_pred = np.argmax(prob_human, axis=1)\n",
    "model_pred = np.argmax(prob_model, axis=1)\n",
    "\n",
    "acc_model = np.sum(model_pred == correct)*2\n",
    "acc_human = np.sum(human_pred == correct)*2\n",
    "model_wrt_human = np.sum(model_pred == human_pred)*2\n",
    "\n",
    "print(acc_model, acc_human, model_wrt_human)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "g6_tlKOCJezW"
   },
   "source": [
    "# Transformed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 901
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 495783,
     "status": "ok",
     "timestamp": 1599579063104,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "jyxFaGqaB8oW",
    "outputId": "4d0d4fce-48d6-45f7-e861-811ccd81366d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 | Overall R: 0.63 and p-value: 0.0\n",
      "1 | Overall R: 0.62 and p-value: 0.0\n",
      "2 | Overall R: 0.61 and p-value: 0.0\n",
      "3 | Overall R: 0.63 and p-value: 0.0\n",
      "4 | Overall R: 0.62 and p-value: 0.0\n",
      "5 | Overall R: 0.64 and p-value: 0.0\n",
      "6 | Overall R: 0.62 and p-value: 0.0\n",
      "7 | Overall R: 0.63 and p-value: 0.0\n",
      "8 | Overall R: 0.64 and p-value: 0.0\n",
      "9 | Overall R: 0.63 and p-value: 0.0\n",
      "10 | Overall R: 0.62 and p-value: 0.0\n",
      "11 | Overall R: 0.64 and p-value: 0.0\n",
      "12 | Overall R: 0.63 and p-value: 0.0\n",
      "13 | Overall R: 0.64 and p-value: 0.0\n",
      "14 | Overall R: 0.66 and p-value: 0.0\n",
      "15 | Overall R: 0.61 and p-value: 0.0\n",
      "16 | Overall R: 0.62 and p-value: 0.0\n",
      "17 | Overall R: 0.62 and p-value: 0.0\n",
      "18 | Overall R: 0.62 and p-value: 0.0\n",
      "19 | Overall R: 0.63 and p-value: 0.0\n",
      "20 | Overall R: 0.62 and p-value: 0.0\n",
      "21 | Overall R: 0.64 and p-value: 0.0\n",
      "22 | Overall R: 0.6 and p-value: 0.0\n",
      "23 | Overall R: 0.64 and p-value: 0.0\n",
      "24 | Overall R: 0.64 and p-value: 0.0\n",
      "25 | Overall R: 0.64 and p-value: 0.0\n",
      "26 | Overall R: 0.64 and p-value: 0.0\n",
      "27 | Overall R: 0.65 and p-value: 0.0\n",
      "28 | Overall R: 0.65 and p-value: 0.0\n",
      "29 | Overall R: 0.65 and p-value: 0.0\n",
      "30 | Overall R: 0.62 and p-value: 0.0\n",
      "31 | Overall R: 0.61 and p-value: 0.0\n",
      "32 | Overall R: 0.64 and p-value: 0.0\n",
      "33 | Overall R: 0.63 and p-value: 0.0\n",
      "34 | Overall R: 0.65 and p-value: 0.0\n",
      "35 | Overall R: 0.6 and p-value: 0.0\n",
      "36 | Overall R: 0.65 and p-value: 0.0\n",
      "37 | Overall R: 0.65 and p-value: 0.0\n",
      "38 | Overall R: 0.61 and p-value: 0.0\n",
      "39 | Overall R: 0.64 and p-value: 0.0\n",
      "40 | Overall R: 0.6 and p-value: 0.0\n",
      "41 | Overall R: 0.62 and p-value: 0.0\n",
      "42 | Overall R: 0.63 and p-value: 0.0\n",
      "43 | Overall R: 0.65 and p-value: 0.0\n",
      "44 | Overall R: 0.63 and p-value: 0.0\n",
      "45 | Overall R: 0.63 and p-value: 0.0\n",
      "46 | Overall R: 0.65 and p-value: 0.0\n",
      "47 | Overall R: 0.61 and p-value: 0.0\n",
      "48 | Overall R: 0.62 and p-value: 0.0\n",
      "49 | Overall R: 0.64 and p-value: 0.0\n",
      "Actual: 40.2 2.891366458960192\n",
      "wrt Human: 56.12 3.210233636357329\n"
     ]
    }
   ],
   "source": [
    "acc_model = []\n",
    "model_wrt_human = []\n",
    "\n",
    "for i in range(50):\n",
    "  pred = np.zeros(250) #This will contain the final prediction for each of the loop. Total 5 subset of validation 50 each.\n",
    "\n",
    "  # 5-fold cross validation \n",
    "  for vs in [0, 50, 100, 150, 200]:\n",
    "    ve = vs + 50\n",
    "\n",
    "    model = get_model(op_features=75, color_transform=True)\n",
    "\n",
    "    history = model.fit(\n",
    "        {\"imageinputs\": inputimages[[*range(vs)] + [*range(ve, 250)]], \"colorinputs\": inputcolors[[*range(vs)] + [*range(ve, 250)]]},\n",
    "        {\"dot\": outputprob[[*range(vs)] + [*range(ve, 250)]]},\n",
    "        epochs=30,\n",
    "        batch_size=10,\n",
    "        shuffle=True,\n",
    "        verbose=0,\n",
    "    )\n",
    "\n",
    "    output = model.predict({\"imageinputs\": inputimages, \"colorinputs\": inputcolors})\n",
    "\n",
    "    R, _, pred[vs:ve] = get_corr(output, outputprob, vs, ve)\n",
    "\n",
    "  print(i, \"| Overall R:\", round(get_corr(pred, outputprob, 0, 250)[0], 2), \"and p-value:\", round(get_corr(pred, outputprob, 0, 250)[1], 6))\n",
    "\n",
    "  prob_model = pred.reshape((50, 5))\n",
    "  prob_human = outputprob.reshape((50, 5))\n",
    "  correct = np.genfromtxt('data/bandw.csv', skip_header = 1, skip_footer = 60, delimiter=',')\n",
    "  human_pred = np.argmax(prob_human, axis=1)\n",
    "  model_pred = np.argmax(prob_model, axis=1)\n",
    "\n",
    "  acc_model.append(np.sum(model_pred == correct)*2)\n",
    "  acc_human = np.sum(human_pred == correct)*2\n",
    "  model_wrt_human.append(np.sum(model_pred == human_pred)*2)\n",
    "\n",
    "acc_model_mean = np.mean(acc_model)\n",
    "acc_model_std = np.std(acc_model)\n",
    "model_wrt_human_mean = np.mean(model_wrt_human)\n",
    "model_wrt_human_std = np.std(model_wrt_human)\n",
    "\n",
    "print(\"Actual:\", acc_model_mean, acc_model_std)\n",
    "print(\"wrt Human:\", model_wrt_human_mean, model_wrt_human_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1681,
     "status": "ok",
     "timestamp": 1599579064807,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "TdGaQSXR0YzH",
    "outputId": "4b1aebf9-2fa2-4d85-ef91-3f9d16871123"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(46, 64)"
      ]
     },
     "execution_count": 10,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.max(acc_model), np.max(model_wrt_human)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1672,
     "status": "ok",
     "timestamp": 1599579064808,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "gnRm72E30bwz",
    "outputId": "ac0bed7b-afe8-4b62-a9e1-fb8da4be8d98"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(34, 50)"
      ]
     },
     "execution_count": 11,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.min(acc_model), np.min(model_wrt_human)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "fGWK6wFxlgUy"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "One_Shot_Color_Similarity.ipynb",
   "provenance": [
    {
     "file_id": "1UE_yDN7vICMTyp3wv4qOQxfoZYEpAE8f",
     "timestamp": 1595256575799
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
