{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "51909177",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"   # see issue #152\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "#import imitation.custom_scripts.custom_ant_env\n",
    "\n",
    "import gym\n",
    "# import seals.classic_control\n",
    "from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv, VecFrameStack, VecEnv, VecMonitor, is_vecenv_wrapped, VecVideoRecorder\n",
    "from stable_baselines3.ppo import PPO\n",
    "from stable_baselines3.common.utils import obs_as_tensor\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import random\n",
    "import gym\n",
    "from matplotlib.pyplot import imshow, imsave\n",
    "from statistics import mean, stdev\n",
    "import time\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "import collections\n",
    "import cv2\n",
    "import warnings\n",
    "from typing import Any, Callable, Dict, List, Optional, Tuple, Union\n",
    "import gym\n",
    "import numpy as np\n",
    "np.random.seed(42)\n",
    "th.manual_seed(42)\n",
    "random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "18c550d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_policy(\n",
    "    model: \"base_class.BaseAlgorithm\",\n",
    "    env: Union[gym.Env, VecEnv],\n",
    "    n_eval_episodes: int = 10,\n",
    "    deterministic: bool = True,\n",
    "    render: bool = False,\n",
    "    callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,\n",
    "    reward_threshold: Optional[float] = None,\n",
    "    return_episode_rewards: bool = False,\n",
    "    warn: bool = True,\n",
    "    max_steps: int = int(1e5),\n",
    "    masksemble_mode=\"MODE\",\n",
    ") -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:\n",
    "    \"\"\"\n",
    "    Runs policy for ``n_eval_episodes`` episodes and returns average reward.\n",
    "    If a vector env is passed in, this divides the episodes to evaluate onto the\n",
    "    different elements of the vector env. This static division of work is done to\n",
    "    remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more\n",
    "    details and discussion.\n",
    "\n",
    "    .. note::\n",
    "        If environment has not been wrapped with ``Monitor`` wrapper, reward and\n",
    "        episode lengths are counted as it appears with ``env.step`` calls. If\n",
    "        the environment contains wrappers that modify rewards or episode lengths\n",
    "        (e.g. reward scaling, early episode reset), these will affect the evaluation\n",
    "        results as well. You can avoid this by wrapping environment with ``Monitor``\n",
    "        wrapper before anything else.\n",
    "\n",
    "    :param model: The RL agent you want to evaluate.\n",
    "    :param env: The gym environment or ``VecEnv`` environment.\n",
    "    :param n_eval_episodes: Number of episode to evaluate the agent\n",
    "    :param deterministic: Whether to use deterministic or stochastic actions\n",
    "    :param render: Whether to render the environment or not\n",
    "    :param callback: callback function to do additional checks,\n",
    "        called after each step. Gets locals() and globals() passed as parameters.\n",
    "    :param reward_threshold: Minimum expected reward per episode,\n",
    "        this will raise an error if the performance is not met\n",
    "    :param return_episode_rewards: If True, a list of rewards and episode lengths\n",
    "        per episode will be returned instead of the mean.\n",
    "    :param warn: If True (default), warns user about lack of a Monitor wrapper in the\n",
    "        evaluation environment.\n",
    "    :return: Mean reward per episode, std of reward per episode.\n",
    "        Returns ([float], [int]) when ``return_episode_rewards`` is True, first\n",
    "        list containing per-episode rewards and second containing per-episode lengths\n",
    "        (in number of steps).\n",
    "    \"\"\"\n",
    "    is_monitor_wrapped = False\n",
    "    # Avoid circular import\n",
    "    from stable_baselines3.common.monitor import Monitor\n",
    "\n",
    "    if not isinstance(env, VecEnv):\n",
    "        env = DummyVecEnv([lambda: env])\n",
    "\n",
    "    is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]\n",
    "\n",
    "    if not is_monitor_wrapped and warn:\n",
    "        warnings.warn(\n",
    "            \"Evaluation environment is not wrapped with a ``Monitor`` wrapper. \"\n",
    "            \"This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. \"\n",
    "            \"Consider wrapping environment first with ``Monitor`` wrapper.\",\n",
    "            UserWarning,\n",
    "        )\n",
    "\n",
    "    n_envs = env.num_envs\n",
    "    episode_rewards = []\n",
    "    episode_lengths = []\n",
    "\n",
    "    episode_counts = np.zeros(n_envs, dtype=\"int\")\n",
    "    # Divides episodes among different sub environments in the vector as evenly as possible\n",
    "    episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype=\"int\")\n",
    "\n",
    "    current_rewards = np.zeros(n_envs)\n",
    "    current_lengths = np.zeros(n_envs, dtype=\"int\")\n",
    "    observations = env.reset()\n",
    "    states = None\n",
    "    while (episode_counts < episode_count_targets).any():\n",
    "        if masksemble_mode is None:\n",
    "            actions, states = model.predict(observations, state=states, deterministic=True)\n",
    "        else:\n",
    "            actions, states, probs = model.policy.predict(observations, state=states, deterministic=True, masksemble_mode=masksemble_mode)\n",
    "        observations, rewards, dones, infos = env.step(actions)\n",
    "        current_rewards += rewards\n",
    "        current_lengths += 1\n",
    "        # Break when max. episode length is reached\n",
    "        if any(current_lengths > max_steps):\n",
    "            episode_rewards.append(current_rewards[i])\n",
    "            episode_lengths.append(current_lengths[i])\n",
    "            episode_counts[i] += 1\n",
    "        for i in range(n_envs):\n",
    "            if episode_counts[i] < episode_count_targets[i]:\n",
    "\n",
    "                # unpack values so that the callback can access the local variables\n",
    "                reward = rewards[i]\n",
    "                done = dones[i]\n",
    "                info = infos[i]\n",
    "\n",
    "                if callback is not None:\n",
    "                    callback(locals(), globals())\n",
    "\n",
    "                if dones[i]:\n",
    "                    if is_monitor_wrapped:\n",
    "                        # Atari wrapper can send a \"done\" signal when\n",
    "                        # the agent loses a life, but it does not correspond\n",
    "                        # to the true end of episode\n",
    "                        if \"episode\" in info.keys():\n",
    "                            # Do not trust \"done\" with episode endings.\n",
    "                            # Monitor wrapper includes \"episode\" key in info if environment\n",
    "                            # has been wrapped with it. Use those rewards instead.\n",
    "                            episode_rewards.append(info[\"episode\"][\"r\"])\n",
    "                            episode_lengths.append(info[\"episode\"][\"l\"])\n",
    "                            print(\"rew\", info[\"episode\"][\"r\"])\n",
    "                            # Only increment at the real end of an episode\n",
    "                            episode_counts[i] += 1\n",
    "                    else:\n",
    "                        episode_rewards.append(current_rewards[i])\n",
    "                        episode_lengths.append(current_lengths[i])\n",
    "                        episode_counts[i] += 1\n",
    "                    current_rewards[i] = 0\n",
    "                    current_lengths[i] = 0\n",
    "                    if states is not None:\n",
    "                        states[i] *= 0\n",
    "\n",
    "        if render:\n",
    "            env.render()\n",
    "\n",
    "    mean_reward = np.mean(episode_rewards)\n",
    "    std_reward = np.std(episode_rewards)\n",
    "    if reward_threshold is not None:\n",
    "        assert mean_reward > reward_threshold, \"Mean reward below threshold: \" f\"{mean_reward:.2f} < {reward_threshold:.2f}\"\n",
    "    if return_episode_rewards:\n",
    "        return episode_rewards, episode_lengths\n",
    "    return mean_reward, std_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9d0bcd9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark_random(env_name, benchmark_steps=int(2.5e4)): \n",
    "    env = gym.make(env_name)\n",
    "\n",
    "    rews = []\n",
    "    start_time = time.time()\n",
    "    obs = env.reset()\n",
    "    rews.append([])\n",
    "    for _ in range(benchmark_steps):\n",
    "        obs, rew, done, _ = env.step(env.action_space.sample())# env.step(action)\n",
    "        rews[-1].append(rew)\n",
    "        if done:\n",
    "            obs = env.reset()\n",
    "            rews.append([])\n",
    "    end_time = time.time()\n",
    "\n",
    "    print(\"episode rewards, length\", [str(sum(r)) + \" \" + str(len(r)) for r in rews])\n",
    "    print(\"Mean reward: \", mean([sum(r) for r in rews if len(r) != 0]))\n",
    "    print(\"Std reward (non 0 length episodes): \", stdev([sum(r) for r in rews if len(r) != 0]))\n",
    "    print(\"Execution Time\", end_time - start_time)\n",
    "    print(\"Execution Time per 1k\", (end_time - start_time) * (1000 / benchmark_steps))\n",
    "    print(\"# of Episodes: \", len(rews))\n",
    "\n",
    "\n",
    "def benchmark_with_loaded_model(env_name, model_path, model_name=\"model\", masksemble_mode=\"AVERAGE\", benchmark_episodes=int(1.5e3), record_video=False, eval_mode=False):\n",
    "    env = gym.make(env_name)\n",
    "    if record_video:\n",
    "        env = DummyVecEnv([VideoWrapper(env, os.path.join(\"agent_videos\", model_path.split(os.sep)[1]+\"_1\"))])\n",
    "    else:\n",
    "        env = DummyVecEnv([lambda: env])\n",
    "    if os.path.isfile(os.path.join(model_path, \"vec_normalize.pkl\")):\n",
    "        env = VecNormalize.load(os.path.join(model_path, \"vec_normalize.pkl\"), env) \n",
    "\n",
    "\n",
    "    model = PPO.load(os.path.join(model_path, model_name), env=env)\n",
    "    \n",
    "    if eval_mode:\n",
    "        n_eval_episodes = 10\n",
    "        mean_rew, rew_std = evaluate_policy(model, env, n_eval_episodes=n_eval_episodes, masksemble_mode=masksemble_mode)\n",
    "        print(f\"Mean Rew.: {mean_rew} Mean Std.: {rew_std} for {n_eval_episodes} evaluation episodes\")\n",
    "        return mean_rew\n",
    "\n",
    "    rews = []\n",
    "    start_time = time.time()\n",
    "    obs = model.env.reset()\n",
    "    render_frame = env.render(\"rgb_array\")\n",
    "    rews.append([])\n",
    "\n",
    "    fps = 25\n",
    "    render_out = cv2.VideoWriter(model_name+'_gen_render_output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (render_frame.shape[1], render_frame.shape[0]), isColor=True)\n",
    "    \n",
    "    for _ in range(benchmark_steps):\n",
    "        action, _ = model.predict(obs, deterministic=True)\n",
    "        obs, rew, done, _ = env.step(action)\n",
    "        rews[-1].append(rew[0])\n",
    "        if done[0]:\n",
    "            rews.append([])\n",
    "            \n",
    "    end_time = time.time()\n",
    "\n",
    "    if isinstance(env, VecNormalize):\n",
    "        print(\"episode rewards, length\", [str(sum(env.unnormalize_reward(np.array(r))))+\" \"+str(len(r)) for r in rews])\n",
    "        print(\"Mean reward (non 0 length episodes): \", mean([sum(env.unnormalize_reward(np.array(r))) for r in rews if len(r) != 0]))\n",
    "        print(\"Std reward (non 0 length episodes): \", stdev([sum(r) for r in rews if len(r) != 0]))\n",
    "        print(\"Execution Time\", end_time - start_time)\n",
    "        print(\"Execution Time per 1k\", (end_time - start_time) * (1000 / benchmark_steps))\n",
    "        print(\"# of Episodes: \", len(rews))\n",
    "    else:\n",
    "        print(\"episode rewards, length\", [str(sum(r)) + \" \" + str(len(r)) for r in rews])\n",
    "        print(\"Mean reward (non 0 length episodes): \", mean([sum(r) for r in rews if len(r) != 0]))\n",
    "        print(\"Std reward (non 0 length episodes): \", stdev([sum(r) for r in rews if len(r) != 0]))\n",
    "        print(\"Execution Time\", end_time - start_time)\n",
    "        print(\"Execution Time per 1k\", (end_time - start_time) * (1000 / benchmark_steps))\n",
    "        print(\"# of Episodes: \", len(rews))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6e06387b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "objc[5395]: Class GLFWApplicationDelegate is implemented in both /Users/yannick/.mujoco/mujoco200/bin/libglfw.3.dylib (0x1345aa778) and /Users/yannick/Documents/VARL/va-for-rl/venv/lib/python3.9/site-packages/glfw/libglfw.3.dylib (0x13476f740). One of the two will be used. Which one is undefined.\n",
      "objc[5395]: Class GLFWWindowDelegate is implemented in both /Users/yannick/.mujoco/mujoco200/bin/libglfw.3.dylib (0x1345aa700) and /Users/yannick/Documents/VARL/va-for-rl/venv/lib/python3.9/site-packages/glfw/libglfw.3.dylib (0x13476f768). One of the two will be used. Which one is undefined.\n",
      "objc[5395]: Class GLFWContentView is implemented in both /Users/yannick/.mujoco/mujoco200/bin/libglfw.3.dylib (0x1345aa7a0) and /Users/yannick/Documents/VARL/va-for-rl/venv/lib/python3.9/site-packages/glfw/libglfw.3.dylib (0x13476f7b8). One of the two will be used. Which one is undefined.\n",
      "objc[5395]: Class GLFWWindow is implemented in both /Users/yannick/.mujoco/mujoco200/bin/libglfw.3.dylib (0x1345aa818) and /Users/yannick/Documents/VARL/va-for-rl/venv/lib/python3.9/site-packages/glfw/libglfw.3.dylib (0x13476f830). One of the two will be used. Which one is undefined.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"LD_LIBRARY_PATH\"]=\"\" # PUT MUJOCO PATH HERE\n",
    "import mujoco_py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "18bdd4d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=============================\n",
      "Evaluation Series for HalfCheetah-v3\n",
      "=============================\n",
      "MODE:  eval\n",
      "=============================\n",
      "Experiment Results for Baseline - Mode:('Baseline', 'HalfCheetah-v3_6')\n",
      "=============================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/kp/3bh91wns7nz7wx71lyc4rwwr0000gn/T/ipykernel_5395/67505183.py:57: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean Rew.: 3074.2709488635037 Mean Std.: 880.4416217916283 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Baseline - Mode:('Baseline', 'HalfCheetah-v3_4')\n",
      "=============================\n",
      "Mean Rew.: 2402.3487694757814 Mean Std.: 1007.7875347503585 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Baseline - Mode:('Baseline', 'HalfCheetah-v3_5')\n",
      "=============================\n",
      "Mean Rew.: 3693.0972411147786 Mean Std.: 909.1011459780036 for 10 evaluation episodes\n",
      "================> MEAN REWARD:  Baseline 3056.5723198180217\n",
      "=============================\n",
      "Experiment Results for Dropout_Average - Mode:('Dropout_Average', 'HalfCheetah-v3_6')\n",
      "=============================\n",
      "Mean Rew.: 2520.5814927120723 Mean Std.: 16.620079403248965 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Dropout_Average - Mode:('Dropout_Average', 'HalfCheetah-v3_4')\n",
      "=============================\n",
      "Mean Rew.: 2257.585815250414 Mean Std.: 11.589912296130082 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Dropout_Average - Mode:('Dropout_Average', 'HalfCheetah-v3_5')\n",
      "=============================\n",
      "Mean Rew.: 1265.7992177987937 Mean Std.: 18.703082519152886 for 10 evaluation episodes\n",
      "================> MEAN REWARD:  Dropout_Average 2014.6555085870932\n",
      "=============================\n",
      "Experiment Results for Dropout_Single - Mode:('Dropout_Single', 'HalfCheetah-v3_6')\n",
      "=============================\n",
      "Mean Rew.: 2528.7936592706947 Mean Std.: 20.86379481219067 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Dropout_Single - Mode:('Dropout_Single', 'HalfCheetah-v3_4')\n",
      "=============================\n",
      "Mean Rew.: 2242.923498537834 Mean Std.: 9.008986619201973 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Dropout_Single - Mode:('Dropout_Single', 'HalfCheetah-v3_5')\n",
      "=============================\n",
      "Mean Rew.: 1254.783290580884 Mean Std.: 14.69236337296118 for 10 evaluation episodes\n",
      "================> MEAN REWARD:  Dropout_Single 2008.8334827964711\n",
      "=============================\n",
      "Experiment Results for Masksemble_Average - Mode:('Masksemble_Average', 'HalfCheetah-v3_6')\n",
      "=============================\n",
      "384 384 384\n",
      "Mean Rew.: 2532.471581114197 Mean Std.: 727.1682660690926 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Masksemble_Average - Mode:('Masksemble_Average', 'HalfCheetah-v3_4')\n",
      "=============================\n",
      "384 384 384\n",
      "Mean Rew.: 2193.0962979055776 Mean Std.: 69.41141008066985 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Masksemble_Average - Mode:('Masksemble_Average', 'HalfCheetah-v3_5')\n",
      "=============================\n",
      "384 384 384\n",
      "Mean Rew.: 3295.349649605242 Mean Std.: 865.6962985437347 for 10 evaluation episodes\n",
      "================> MEAN REWARD:  Masksemble_Average 2673.639176208339\n",
      "=============================\n",
      "Experiment Results for Masksemble_Single - Mode:('Masksemble_Single', 'HalfCheetah-v3_6')\n",
      "=============================\n",
      "384 384 384\n",
      "Mean Rew.: 2007.7339595266385 Mean Std.: 572.736783704924 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Masksemble_Single - Mode:('Masksemble_Single', 'HalfCheetah-v3_4')\n",
      "=============================\n",
      "384 384 384\n",
      "Mean Rew.: 2202.583803121888 Mean Std.: 81.17131472134619 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Masksemble_Single - Mode:('Masksemble_Single', 'HalfCheetah-v3_5')\n",
      "=============================\n",
      "384 384 384\n",
      "Mean Rew.: 3754.451218031917 Mean Std.: 200.2000492274401 for 10 evaluation episodes\n",
      "================> MEAN REWARD:  Masksemble_Single 2654.922993560148\n",
      "=============================\n",
      "Experiment Results for Ensemble_Average - Mode:('Ensemble_Average', 'HalfCheetah-v3_6')\n",
      "=============================\n",
      "Mean Rew.: 254.75516523570695 Mean Std.: 276.0790662239064 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Ensemble_Average - Mode:('Ensemble_Average', 'HalfCheetah-v3_4')\n",
      "=============================\n",
      "Mean Rew.: 898.1313105693625 Mean Std.: 43.23988004852571 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Ensemble_Average - Mode:('Ensemble_Average', 'HalfCheetah-v3_5')\n",
      "=============================\n",
      "Mean Rew.: 987.4803675712177 Mean Std.: 70.16160266668011 for 10 evaluation episodes\n",
      "================> MEAN REWARD:  Ensemble_Average 713.4556144587624\n",
      "=============================\n",
      "Experiment Results for Ensemble_Single - Mode:('Ensemble_Single', 'HalfCheetah-v3_6')\n",
      "=============================\n",
      "Mean Rew.: 74.10263116667666 Mean Std.: 238.5072219852073 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Ensemble_Single - Mode:('Ensemble_Single', 'HalfCheetah-v3_4')\n",
      "=============================\n",
      "Mean Rew.: 857.4930792643981 Mean Std.: 87.7762936609202 for 10 evaluation episodes\n",
      "=============================\n",
      "Experiment Results for Ensemble_Single - Mode:('Ensemble_Single', 'HalfCheetah-v3_5')\n",
      "=============================\n",
      "Mean Rew.: 984.1133372782111 Mean Std.: 113.56421648528858 for 10 evaluation episodes\n",
      "================> MEAN REWARD:  Ensemble_Single 638.569682569762\n",
      "Complete run time for the evaluation series in s:  464.07239031791687\n"
     ]
    }
   ],
   "source": [
    "experiment_list = dict(\n",
    "    Baseline={\n",
    "        \"path\": \"/Users/yannick/Documents/neurips_benchmarks/neurips_benchmarks/Cheetah_zip/Baseline/\",\n",
    "        \"masksemble_mode\": None,\n",
    "    },\n",
    "    Dropout_Average={\n",
    "        \"path\": \"/Users/yannick/Documents/neurips_benchmarks/neurips_benchmarks/Cheetah_zip/Dropout/\",\n",
    "        \"masksemble_mode\": \"AVERAGE\",\n",
    "    },\n",
    "    Dropout_Single={\n",
    "        \"path\": \"/Users/yannick/Documents/neurips_benchmarks/neurips_benchmarks/Cheetah_zip/Dropout/\",\n",
    "        \"masksemble_mode\": \"INITIAL\",\n",
    "    },\n",
    "    Masksemble_Average={\n",
    "        \"path\": \"../neurips_benchmarks/neurips_benchmarks/Cheetah_zip/Masksemble/\",\n",
    "        \"masksemble_mode\": \"AVERAGE\",\n",
    "    },\n",
    "    Masksemble_Single={\n",
    "        \"path\": \"../neurips_benchmarks/neurips_benchmarks/Cheetah_zip/Masksemble/\",\n",
    "        \"masksemble_mode\": \"INITIAL\",\n",
    "    },\n",
    "    Ensemble_Average={\n",
    "        \"path\": \"../neurips_benchmarks/neurips_benchmarks/Cheetah_zip/Ensemble/\",\n",
    "        \"masksemble_mode\": \"AVERAGE\",\n",
    "    },\n",
    "    Ensemble_Single={\n",
    "        \"path\": \"../neurips_benchmarks/neurips_benchmarks/Cheetah_zip/Ensemble/\",\n",
    "        \"masksemble_mode\": \"INITIAL\",\n",
    "    },\n",
    ")\n",
    "\n",
    "\n",
    "import time\n",
    "env_name = \"HalfCheetah-v3\"\n",
    "    \n",
    "eval_mode = True\n",
    "\n",
    "print(\"=============================\")\n",
    "print(f\"Evaluation Series for {env_name}\")\n",
    "print(\"=============================\")\n",
    "\n",
    "print(\"MODE: \", \"eval\" if eval_mode is True else \"video\")\n",
    "\n",
    "complete_start_time = time.time()\n",
    "for key, exp in experiment_list.items():\n",
    "    \n",
    "    mean_rews = []\n",
    "    for subdir in os.listdir(exp[\"path\"]):\n",
    "        \n",
    "        if \"DS\" in subdir:\n",
    "            continue\n",
    "    \n",
    "        print(\"=============================\")\n",
    "        print(f\"Experiment Results for {key} - Mode:{key, subdir}\")\n",
    "        print(\"=============================\")\n",
    "\n",
    "        model_path = os.path.join(exp[\"path\"], subdir)\n",
    "        model_name=env_name+\".zip\"\n",
    "\n",
    "        mean_rews.append(benchmark_with_loaded_model(env_name, \n",
    "                                      model_path, \n",
    "                                      eval_mode=eval_mode, \n",
    "                                      masksemble_mode=exp[\"masksemble_mode\"],\n",
    "                                      model_name=model_name))\n",
    "        \n",
    "    print(\"================> MEAN REWARD: \", key, np.mean(mean_rews))\n",
    "            \n",
    "print(\"Complete run time for the evaluation series in s: \", time.time() - complete_start_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59b12711",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (va-for-rl)",
   "language": "python",
   "name": "venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
