{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "sys.path.append('../../')\n",
    "\n",
    "from VizDoom.VizDoom_src.utils import z_normalize, inverse_z_normalize\n",
    "from VizDoom.VizDoom_src.utils import env_vizdoom2\n",
    "from TMaze_new.TMaze_new_src.utils import set_seed\n",
    "\n",
    "env_args = {\n",
    "    'simulator':'doom', \n",
    "    'scenario':'custom_scenario{:003}.cfg', #custom_scenario{:003}.cfg\n",
    "    'test_scenario':'', \n",
    "    'screen_size':'320X180', \n",
    "    'screen_height':64, \n",
    "    'screen_width':112, \n",
    "    'num_environments':16,# 16\n",
    "    'limit_actions':True, \n",
    "    'scenario_dir':'../../VizDoom/VizDoom_src/env/', \n",
    "    'test_scenario_dir':'', \n",
    "    'show_window':False, \n",
    "    'resize':True, \n",
    "    'multimaze':True, \n",
    "    'num_mazes_train':16, \n",
    "    'num_mazes_test':1, # 64 \n",
    "    'disable_head_bob':False, \n",
    "    'use_shaping':False, \n",
    "    'fixed_scenario':False, \n",
    "    'use_pipes':False, \n",
    "    'num_actions':0, \n",
    "    'hidden_size':128, \n",
    "    'reload_model':'', \n",
    "    'model_checkpoint':'../3dcdrl/saved_models/two_col_p1_checkpoint_0198658048.pth.tar',\n",
    "    'conv1_size':16, \n",
    "    'conv2_size':32, \n",
    "    'conv3_size':16, \n",
    "    'learning_rate':0.0007, \n",
    "    'momentum':0.0, \n",
    "    'gamma':0.99, \n",
    "    'frame_skip':4, \n",
    "    'train_freq':4, \n",
    "    'train_report_freq':100, \n",
    "    'max_iters':5000000, \n",
    "    'eval_freq':1000, \n",
    "    'eval_games':50, \n",
    "    'model_save_rate':1000, \n",
    "    'eps':1e-05, \n",
    "    'alpha':0.99, \n",
    "    'use_gae':False, \n",
    "    'tau':0.95, \n",
    "    'entropy_coef':0.001, \n",
    "    'value_loss_coef':0.5, \n",
    "    'max_grad_norm':0.5, \n",
    "    'num_steps':128, \n",
    "    'num_stack':1, \n",
    "    'num_frames':200000000, \n",
    "    'use_em_loss':False, \n",
    "    'skip_eval':False, \n",
    "    'stoc_evals':False, \n",
    "    'model_dir':'', \n",
    "    'out_dir':'./', \n",
    "    'log_interval':100, \n",
    "    'job_id':12345, \n",
    "    'test_name':'test_000', \n",
    "    'use_visdom':False, \n",
    "    'visdom_port':8097, \n",
    "    'visdom_ip':'http://10.0.0.1'                 \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_returns_VizDoom(seed, episode_timeout):\n",
    "    \n",
    "    set_seed(seed)\n",
    "    \n",
    "    max_ep_len = episode_timeout#* 3\n",
    "\n",
    "        \n",
    "    scene = 0\n",
    "    scenario = env_args['scenario_dir'] + env_args['scenario'].format(scene)\n",
    "    config_env = scenario\n",
    "\n",
    "    env = env_vizdoom2.DoomEnvironmentDisappear(\n",
    "        scenario=config_env,\n",
    "        show_window=False,\n",
    "        use_info=True,\n",
    "        use_shaping=False, #if False bonus reward if #shaping reward is always: +1,-1 in two_towers\n",
    "        frame_skip=2,\n",
    "        no_backward_movement=True,\n",
    "        seed=seed)\n",
    "    \n",
    "    state0 = env.reset()\n",
    "    state = torch.tensor(state0['image']).float()\n",
    "    state = state.reshape(1, 1, state.shape[0], state.shape[1], state.shape[2])\n",
    "\n",
    "    \n",
    "    out_states = []\n",
    "    out_states.append(state.cpu().numpy())\n",
    "    done = False\n",
    "    \n",
    "    rews = []\n",
    "    episode_return, episode_length = 0, 0\n",
    "    \n",
    "    for t in range(max_ep_len):\n",
    "        act = np.random.randint(low=0, high=4+1)   \n",
    "        \n",
    "        state, reward, done, info = env.step(act)\n",
    "        state = np.float32(state['image'])\n",
    "        state = state.reshape(1, 1, state.shape[0], state.shape[1], state.shape[2])\n",
    "        \n",
    "        \n",
    "        rews.append(reward)\n",
    "        episode_return += reward\n",
    "        episode_length += 1\n",
    "        \n",
    "        if done:\n",
    "            torch.cuda.empty_cache()\n",
    "            break  \n",
    "    \n",
    "    env.close()\n",
    "    return episode_return, (t+1)*2\n",
    "\n",
    "# * ##############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "reds = [2, 3, 6, 8, 9, 10, 11, 14, 15, 16, 17, 18, 20, 21, 25, 26, 27, 28, 29, 31, 38, 40, 41, 42, 45,\n",
    "        46, 49, 50, 51, 52, 53, 54, 55, 58, 59, 60, 61, 63, 64, 67, 68, 70, 72, 73, 74, 77, 80, 82, 84, \n",
    "        86, 88, 89, 90, 91, 92, 97, 98, 99, 100, 101, 103, 106, 108, 109, 113, 115, 116, 117, 120, \n",
    "        123, 124, 125, 126, 127, 128, 129, 133, 134, 136, 139, 140, 142, 144, 145, 147, 148, 151, 152, \n",
    "        153, 154, 156, 157, 158, 159, 161, 164, 165, 170, 171, 173]\n",
    "\n",
    "greens = [0, 1, 4, 5, 7, 12, 13, 19, 22, 23, 24, 30, 32, 33, 34, 35, 36, 37, 39, 43, 44, 47, 48, 56, 57,\n",
    "          62, 65, 66, 69, 71, 75, 76, 78, 79, 81, 83, 85, 87, 93, 94, 95, 96, 102, 104, 105, 107, 110, 111, \n",
    "          112, 114, 118, 119, 121, 122, 130, 131, 132, 135, 137, 138, 141, 143, 146, 149, 150, 155, 160, 162, \n",
    "          163, 166, 167, 168, 169, 172, 175, 176, 177, 182, 183, 187, 190, 192, 193, 195, 199, 204, 206, 208, \n",
    "          209, 210, 212, 215, 216, 218, 219, 220, 221, 223, 224, 225]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:58<00:00,  1.72it/s]\n",
      "100%|██████████| 100/100 [00:58<00:00,  1.70it/s]\n"
     ]
    }
   ],
   "source": [
    "episode_timeout = 4200\n",
    "SEED = 1\n",
    "\n",
    "reds_returns = []\n",
    "reds_lengths = []\n",
    "for SEED in tqdm(reds):\n",
    "    episode_return, episode_length = get_returns_VizDoom(seed=SEED, episode_timeout=episode_timeout)\n",
    "    reds_returns.append(episode_return)\n",
    "    reds_lengths.append(episode_length)\n",
    "\n",
    "greens_returns = []\n",
    "greens_lengths = []\n",
    "for SEED in tqdm(greens):\n",
    "    episode_return, episode_length = get_returns_VizDoom(seed=SEED, episode_timeout=episode_timeout)\n",
    "    greens_returns.append(episode_return)\n",
    "    greens_lengths.append(episode_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total returns: 4.817399999999981\n",
      "Total lengths: 404.75\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Red returns: 4.657599999999982\n",
      "Red lengths: 395.78\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Green returns: 4.977199999999977\n",
      "Green lengths: 413.72\n"
     ]
    }
   ],
   "source": [
    "print(\"Total returns:\", np.mean(reds_returns + greens_returns))\n",
    "print(\"Total lengths:\", np.mean(reds_lengths + greens_lengths))\n",
    "\n",
    "print(\"-\"*100)\n",
    "\n",
    "print(\"Red returns:\", np.mean(reds_returns))\n",
    "print(\"Red lengths:\", np.mean(reds_lengths))\n",
    "\n",
    "print(\"-\"*100)\n",
    "\n",
    "print(\"Green returns:\", np.mean(greens_returns))\n",
    "print(\"Green lengths:\", np.mean(greens_lengths))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "BERSERK",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
