{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tabular experiments\n",
    "Run this notebook to obtain the data for the tabular experiments. Use then the `plot_results_tabular.ipynb` notebook to plot the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import lzma\n",
    "import pickle\n",
    "import multiprocessing\n",
    "from tabular.agents.agent import Experience\n",
    "from tabular.make_agent import make_agent, AgentParameters\n",
    "from tabular.simulation_parameters import make_env, SimulationParameters\n",
    "from tabular.utils.utils import Results\n",
    "from typing import List, NamedTuple, Sequence\n",
    "from tabular.config import CONFIG\n",
    "from copy import deepcopy\n",
    "\n",
    "class DataResults(NamedTuple):\n",
    "    \"\"\" This is object type saved in the results \"\"\"\n",
    "    simulation_parameters: SimulationParameters\n",
    "    agent_type: str\n",
    "    data: Sequence[Sequence[Results]]\n",
    "\n",
    "def run(seed:int, agent_parameters: any, p: SimulationParameters) -> List[Results]:\n",
    "    \"\"\"Run a simulation \n",
    "\n",
    "    Args:\n",
    "        seed (int): simulation seed\n",
    "        agent_parameters (any): parameters of the agent\n",
    "        p (SimulationParameters): simulation parameters\n",
    "\n",
    "    Returns:\n",
    "        List[Results]: A list of results (evaluated every p.sim_parameters.freq_eval steps)\n",
    "    \"\"\"\n",
    "    np.set_printoptions(formatter={'float': lambda x: \"{0:0.3f}\".format(x)})\n",
    "    np.random.seed(seed)\n",
    "    env = make_env(env = p.env_parameters)\n",
    "    print(np.random.uniform())\n",
    "    \n",
    "    start_time = time.time()\n",
    "    s = env.reset()\n",
    "    discount_factor = p.sim_parameters.discount_factor\n",
    "    agent = make_agent(agent_parameters)\n",
    "\n",
    "    results = []\n",
    "\n",
    "    R_basis = env.generate_boundary_rewards()\n",
    "    R_random = env.generate_random_rewards(N=p.sim_parameters.num_rewards)\n",
    "    R = np.vstack([R_basis, R_random])\n",
    "\n",
    "    for t in range(p.env_parameters.horizon):\n",
    "        a = agent.forward(s, t)\n",
    "        next_state, _ = env.step(a)\n",
    "        exp = Experience(s, a, next_state)\n",
    "        reset = agent.backward(exp, t)\n",
    "\n",
    "        s = env.reset() if reset else next_state\n",
    "\n",
    "        # Evaluate the agent\n",
    "        if (t +1) % p.sim_parameters.freq_eval == 0:\n",
    "            \n",
    "            V_res, pi_res, Q_res = env.eval_transition(\n",
    "                Phat=agent.empirical_transition(), R=R, discount_factor=discount_factor)\n",
    "            print(f'[{t}] {agent.U_t} {agent.Z_t} - {agent.beta} -  {V_res.mean()} - {pi_res.mean()} - {agent.state_action_visits}')\n",
    "            print('--------')\n",
    "            \n",
    "            # Append results to be saved\n",
    "            results.append(\n",
    "                 Results(step=t, omega=deepcopy(agent.omega), total_state_visits=deepcopy(agent.total_state_visits),\n",
    "                         last_visit= deepcopy(agent.last_visit), exp_visits=deepcopy(agent.exp_visits), V_res=V_res,\n",
    "                         Q_res=Q_res, pi_res=pi_res, elapsed_time=time.time() - start_time))\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = CONFIG\n",
    "NUM_CPUS = 10 # Change this parameter to define the number of vCPUs to use\n",
    "\n",
    "# Loop through the configurations\n",
    "for env_params, agents in cfg.envs:\n",
    "        env = make_env(env_params)\n",
    "        agent_parameters = AgentParameters(\n",
    "                 dim_state_space=env.dim_state, dim_action_space=env.dim_action,\n",
    "                         discount_factor=cfg.sim_parameters.discount_factor, \n",
    "                         horizon=env_params.horizon,\n",
    "                         frequency_evaluation=cfg.sim_parameters.freq_eval,\n",
    "                         delta=cfg.sim_parameters.delta)\n",
    "        \n",
    "        # Loop through the agents\n",
    "        for agent in agents:\n",
    "            print(f'> Evaluating {agent.type} on {env_params.env_type.value}({env_params.horizon})', end='... ')\n",
    "            agent = agent._replace(agent_parameters = agent_parameters)\n",
    "            \n",
    "            # Create path if it does not exists\n",
    "            path = f'./tabular/data/{env_params.env_type.value}/{env_params.horizon}/'\n",
    "            if not os.path.exists(path):\n",
    "                os.makedirs(path)\n",
    "\n",
    "            data = {}\n",
    "\n",
    "            data['simulation_parameters'] = SimulationParameters(\n",
    "                env_parameters=env_params,\n",
    "                sim_parameters=cfg.sim_parameters\n",
    "            )\n",
    "            data['agent_type'] = agent\n",
    "\n",
    "            iterations = [(seed, agent, data['simulation_parameters']) for seed in  range(data['simulation_parameters'].sim_parameters.num_sims)]\n",
    "            start_time = time.time()\n",
    "            data_returned = []\n",
    "\n",
    "            # Run simulations\n",
    "            with multiprocessing.Pool(NUM_CPUS) as pool:\n",
    "                returns = [pool.apply_async(run, p) for p in iterations]\n",
    "\n",
    "                for r in returns:\n",
    "                     data_returned.append(r.get())\n",
    "\n",
    "\n",
    "            data['data'] = data_returned\n",
    "            print(f'done in {np.round(time.time() - start_time, 2)} seconds.')\n",
    "                        \n",
    "            data = DataResults(data['simulation_parameters'], data['agent_type'], data['data'])\n",
    "\n",
    "            # Save compressed results\n",
    "            with lzma.open(f'{path}/{agent.type}.pkl.lzma', 'wb') as f:\n",
    "                pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "reward-free-exploration",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
