{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle, os, csv, math, time, joblib\n",
    "from joblib import Parallel, delayed\n",
    "import datetime as dt\n",
    "from datetime import date, datetime, timedelta\n",
    "from collections import Counter\n",
    "import copy as cp\n",
    "import tqdm\n",
    "from sklearn.ensemble import ExtraTreesRegressor, ExtraTreesClassifier\n",
    "from lightgbm import LGBMRegressor, LGBMClassifier\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error\n",
    "from sklearn.metrics import log_loss, f1_score, precision_score, recall_score, accuracy_score\n",
    "#import matplotlib.pyplot as plt\n",
    "#import matplotlib.ticker as ticker\n",
    "import collections \n",
    "#import shap\n",
    "import seaborn as sns\n",
    "import random\n",
    "from sklearn.linear_model import LinearRegression\n",
    "np.seterr(all=\"ignore\")\n",
    "import matplotlib.pyplot as plt\n",
    "import tqdm\n",
    "import math\n",
    "import statsmodels.api as sm\n",
    "import pandas as pd\n",
    "import statsmodels.formula.api as smf\n",
    "import numpy as np\n",
    "import json\n",
    "import util as util_fqi\n",
    "import sys\n",
    "sys.path.append('models/')\n",
    "from lmmfqi import LMMFQIagent\n",
    "from fqi import FQIagent\n",
    "from cfqi import CFQIagent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:01<00:00, 66.87it/s]\n"
     ]
    }
   ],
   "source": [
    "train_tuples, test_tuples = util_fqi.generate_tuples()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/aishwaryamandyam/anaconda3/envs/research/lib/python3.7/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n",
      "  warnings.warn(msg, FutureWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:ylabel='Density'>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAD8CAYAAABthzNFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAjR0lEQVR4nO3de3Bc53nf8e+zNwALEgAv0MWkaFKiZJdKHNdmJLd2kk5dO3JudMZSLbmp1USNkkn0TzLplIlbjaI4M5VnKk0zVusoI3UUporkKI3LJkxVy0qdSWzTomTZEmUzAiWZJESJF9wIYBeL3X36xzkHXEMLYLE4Zy/E7zMDcfecs7vv0WL3h/d9z3mOuTsiIiKLpdrdABER6UwKCBERqUsBISIidSkgRESkLgWEiIjUpYAQEZG6Eg0IM7vJzI6Z2YiZ7a+zvsfMngjXHzaznTXr3mNmXzezo2b2opn1JtlWERH5QYkFhJmlgQeBjwF7gNvMbM+ize4Axt19N/AAcF/42AzwJ8Cvuvv1wD8D5pNqq4iIvF2SPYgbgBF3f9XdS8DjwL5F2+wDHg1vPwl82MwM+CjwHXf/NoC7n3f3SoJtFRGRRTIJPvc24GTN/VPAjUtt4+5lM5sEtgDXAW5mTwHDwOPu/rnlXmzr1q2+c+fOmJouIrI+PPfcc+fcfbjeuiQDYi0ywIeAHwVmga+Y2XPu/pXajczsTuBOgB07dnDkyJGWN1REpJuZ2feXWpfkENMocFXN/e3hsrrbhPMOg8B5gt7G37r7OXefBQ4B71v8Au7+kLvvdfe9w8N1A1BERJqUZEA8C1xrZrvMLAfcChxctM1B4Pbw9s3AMx5UD3wK+GEzy4fB8RPAywm2VUREFklsiCmcU7iL4Ms+DTzi7kfN7F7giLsfBB4GDpjZCDBGECK4+7iZ3U8QMg4ccve/SqqtIiLydnaplPveu3evaw5CRGR1wvndvfXW6UxqERGpSwEhIiJ1KSBERKQuBYSIiNSlgBDpMl87fo69n32aL7/8VrubIpc4BYRIl/mb753h3PQcv3LgCH/1ndPtbo5cwhQQIl3mpdEp3n3FRq7anOfPnju58gNEmqSAEOki7s5Lb0zy/nduYs+VA5wYm213k+QSpoAQ6SInxwpcKJb5oW2D7Nic59R4gWr10jjZVTpPp1ZzFZFFHjt8gpdGJwE4OTbL6ckipXKVMxfmuGJQF1yU+KkHIdJF3pgokDK4fKCXzf05AE6Oa5hJkqGAEOkib0wWuHygl2w6xeZ8GBCah5CEKCBEusjpiSJXhsNJQ/ksBpqolsQoIES6RKlc5cJcma0begDIpFNs7M1wcqzQ5pbJpUoBIdIlxmZLAAtzD9FtDTFJUhQQIl1ifObtAbEpn9MktSRGASHSJcaigMjXBER/jjenisyVK+1qllzCFBAiXWJspkRPJkVfLr2wbHM+hzuMjmseQuKngBDpEmMzJTb35zCzhWUbezML60TipoAQ6RJjs6UfmH8AFnoTE7Pz7WiSXOIUECJdwN0ZnymxKb8oILJhQBQUEBI/BYRIFzh7YY5y1d/Wg8jngiGmiVkNMUn8FBAiXSA6W3pxQPRkU5jBpHoQkgAFhEgXWAiIRUNMKTMG+7Kag5BEKCBEukB0GOtgPvu2dUN9Wc1BSCIUECJd4M2pIvlcmmz67R/ZwXxOcxCSCAWESBd4a6rIYN/bew8Q9CCm1IOQBCQaEGZ2k5kdM7MRM9tfZ32PmT0Rrj9sZjvD5TvNrGBmL4Q/X0iynSKd7s2pIgO9SwREXkNMkozELjlqZmngQeAjwCngWTM76O4v12x2BzDu7rvN7FbgPuCT4brj7v7epNon0k3enJxj19Z83XWapJakJNmDuAEYcfdX3b0EPA7sW7TNPuDR8PaTwIetto6AiFAqVzk/M7d0D6Ivy1RxnkrVW9wyudQlGRDbgJM190+Fy+pu4+5lYBLYEq7bZWbfMrOvmtmPJdhOkY525kIRdxhYYg5iMCzYd6GoXoTEK7EhpjU6Dexw9/Nm9n7gS2Z2vbtP1W5kZncCdwLs2LGjDc0USd5bU0WAZXsQENRjGlp0noTIWiTZgxgFrqq5vz1cVncbM8sAg8B5d59z9/MA7v4ccBy4bvELuPtD7r7X3fcODw8nsAsi7ffm5BwAA331/54bCs+N0ES1xC3JgHgWuNbMdplZDrgVOLhom4PA7eHtm4Fn3N3NbDic5MbMrgauBV5NsK0iHevNlXoQUUDoXAiJWWJDTO5eNrO7gKeANPCIux81s3uBI+5+EHgYOGBmI8AYQYgA/Dhwr5nNA1XgV919LKm2inSyt6aK5DIp8jUXCqo12BcMK6kek8Qt0TkIdz8EHFq07O6a20XgljqP+3Pgz5Nsm0i3eHOyyOUDPSx1gF/Ug1BASNx0JrVIh3tzqsgVA71Lrh+smaQWiZMCQqTDvTVV5PJlAiKbTtGfSysgJHYKCJEOt1JAAAzlc0wUNEkt8VJAiHSwuXKF4nyVTXXKfNdSuQ1JggJCpINdKJaBpc+ijgz0ZXQmtcROASHSwaIy3kudAxHZ2JtdCBORuCggRDrYVPilv7F3+SPSN/ZmFBASOwWESAeLho1WHGLq1UWDJH4KCJEONlUI5yBWGGIa6M0wXSpTVclviZECQqSDTS30IFYaYsriDtMlDTNJfBQQIh0sGmLauFIPIgwQDTNJnBQQIh1sqlAmZdC/RKG+SBQgmqiWOCkgRDrYVHGegb7skoX6ItFRTgoIiZMCQqSDXSiWVzzEFWp7EBpikvgoIEQ62FRhfsUjmCA4igkuTmqLxEEBIdLBpoqNBYTmICQJCgiRDjZVaHSISXMQEj8FhEgHuxBOUq+kN5sml0lpiElipYAQ6WBTxXJDQ0wQzENEZ16LxEEBIdKhKlVneq6xISaIKrqqByHxUUCIdKjpBq8FERlQRVeJWWN/mohISz12+ARjM8ElRF9+Y4rHDp9Y8THqQUjc1IMQ6VDF+QoAvdnGPqYbezML148QiYMCQqRDXQyI5eswRQbUg5CYKSBEOlQUEH0NBoSuKidx0xyESIcqzFeBlXsQ0fzEa+dnmC1VOPD175NOGZ+6cUfibZRLm3oQIh1qtXMQvZkgSObCx4mslQJCpENFAdGTaWyIKRqKKparibVJ1pdEA8LMbjKzY2Y2Ymb766zvMbMnwvWHzWznovU7zGzazH4ryXaKdKLifIWeTIp0avlrQUSinkZRPQiJSWIBYWZp4EHgY8Ae4DYz27NoszuAcXffDTwA3Ldo/f3AXyfVRpFOVpyvNnwEE1ycqygoICQmSfYgbgBG3P1Vdy8BjwP7Fm2zD3g0vP0k8GELL51lZh8HXgOOJthGkY5VmK80PP8A0BdelrRQUkBIPJIMiG3AyZr7p8Jldbdx9zIwCWwxsw3Avwd+d7kXMLM7zeyImR05e/ZsbA0X6QTF+cqqehALcxDqQUhMOnWS+h7gAXefXm4jd3/I3fe6+97h4eHWtEykRYrlysKRSY3o0xCTxCzJ8yBGgatq7m8Pl9Xb5pSZZYBB4DxwI3CzmX0OGAKqZlZ0988n2F6RjlKcr3LZxsYDIpdJkTINMUl8kgyIZ4FrzWwXQRDcCnxq0TYHgduBrwM3A8+4uwM/Fm1gZvcA0woHWW8KpdXNQZgZvdk0s+pBSEwSCwh3L5vZXcBTQBp4xN2Pmtm9wBF3Pwg8DBwwsxFgjCBERNY9d2dulUNMEAwzqQchcUm01Ia7HwIOLVp2d83tInDLCs9xTyKNE+lgpUqVqjdeqC/Sl0trklpi06mT1CLrWjGsw9Roob5IXzatSWqJjQJCpAMt1GHKrb4HoSEmiYsCQqQDLQREZnUfUfUgJE4KCJEOVFjlxYIifdlgDiI4GFBkbRQQIh2o6TmIXJqqw5wqukoMFBAiHWih1PcqzoMAnU0t8VJAiHSg1V6POqKCfRInBYRIByrMV8ikjGxaPQhpHwWESAda7bUgIupBSJwUECIdaLWlviPqQUicFBAiHag4X6FvlRPUUBMQ6kFIDBQQIh2o2R7EQslv9SAkBgoIkQ5UaHIOIir5rYCQOCggRDpQsz0IUMlviU9DAWFm/9PMftrMFCgiLRAERHMfN5X8lrg0+hv4XwmuBveKmf0nM3tXgm0SWdfmyhXKVV91mY2ICvZJXBoKCHd/2t3/FfA+4HXgaTP7mpn9opllk2ygyHozWZgHVn8WdUQlvyUuDfdhzWwL8G+Afwt8C/gvBIHx5URaJrJOTRXKwOoL9UXUg5C4NHTJUTP7C+BdwAHgZ939dLjqCTM7klTjRNajqAfRt8qLBUWiOQh3x8zibJqsM41ek/qPwutLLzCzHnefc/e9CbRLZN2aKq5xiCkblPyeniuzsVcjwNK8RoeYPltn2dfjbIiIBKaiHsQaAgIu9kREmrVsD8LMrgC2AX1m9o+BqL86AOQTbpvIunRxkrr5w1yj59m+KbZmyTq00hDTTxJMTG8H7q9ZfgH4nYTaJLKuTa11DiLqQcyqByFrs2xAuPujwKNm9gl3//MWtUlkXZsszJNNG5nU2nsQImux0hDTL7j7nwA7zew3F6939/vrPExE1mCqUG56/gE0ByHxWWmIqT/8d0PSDRGRwGRhvukjmEABIfFZaYjpD8N/f7c1zRGRycL8mnoQUclvBYSsVaPF+j5nZgNmljWzr5jZWTP7hQYed5OZHTOzETPbX2d9j5k9Ea4/bGY7w+U3mNkL4c+3zeznV71nIl1qqjjf9AQ1XCz5rYCQtWp0Fuyj7j4F/AxBLabdwL9b7gFmlgYeBD4G7AFuM7M9iza7Axh3993AA8B94fKXgL3u/l7gJuAPzazRk/pEutpaexAQDDMpIGStGg2I6Mv5p4E/c/fJBh5zAzDi7q+6ewl4HNi3aJt9wKPh7SeBD5uZufusu5fD5b2AN9hOka43tcY5CAiOZFJAyFo1GhB/aWbfA94PfMXMhoHiCo/ZBpysuX8qXFZ3mzAQJoEtAGZ2o5kdBV4EfrUmMBaY2Z1mdsTMjpw9e7bBXRHpXNWqc2GuvKYhJoB8Lr1wPoVIsxot970f+KcEwz7zwAxv7w3Eyt0Pu/v1wI8Cv21mvXW2ecjd97r73uHh4SSbI9ISF4pl3JuvwxTRHITEYTXj+u8mOB+i9jF/vMz2o8BVNfe3h8vqbXMqfN5B4HztBu7+XTObBn4IUOVYuaRFhfrimIM4OTYbR5NkHWu03PcB4BrgBSAqNO8sHxDPAtea2S6CILiV4Kp0tQ4CtxMU/rsZeMbdPXzMSXcvm9k7CcLp9UbaKtLNJtdYqC8SzUFUq04qpZLf0pxGexB7gT3u3vBkcfjlfhfwFJAGHnH3o2Z2L3DE3Q8CDwMHzGwEGCMIEYAPAfvNbB6oAr/m7ucafW2RbhXNG/Tm1nb594WS36UyAyr5LU1qNCBeAq4ATq+0Ya3wGhKHFi27u+Z2EbilzuMOEFycSGRdia0HUVOwTwEhzWo0ILYCL5vZN4G5aKG7/1wirRJZp+IcYoqe76oVthVZSqMBcU+SjRCRQJyT1IAOdZU1aSgg3P2r4WTxte7+tJnlCeYVRCRGk4V50ikjl1njHETYg5hQQMgaNFqL6ZcJznT+w3DRNuBLCbVJZN2aKpQZ6M1gtrYjj/K54G+/8dlSHM2SdarRP1N+HfggMAXg7q8AlyXVKJH1arIwz2Df2ieV82EPYnxGASHNazQg5sJ6SgCEJ7WpPpJIzKaK8wzEEBDZdIp8Ls24Ljsqa9BoQHzVzH4H6DOzjwB/Bvzv5Jolsj7F1YMA2JTPaYhJ1qTRgNgPnCUonPcrBOc2/IekGiWyXk0W4jtvYVN/VkNMsiaNHsVUNbMvAV9yd5VNFUnIVKEcyxATRD0IDTFJ85btQVjgHjM7BxwDjoVXk7t7uceJyOq5O1MaYpIOstIQ028QHL30o+6+2d03AzcCHzSz30i8dSLryFy5SqlSZaAvnosnbspriEnWZqWA+NfAbe7+WrTA3V8FfgH4dJINE1lvojIbsfUg+nNMFcuUK9VYnk/Wn5UCIluvimo4D6EKYCIxigIitknqfA7Q2dTSvJUCYrn+qfquIjGaSqAHATCheQhp0kqDnT9iZlN1lhvwtkuAikjzaoeYTo0X1vx8m/JB0IzNqAchzVk2INxdBflEWiSq5BrnYa6gekzSvLWVjBSR2EzOJjPEpCOZpFkKCJEOMVkoA7CxN77DXAGdLCdNU0CIdIip4jz9uTTZdDwfy75smp5MSkNM0jQFhEiHmCzEU8k1YmbB2dQaYpImKSBEOkScZTYim/pVj0map4AQ6RBx9yAgLLehISZpkgJCpEPEWeo7ooJ9shYKCJEOcaFYTmCISQX7pHkKCJEOEQwxxXOIa2RTPsdkYZ5KVVcIltVTQIh0gHKlyvRcAj2IfI6qX6zzJLIaCgiRDnChGJwkF/scRH90spyGmWT1Eg0IM7vJzI6Z2YiZ7a+zvsfMngjXHzazneHyj5jZc2b2YvjvP0+ynSLtFn2BR1/ocblYj0k9CFm9xALCzNLAg8DHgD3AbWa2Z9FmdwDj7r4beAC4L1x+DvhZd/9h4HbgQFLtFOkE0Rf4UPiFHpeFgNBEtTQhyR7EDcCIu7/q7iXgcWDfom32AY+Gt58EPmxm5u7fcvc3wuVHgT4z60mwrSJtFV2zYVNSAaEhJmlCkgGxDThZc/9UuKzuNu5eBiaBLYu2+QTwvLvPLX4BM7vTzI6Y2ZGzZ8/G1nCRVot6EJvjDgjNQcgadPQktZldTzDs9Cv11rv7Q+6+1933Dg8Pt7ZxIjGKehBDMc9BbOjJkEmZ5iCkKUkGxChwVc397eGyutuYWQYYBM6H97cDfwF82t2PJ9hOkbYbmymRSRkbe+I9D8LMgnpMmoOQJiQZEM8C15rZLjPLAbcCBxdtc5BgEhrgZuAZd3czGwL+Ctjv7n+fYBtFOsL47DxD+SxmFvtzqx6TNCuxgAjnFO4CngK+C3zR3Y+a2b1m9nPhZg8DW8xsBPhNIDoU9i5gN3C3mb0Q/lyWVFtF2m1ithT7BHVkKK+KrtKcePuzi7j7IeDQomV319wuArfUedxngc8m2TaRTjKeYEBszuc4fnY6keeWS1tHT1KLrBcT4RBTEjb1Z9WDkKYoIEQ6wNhMcj2ITfkcE7Ml3FWwT1Yn0SEmEVmZuwc9iJgPcX3s8AkAXjs3Q7nq/Pe/f53ebJpP3bgj1teRS5d6ECJtNluqUKpUYz9JLtKfC/4OnJkrJ/L8culSQIi02XhCZTYi+VwaCIJIZDUUECJtNrFQqC+ZSWoFhDRLcxAibfTY4RO88tYFAI68Ps656fhPaMuHZ2fPlDTEJKujHoRIm0V/2ed70ok8/4YezUFIcxQQIm02G/5ln88l06HvyaTIpIxpBYSskgJCpM2iHkRfNpkehJnR35NhuqiAkNVRQIi02WypQm82RToVf6G+yIaejOYgZNUUECJtNlMqL5yrkJQN6kFIExQQIm02PVdemEhOyoaejOYgZNUUECJtNjNXpj/hgOjvyTAzV1E9JlkVBYRIm03PVZLvQfRmqLhTnK8m+jpyaVFAiLRR1Z3ZFvQgNoTnWGiYSVZDASHSRrOlCs7FL/CkbOgJyngoIGQ1FBAibRR9YW/oTaYOU6RfPQhpggJCpI2i8hf9ifcggiEsBYSshgJCpI0WehAJnweRz2UwVI9JVkcBIdJG0Rd20kcxpVNGPpfWyXKyKgoIkTaaniuTMujNJTvEBMG5EBpiktVQQIi00XQxKLORsuTqMEU29CogZHUUECJt1IqzqCMqtyGrpYAQaaPpuTIbelsTEAO9WaYK8yq3IQ1TQIi00Uwp+TIbkaF8lnLVGZuJ/7KmcmlSQIi00fRcmf4WTFADDPXlABidKLTk9aT7JRoQZnaTmR0zsxEz219nfY+ZPRGuP2xmO8PlW8zsb8xs2sw+n2QbRdqlUKpQKldb2oMAeEMBIQ1KLCDMLA08CHwM2APcZmZ7Fm12BzDu7ruBB4D7wuVF4D8Cv5VU+0Ta7dz0HEDLJqmH+oKAGJ0otuT1pPsl2YO4ARhx91fdvQQ8DuxbtM0+4NHw9pPAh83M3H3G3f+OIChELknRX/KDfcnWYYr05dLk0in1IKRhSQbENuBkzf1T4bK627h7GZgEtiTYJpGOEc0FDOVzLXk9M2Mwn2V0XAEhjenqSWozu9PMjpjZkbNnz7a7OSKrEn1RR3MDrTDUl+WNSQWENCbJgBgFrqq5vz1cVncbM8sAg8D5Rl/A3R9y973uvnd4eHiNzRVprdGJAv09GbLp1v2dNpTPaYhJGpbkb+azwLVmtsvMcsCtwMFF2xwEbg9v3ww84zqLR9aJ0YkCm1rYe4Cgt3JuukRxvtLS15XulNjhE+5eNrO7gKeANPCIux81s3uBI+5+EHgYOGBmI8AYQYgAYGavAwNAzsw+DnzU3V9Oqr0irTY6XmjZ/EMkOpLpjYkCVw9vaOlrS/dJ9Pg6dz8EHFq07O6a20XgliUeuzPJtom0k7szOlHghp2bW/q6gwvnQhQVELKirp6kFulW56ZLzJWrLZ2gBti0cDb1bEtfV7qTAkKkDU6NB1/QrR5iGujLkkuneO2cAkJWpoAQaYOL50C0tgeRThnv3JLn+Nnplr6udCcFhEgbROdAbGpxDwLgmuENvKqAkAYoIETaYHSiwEBvht5sayq51rrmsn6+f36W+Uq15a8t3UUBIdIGJ8dm2b4p35bXvnrrBspV58SY5iFkeQoIkTYYOTvNNZe15zDT6HWPn9EwkyxPASHSYsX5CqfGC+xu03kIVw/3A3D87ExbXl+6hwJCpMWOn53GPZgLaIeB3iyXbezRRLWsSAEh0mLRX+672zTEBEEvQoe6ykoUECItNnJmmpTBrq3t6UFAcKjrK2emUW1MWY4CQqTFjp+ZZsfmPD2Z1h/iGrn+HYNcKJY5OabS37I0BYRIi42cmW7r8BLAD28bBODF0cm2tkM6mwJCpIUqVee1czNc0+ZKqtddsYFcOsV3Rifa2g7pbAoIkRY6OTZLqVJt2zkQkZ5MmndfuZEXT6kHIUtL9HoQIvKDoiGdPVcOtK0Njx0+AUBvNs3zJ8b5k298n5QZn7pxR9vaJJ1JPQiRFnr+xDi92RTvvmJju5vC9qE+ivNVxmZK7W6KdCgFhEgLfevEBO/ZPkQm3f6P3juG+oCLlWVFFtMQk0gLPHb4BOVKlRdHJ/ngNVsWhnna6fKBXnLpFK+fn+FHrhpqd3OkA7X/zxiRdeKNySKVqnPV5vZUcV0snTKuHu7nFRXtkyUoIERa5GRYXrtTAgKCch9jMyXOT8+1uynSgRQQIi3y+vkZhvJZBnpbe5nR5Vx3WTBZrl6E1KOAEGmBycI83zt9gXdf0b7DW+vZsiHHUD6rgJC6FBAiLfC1kXM4zod2b213U36AmXHtZRt59ew0hVKl3c2RDqOAEEnY2EyJb74+xg9tG2Rzf67dzXmb9141xFy5ypPPn2p3U6TDKCBEEnRmqshtD32DctX5ieuG292cunZuybN9Ux+P/N1rVKsq/y0XKSBEEvL98zN84gtf4+T4LLf/k51cOdjX7ibVZWZ8aPdWXjs3w1+/9Ga7myMdJNGAMLObzOyYmY2Y2f4663vM7Ilw/WEz21mz7rfD5cfM7CeTbKdI3L57eoqbv/B1potlHvvlD7S9vPdKrn/HIO+6fCO/+cUXOPTiaV1ISIAEz6Q2szTwIPAR4BTwrJkddPeXaza7Axh3991mditwH/BJM9sD3ApcD7wDeNrMrnN3zaJJR3v0a69z/Ow0Xzxyklw6xS9+cBcvvzHV7matKJ0y/vTOD3DHo8/ya//jea4Z7mffe7fxM++5kl1b+zGzdjdR2iDJUhs3ACPu/iqAmT0O7ANqA2IfcE94+0ng8xb8Ju4DHnf3OeA1MxsJn+/rCbb3khT9JRj9Qej11i3cj7b5wcewaH3tNpGUWfgT3k7F84UyX6lyoVjmQnGeC8Uy03NlcpkU/bkM+VyafC5Nf0+GnkwKdyiWK8zNV6nU7PdcuUKhVGG2VKHqTjadIp0yCvMVxqZLjM2UKMxXSKeMgb4sm/PBoZ+ZdLBPBphBuepMF8tcmCszO1ehXK1SqToTs/McfWOKl0YneeXMBaoOW/pz/NKHdrEp33mT0kv5Py+9ycffu413bu7nhZMT3P/lf+D+L/8DA70Z9rxjgHdfMcD2TX28Y6iPywd6yKXT2ML7/YO/A5lUilwmRU8mRU82RU8mTbrmd6JcqVKqVJktXXxvZktlCqUKc5UqA72ZhSvumYFhwb/hc/dmU/Rm08FPJpV4bavaz5EvXrZwP/wXX/KzE31uylVnvlwlm0mRSwf/n5YLYXen6lCYr1AqV8mmjZ5MmmzaEg3vJANiG3Cy5v4p4MaltnH3splNAlvC5d9Y9NhtSTTyxVOTfPKhIHeW+wJcSvD1scS6Fd63pb6QF155mS/slX5JO0E6ZaRrvjzqfYB84T9vXx59KBph1t5939CTYdtQHz9x3TDbhvq4engDvdn2XVK0Wdl0iht2beaGXZuZmC1x7K0LnJ4oMjpe4PkTE5TK1aafO5My0iljvlJt+H1tVDZtZNOphj5TS62r/Sy143cpl06RSRvV8Pe+WvWF28s+LpPiZ95zJff/y/fG3qauLtZnZncCd4Z3p83sWEIvtRU4l9Bzt5v2LSZHW/VCAb1v3Sex/XoAeOCTTT/8nUutSDIgRoGrau5vD5fV2+aUmWWAQeB8g4/F3R8CHoqxzXWZ2RF335v067SD9q07ad+6TzfuV5IDd88C15rZLjPLEUw6H1y0zUHg9vD2zcAzHvTzDgK3hkc57QKuBb6ZYFtFRGSRxHoQ4ZzCXcBTQBp4xN2Pmtm9wBF3Pwg8DBwIJ6HHCEKEcLsvEkxol4Ff1xFMIiKtlegchLsfAg4tWnZ3ze0icMsSj/194PeTbN8qJD6M1Ubat+6kfes+XbdfphNiRESkHpXaEBGRuhQQi5jZLWZ21MyqZra3ZvlOMyuY2Qvhzxdq1r3fzF4MS4P8gXXoaadL7Vu4rm5pk5XKpXQiM7vHzEZr3qufqlnX1SVcuvH9WI6ZvR5+dl4wsyPhss1m9mUzeyX8d1O729kIM3vEzM6Y2Us1y+ruiwX+IHwfv2Nm72tfy5fh7vqp+QH+EfAu4P8Be2uW7wReWuIx3wQ+ABjw18DH2r0fq9y3PcC3gR5gF3Cc4MCCdHj7aiAXbrOn3fvRwH7eA/xWneV197Pd7V3FfnXl+7HCPr0ObF207HPA/vD2fuC+drezwX35ceB9td8TS+0L8FPhd4WF3x2H293+ej/qQSzi7t9194ZPuDOzK4EBd/+GB+/8HwMfT6p9a7HMvi2UNnH314CotMlCuRR3LwFRuZRutdR+dotL7f1Yyj7g0fD2o3To52kxd/9bgqMxay21L/uAP/bAN4Ch8LukoyggVmeXmX3LzL5qZj8WLttGUAokklhZkATVK4uybZnl3eCusOv+SM0QRTfvD3R/++tx4P+a2XNhZQSAy939dHj7TeDy9jQtFkvtS1e8l11daqNZZvY0cEWdVZ9x9/+1xMNOAzvc/byZvR/4kpldn1gjm9TkvnWd5fYT+G/A7xF8+fwe8J+BX2pd62QVPuTuo2Z2GfBlM/te7Up3dzO7JA617MZ9WZcB4e7/oonHzAFz4e3nzOw4cB1BCZDtNZvWLQvSKs3sG8uXNlmx5Ek7NLqfZvZHwF+Gdxsq4dLBur39b+Puo+G/Z8zsLwiG0d4ysyvd/XQ47HKmrY1cm6X2pSveSw0xNcjMhsNrXGBmVxOU/3g17D5OmdkHwqOXPg1021/qS5U2aaRcSsdZNJb780B0VEm3l3DpyvdjKWbWb2Ybo9vARwneq9oSPLfTfZ+nWkvty0Hg0+HRTB8AJmuGojpHu2fJO+2H4AvlFEFv4S3gqXD5JwgKdr4APA/8bM1j9hL8Yh8HPk94AmKn/Sy1b+G6z4TtP0bNUVgER1v8Q7juM+3ehwb38wDwIvAdgg/ilSvtZ7f8dOP7scy+XE1wJNa3w8/WZ8LlW4CvAK8ATwOb293WBvfnTwmGoufDz9kdS+0LwdFLD4bv44vUHFXYST86k1pEROrSEJOIiNSlgBARkboUECIiUpcCQkRE6lJAiIhIXQoIERGpSwEhIiJ1KSBERKSu/w+T8GOKnBCISwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "rewards = [x[3] for x in train_tuples]\n",
    "sns.distplot(rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fqi_agent = FQIagent(train_tuples=train_tuples, test_tuples=test_tuples)\n",
    "Q_dist = fqi_agent.runFQI(repeats=1)\n",
    "plt.plot(Q_dist, label= \"FQI\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Q Estimate\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfqi_agent = CFQIagent(train_tuples=train_tuples, test_tuples=test_tuples)\n",
    "Q_dist = cfqi_agent.runFQI(repeats=1)\n",
    "plt.plot(Q_dist, label= \"CFQI\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Q Estimate\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lmm_agent = LMMFQIagent(train_tuples=train_tuples, test_tuples=test_tuples)\n",
    "Q_dist = lmm_agent.runFQI(repeats=1)\n",
    "plt.plot(Q_dist, label= \"LMMFQI\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Q Estimate\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test out on each test tuple\n",
    "# FQI, CFQI, LMMFQI, Oracle, Random\n",
    "algos = ['fqi', 'cfqi', 'lmmfqi', 'oracle', 'random']\n",
    "overall_reward = {}\n",
    "mu, sigma = 0, 4\n",
    "for alg in algos:\n",
    "    overall_reward[alg] = []\n",
    "    \n",
    "raw_fqi = []\n",
    "raw_lmmfqi = []\n",
    "raw_cfqi = []\n",
    "    \n",
    "for k, pat in enumerate(tqdm.tqdm(range(util_fqi.num_patients))):\n",
    "    flip = np.random.choice(2)\n",
    "    if flip == 0:\n",
    "        ds = 'foreground'\n",
    "    else:\n",
    "        ds = 'background'\n",
    "    \n",
    "    # Generate a random initial state\n",
    "    s = np.random.normal(mu, sigma, (10, 1))\n",
    "    \n",
    "    val_rewards = {}\n",
    "    for alg in algos:\n",
    "        val_rewards[alg] = []\n",
    "    \n",
    "    \n",
    "    # Generate all of the tuples for this patient\n",
    "    for i in range(util_fqi.num_samples):\n",
    "        s = s.T\n",
    "        # FQI agent\n",
    "        fqi_action = fqi_agent.piE.predict(s)\n",
    "        raw_fqi.append(fqi_action)\n",
    "        #if fqi_action[0] > 3:\n",
    "        #    fqi_action[0] = 3\n",
    "        fqi_action = util_fqi.actions[round(fqi_action[0])]\n",
    "        fqi_action = np.reshape(fqi_action, (2, 1))\n",
    "        s_a = np.concatenate((s.T, fqi_action))\n",
    "        val_rewards['fqi'].append(np.dot(util_fqi.reward_function.T, s_a)[0])\n",
    "\n",
    "\n",
    "        # CFQI agent\n",
    "        cfqi_action = cfqi_agent.piE.predict(s)\n",
    "        raw_cfqi.append(cfqi_action)\n",
    "        #if cfqi_action[0] > 3:\n",
    "        #    cfqi_action[0] = 3\n",
    "        cfqi_action = util_fqi.actions[round(cfqi_action[0])]\n",
    "        cfqi_action = np.reshape(cfqi_action, (2, 1))\n",
    "        s_a = np.concatenate((s.T, cfqi_action))\n",
    "        val_rewards['cfqi'].append(np.dot(util_fqi.reward_function.T, s_a)[0])\n",
    "        \n",
    "        # LMMFQI agent\n",
    "        lmmfqi_action = lmm_agent.piE.predict(s)\n",
    "        raw_lmmfqi.append(lmmfqi_action)\n",
    "        #if lmmfqi_action[0] > 3:\n",
    "        #    lmmfqi_action[0] = 3\n",
    "        lmmfqi_action = util_fqi.actions[round(lmmfqi_action[0])]\n",
    "        lmmfqi_action = np.reshape(lmmfqi_action, (2, 1))\n",
    "        s_a = np.concatenate((s.T, lmmfqi_action))\n",
    "        val_rewards['lmmfqi'].append(np.dot(util_fqi.reward_function.T, s_a)[0])\n",
    "        \n",
    "\n",
    "\n",
    "        # Oracle\n",
    "        all_rewards = []\n",
    "        for j, a in enumerate(util_fqi.actions):\n",
    "            a = np.asarray(a)\n",
    "            a = np.reshape(a, (2, 1))\n",
    "            s_a = np.concatenate((s.T, a))\n",
    "            reward = np.dot(util_fqi.reward_function.T, s_a)\n",
    "            all_rewards.append(reward)\n",
    "\n",
    "        all_rewards = np.asarray(all_rewards)\n",
    "        oracle_action = util_fqi.actions[np.argmax(all_rewards)]\n",
    "        val_rewards['oracle'].append(np.max(all_rewards))\n",
    "\n",
    "\n",
    "        # Random action\n",
    "        random_action = np.asarray(util_fqi.actions[np.random.choice(3)])\n",
    "        random_action = np.reshape(random_action, (2, 1))\n",
    "        s_a = np.concatenate((s.T, random_action))\n",
    "        val_rewards['random'].append(np.dot(util_fqi.reward_function.T, s_a)[0])\n",
    "        \n",
    "        if ds == 'foreground':\n",
    "            t_m = util_fqi.transition_foreground\n",
    "        else:\n",
    "            t_m = util_fqi.transition_background\n",
    "        ns = np.matmul(s_a.T, t_m) / np.linalg.norm(np.matmul(s_a.T, t_m), ord=2)\n",
    "        ns = np.add(ns, np.random.normal(0, 0.5, (1, 10))) # Add noise\n",
    "        s = ns.T\n",
    "    \n",
    "    plt.title(\"Rewards for \" + ds + \" trajectory: \" + str(k))\n",
    "    plt.xlabel(\"Step\")\n",
    "    plt.ylabel(\"Cumulative Reward\")\n",
    "    x = [i for i in range(util_fqi.num_samples)]\n",
    "    rewards_fqi = util_fqi.cumulative_reward(val_rewards['fqi'])\n",
    "    overall_reward['fqi'].append(rewards_fqi[-1])\n",
    "    rewards_cfqi = util_fqi.cumulative_reward(val_rewards['cfqi'])\n",
    "    overall_reward['cfqi'].append(rewards_cfqi[-1])\n",
    "    rewards_lmmfqi = util_fqi.cumulative_reward(val_rewards['lmmfqi'])\n",
    "    overall_reward['lmmfqi'].append(rewards_lmmfqi[-1])\n",
    "    rewards_oracle = util_fqi.cumulative_reward(val_rewards['oracle'])\n",
    "    overall_reward['oracle'].append(rewards_oracle[-1])\n",
    "    rewards_random = util_fqi.cumulative_reward(val_rewards['random'])\n",
    "    overall_reward['random'].append(rewards_random[-1])\n",
    "\n",
    "    plt.plot(x, rewards_fqi, label=\"FQI\")\n",
    "    plt.plot(x, rewards_lmmfqi, label='CFQI')\n",
    "    plt.plot(x, rewards_cfqi, label='LMMFQI')\n",
    "    plt.plot(x, rewards_oracle, label='Oracle')\n",
    "    plt.plot(x, rewards_random, label='Random')\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plots\n",
    "* It seems like LMMFQI, CFQI and FQI all predict actions between 1 and 1.5 pretty frequently. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(raw_lmmfqi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(raw_cfqi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(raw_fqi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def subtract(a, b):\n",
    "    return [b_i - a_i for a_i, b_i in zip(a, b)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(\"Cumulative Reward across patients and algorithms\")\n",
    "sns.stripplot(subtract(overall_reward['oracle'], overall_reward['fqi']), color='r', label='Oracle - FQI')\n",
    "sns.stripplot(subtract(overall_reward['oracle'], overall_reward['cfqi']), color='g', label='Oracle - CFQI')\n",
    "sns.stripplot(subtract(overall_reward['oracle'], overall_reward['lmmfqi']), color='y', label='Oracle - LMMFQI')\n",
    "sns.stripplot(subtract(overall_reward['oracle'], overall_reward['random']), color='b', label='Oracle - Random')\n",
    "#sns.stripplot(overall_reward['oracle'], color='m', label=\"Oracle\")\n",
    "plt.legend()\n",
    "plt.xlabel(\"Cumulative Reward\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research",
   "language": "python",
   "name": "research"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
