{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "KVpizesFCowb"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kQPMWv1pqmeP"
   },
   "outputs": [],
   "source": [
    "data = np.genfromtxt('data/bandw.csv', skip_header = 6, delimiter=',')\n",
    "humanpred = np.zeros((50,5))\n",
    "for i in range(50):\n",
    "  for j in range(56):\n",
    "    color = int(data[j, i]);\n",
    "    humanpred[i, color] = humanpred[i, color]+1\n",
    "\n",
    "prob = humanpred/56"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3o4bRkBQ2TwJ"
   },
   "outputs": [],
   "source": [
    "outputprob = prob.reshape(250)\n",
    "\n",
    "prob_human = outputprob.reshape((50, 5))\n",
    "correct = np.genfromtxt('data/bandw.csv', skip_header = 1, skip_footer = 60, delimiter=',')\n",
    "human_pred = np.argmax(prob_human, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "XgGFzMjtyuQy"
   },
   "source": [
    "## Human (Overall)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3171,
     "status": "ok",
     "timestamp": 1599579991137,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "1a-a5vIsyzfM",
    "outputId": "af327317-8333-4467-a137-0eee5157fae9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 54\n"
     ]
    }
   ],
   "source": [
    "print(\"Accuracy\", round(np.sum(human_pred == correct)*2, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JZCvmEKquzlr"
   },
   "source": [
    "## Chance performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3153,
     "status": "ok",
     "timestamp": 1599579991138,
     "user": {
      "displayName": "Shashi Kant",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiHv2woo7B1czb0XKsLBEOsgMWsuYITOryCdNNv7A=s64",
      "userId": "09783338171816410510"
     },
     "user_tz": -330
    },
    "id": "jyxFaGqaB8oW",
    "outputId": "11e2fefa-9620-4485-9065-c5d4710a2359"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean Accuracy 7.9 20.01\n",
      "Standard deviation 2.88 4.44\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(0)\n",
    "\n",
    "acc_wrt_human = []\n",
    "acc_wrt_actual = []\n",
    "for i in range(1000):\n",
    "  pred = np.random.randint(0, 2, data[0].shape[0])\n",
    "  acc_wrt_actual.append(np.sum(pred == correct)*2)\n",
    "  acc_wrt_human.append(np.sum(pred == human_pred)*2)\n",
    "\n",
    "acc_wrt_actual = np.array(acc_wrt_actual)\n",
    "acc_wrt_human = np.array(acc_wrt_human)\n",
    "\n",
    "print(\"Mean Accuracy\", round(np.mean(acc_wrt_human), 2), round(np.mean(acc_wrt_actual), 2))\n",
    "print(\"Standard deviation\", round(np.std(acc_wrt_human), 2), round(np.std(acc_wrt_actual), 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "x0WkCD_WZZLK"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "One_Shot_Human.ipynb",
   "provenance": [
    {
     "file_id": "1UE_yDN7vICMTyp3wv4qOQxfoZYEpAE8f",
     "timestamp": 1595256575799
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
