{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.env_export import create_json\n",
    "from src.utils.khalili_env import get_env\n",
    "from src.policy.jas_voc_policy import JAS_voc_policy\n",
    "from src.policy.jas_policy import RandomPolicy\n",
    "from src.utils.mouselab_jas import MouselabJas\n",
    "import numpy as np\n",
    "import random\n",
    "import seaborn as sns\n",
    "from matplotlib import pyplot as plt\n",
    "import json\n",
    "from simulation import run_episode\n",
    "from src.utils.data_classes import EpisodeResult\n",
    "import pandas as pd\n",
    "import tqdm\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random baseline for experiment environments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "cost_weight= 0.5798921379230035\n",
    "env, config = get_env()\n",
    "policy = RandomPolicy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>reward</th>\n",
       "      <th>actions</th>\n",
       "      <th>seed</th>\n",
       "      <th>runtime</th>\n",
       "      <th>true_reward</th>\n",
       "      <th>expected_reward</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.509576</td>\n",
       "      <td>4.9222</td>\n",
       "      <td>12.0</td>\n",
       "      <td>0.001303</td>\n",
       "      <td>3.894847</td>\n",
       "      <td>3.509576</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3.527022</td>\n",
       "      <td>4.9086</td>\n",
       "      <td>13.0</td>\n",
       "      <td>0.001295</td>\n",
       "      <td>3.921715</td>\n",
       "      <td>3.527022</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3.498825</td>\n",
       "      <td>4.9246</td>\n",
       "      <td>14.0</td>\n",
       "      <td>0.001308</td>\n",
       "      <td>3.610339</td>\n",
       "      <td>3.498825</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3.510792</td>\n",
       "      <td>4.9127</td>\n",
       "      <td>15.0</td>\n",
       "      <td>0.001473</td>\n",
       "      <td>3.574298</td>\n",
       "      <td>3.510792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3.499757</td>\n",
       "      <td>4.9218</td>\n",
       "      <td>16.0</td>\n",
       "      <td>0.001420</td>\n",
       "      <td>3.222883</td>\n",
       "      <td>3.499757</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>3.572623</td>\n",
       "      <td>4.9146</td>\n",
       "      <td>17.0</td>\n",
       "      <td>0.001420</td>\n",
       "      <td>4.224300</td>\n",
       "      <td>3.572623</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>3.544299</td>\n",
       "      <td>4.9184</td>\n",
       "      <td>18.0</td>\n",
       "      <td>0.001367</td>\n",
       "      <td>3.676118</td>\n",
       "      <td>3.544299</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>3.537451</td>\n",
       "      <td>4.9148</td>\n",
       "      <td>19.0</td>\n",
       "      <td>0.001414</td>\n",
       "      <td>3.952736</td>\n",
       "      <td>3.537451</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>3.517815</td>\n",
       "      <td>4.9179</td>\n",
       "      <td>20.0</td>\n",
       "      <td>0.001444</td>\n",
       "      <td>3.435968</td>\n",
       "      <td>3.517815</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>3.502410</td>\n",
       "      <td>4.9111</td>\n",
       "      <td>21.0</td>\n",
       "      <td>0.001369</td>\n",
       "      <td>3.539911</td>\n",
       "      <td>3.502410</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     reward  actions  seed   runtime  true_reward  expected_reward\n",
       "0  3.509576   4.9222  12.0  0.001303     3.894847         3.509576\n",
       "1  3.527022   4.9086  13.0  0.001295     3.921715         3.527022\n",
       "2  3.498825   4.9246  14.0  0.001308     3.610339         3.498825\n",
       "3  3.510792   4.9127  15.0  0.001473     3.574298         3.510792\n",
       "4  3.499757   4.9218  16.0  0.001420     3.222883         3.499757\n",
       "5  3.572623   4.9146  17.0  0.001420     4.224300         3.572623\n",
       "6  3.544299   4.9184  18.0  0.001367     3.676118         3.544299\n",
       "7  3.537451   4.9148  19.0  0.001414     3.952736         3.537451\n",
       "8  3.517815   4.9179  20.0  0.001444     3.435968         3.517815\n",
       "9  3.502410   4.9111  21.0  0.001369     3.539911         3.502410"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Compute random score over test environments\n",
    "columns = EpisodeResult._fields\n",
    "data = []\n",
    "for seed in range(12, 22):\n",
    "    res_data = []\n",
    "    for _ in range(10000):\n",
    "        res, actions = run_episode(env, policy, seed=seed)\n",
    "        res_data.append(res)\n",
    "    res_df = pd.DataFrame(res_data, columns=columns)\n",
    "    data.append(res_df.mean().to_list())\n",
    "df = pd.DataFrame(data, columns=columns)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "#df.to_csv(\"./data/experiment_results/random_baseline.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "reward              3.522057\n",
       "actions             4.916670\n",
       "seed               16.500000\n",
       "runtime             0.001381\n",
       "true_reward         3.705311\n",
       "expected_reward     3.522057\n",
       "dtype: float64"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.mean()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random baseline for simulation environments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "env, config = get_env()\n",
    "policy = RandomPolicy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5000000/5000000 [4:21:16<00:00, 318.96it/s]  \n"
     ]
    }
   ],
   "source": [
    "# Compute random score over test environments\n",
    "start_seed, end_seed = 5000, 10000\n",
    "runs_per_env = 1000\n",
    "columns = EpisodeResult._fields\n",
    "data = []\n",
    "with tqdm.tqdm(total=runs_per_env*(end_seed-start_seed)) as pbar:\n",
    "    for seed in range(start_seed, end_seed):\n",
    "        for _ in range(runs_per_env):\n",
    "            res, actions = run_episode(env, policy, seed=seed)\n",
    "            data.append(res)\n",
    "            pbar.update()\n",
    "        \n",
    "df = pd.DataFrame(data, columns=columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(\"./data/simulation_results/random_baseline_all.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"./data/simulation_results/random_baseline_all.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_df = df.groupby(\"seed\").mean()\n",
    "mean_df = mean_df.reset_index()\n",
    "mean_df.to_csv(\"./data/simulation_results/random_baseline.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jas",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
