{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "8b544240",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "from scipy.stats import spearmanr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "6bfe63e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "perf = pd.read_csv(\"../../Implicit-Language-Q-Learning/outputs/visual_dialogue/performances.txt\", delimiter='\\t', header=None)\n",
    "perf = perf[~perf[0].duplicated()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "edfceb69",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3370711/2728996184.py:16: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  target_policies = target_policies.iloc[np.round(np.linspace(0, target_policies.shape[0]-1, 8)).astype(np.int)]\n"
     ]
    }
   ],
   "source": [
    "low = perf.loc[perf[1] < perf[1].quantile(.2)].sort_values(1)\n",
    "low_sampled = low.iloc[np.round(np.linspace(0, low.shape[0]-1, 10)).astype(np.int64)]\n",
    "\n",
    "mid = perf.loc[np.logical_and(perf[1] >= perf[1].quantile(.2), perf[1] < perf[1].quantile(.6))].sort_values(1)\n",
    "mid_sampled = mid.iloc[np.round(np.linspace(0, mid.shape[0]-1, 10)).astype(np.int64)]\n",
    "\n",
    "high = perf.loc[perf[1] >= perf[1].quantile(.6)].sort_values(1)\n",
    "high_sampled = high.iloc[np.round(np.linspace(0, high.shape[0]-1, 10)).astype(np.int64)]\n",
    "\n",
    "target_policies = pd.concat([\n",
    "    low.iloc[np.round(np.linspace(1, low.shape[0]-2, 5)).astype(np.int64)],\n",
    "    mid.iloc[np.round(np.linspace(1, mid.shape[0]-2, 3)).astype(np.int64)],\n",
    "    high.iloc[np.round(np.linspace(1, high.shape[0]-2, 2)).astype(np.int64)],\n",
    "])\n",
    "\n",
    "target_policies = target_policies.iloc[np.round(np.linspace(0, target_policies.shape[0]-1, 8)).astype(np.int)]\n",
    "\n",
    "target_policies[5] = target_policies[0].map(lambda x : \"/\".join(x.split(\"/\")[:-2]) + \"/{}.pkl\".format(x.split(\"/\")[-1].split(\".\")[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "e523826e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3370711/109546601.py:1: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  low_sampled = low_sampled.iloc[np.round(np.linspace(0, low_sampled.shape[0]-1, 5)).astype(np.int)]\n"
     ]
    }
   ],
   "source": [
    "low_sampled = low_sampled.iloc[np.round(np.linspace(0, low_sampled.shape[0]-1, 5)).astype(np.int)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "19a62f4f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['model_131071.pkl', 'model_32767.pkl', 'model_65535.pkl',\n",
       "       'model_32767.pkl', 'model_114687.pkl', 'model_98303.pkl',\n",
       "       'model.pkl', 'model_229375.pkl'], dtype=object)"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_policies[5].map(lambda x : \"/\".join(x.split(\"/\")[-1:])).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "dade9d52",
   "metadata": {},
   "outputs": [],
   "source": [
    "iw_results = [i for i in os.listdir(\"./iw/low\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "98860389",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_policies[0] = target_policies[5].map(lambda x : \"/\".join(x.split(\"/\")[-1:]).replace(\"/\",\"_\").replace(\".pkl\",\"\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "b8aebdb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "del target_policies[5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "e937c353",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Must have equal len keys and value when setting with an iterable",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_3370711/2879564588.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miw_results\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m     \u001b[0mcurrent\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"_\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"_\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m     \u001b[0mtarget_policies\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcurrent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloadtxt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"./iw/low/\"\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m.005\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/ope_py37/lib/python3.8/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m__setitem__\u001b[0;34m(self, key, value)\u001b[0m\n\u001b[1;32m    714\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    715\u001b[0m         \u001b[0miloc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"iloc\"\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 716\u001b[0;31m         \u001b[0miloc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_setitem_with_indexer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    717\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    718\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_validate_key\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/ope_py37/lib/python3.8/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m_setitem_with_indexer\u001b[0;34m(self, indexer, value, name)\u001b[0m\n\u001b[1;32m   1686\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mtake_split_path\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1687\u001b[0m             \u001b[0;31m# We have to operate column-wise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1688\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_setitem_with_indexer_split_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1689\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1690\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_setitem_single_block\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/ope_py37/lib/python3.8/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m_setitem_with_indexer_split_path\u001b[0;34m(self, indexer, value, name)\u001b[0m\n\u001b[1;32m   1768\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1769\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1770\u001b[0;31m                 raise ValueError(\n\u001b[0m\u001b[1;32m   1771\u001b[0m                     \u001b[0;34m\"Must have equal len keys and value \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1772\u001b[0m                     \u001b[0;34m\"when setting with an iterable\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Must have equal len keys and value when setting with an iterable"
     ]
    }
   ],
   "source": [
    "for i in range(2,11):\n",
    "    target_policies[i] = np.nan\n",
    "\n",
    "target_policies.index = target_policies[0]\n",
    "del target_policies[0]\n",
    "\n",
    "for f in iw_results:\n",
    "    current = \"_\".join(f.split(\"_\")[:-5])\n",
    "    target_policies.loc[current, 8] = np.loadtxt(\"./iw/low/\"+f)/.005\n",
    "\n",
    "\n",
    "target_policies.loc[~target_policies[8].isna(), 1] *= 100\n",
    "target_policies.loc[~target_policies[8].isna(), 1] += 25\n",
    "\n",
    "\n",
    "truth = target_policies[1].loc[~target_policies[8].isna()].values\n",
    "pred = target_policies.loc[:, 8].loc[~target_policies[8].isna()].values\n",
    "\n",
    "maes = np.asarray([np.abs((truth - pred) / truth).mean()])\n",
    "ranks = np.asarray([spearmanr(truth, pred)[0] ])\n",
    "regrets = np.asarray([(truth.max() - truth[np.argmax(pred)]) / truth.max() ])\n",
    "\n",
    "print(\"MAE\", maes.mean(), maes.std())\n",
    "\n",
    "print(\"Rank\", ranks.mean(), ranks.std())\n",
    "\n",
    "print(\"Regret@1\", regrets.mean(), regrets.std())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
