{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "98224456-92f8-4dce-b416-cb05658fedc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dowel_wrapper\n",
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "import numpy as np\n",
    "from sklearn.neighbors import KernelDensity\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from envs.mujoco.ant_env import AntEnv\n",
    "import io\n",
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Circle\n",
    "import matplotlib.colors as mcolors\n",
    "import matplotlib.cm as cm\n",
    "from matplotlib.lines import Line2D\n",
    "import scipy.stats as stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5d12018e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CPU_Unpickler(pickle.Unpickler):\n",
    "    def find_class(self, module, name):\n",
    "        if module == 'torch.storage' and name == '_load_from_bytes':\n",
    "            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')\n",
    "        else: return super().find_class(module, name)\n",
    "\n",
    "def load_models(chkpt):\n",
    "    # 1. Load Metra weights\n",
    "    METRA_EXP_ROOT = 'exp/ant_multi_goals_metra_chkpt_30k/sd000_s_21377600.0.1721411268_ant_nav_prime_sac'\n",
    "    with open(os.path.join(METRA_EXP_ROOT, f'itr_{chkpt}.pkl'), 'rb') as f:\n",
    "        metra_itr = CPU_Unpickler(f).load()\n",
    "\n",
    "    metra_algo = metra_itr['algo']\n",
    "    metra_algo.device = 'cpu'\n",
    "    metra_algo.option_policy._module.to('cpu')\n",
    "\n",
    "    metra_option_data = torch.load(os.path.join(METRA_EXP_ROOT, f'option_policy{chkpt}.pt'), map_location=torch.device('cpu'))\n",
    "    metra_option_policy = metra_option_data['policy']\n",
    "    metra_option_policy.to('cpu')\n",
    "    metra_option_policy.eval()\n",
    "\n",
    "    # 2. Load Metra SF TD weights\n",
    "    METRA_SF_TD_EXP_ROOT = 'exp/ant_multi_goals_metra_sf_td_chkpt_30k/sd000_s_21377629.0.1721411543_ant_nav_prime_sac'\n",
    "    with open(os.path.join(METRA_SF_TD_EXP_ROOT, f'itr_{chkpt}.pkl'), 'rb') as f:\n",
    "        metra_sf_td_itr = CPU_Unpickler(f).load()\n",
    "\n",
    "    metra_sf_td_algo = metra_sf_td_itr['algo']\n",
    "    metra_sf_td_algo.device = 'cpu'\n",
    "    metra_sf_td_algo.option_policy._module.to('cpu')\n",
    "\n",
    "    metra_sf_td_option_data = torch.load(os.path.join(METRA_SF_TD_EXP_ROOT, f'option_policy{chkpt}.pt'), map_location=torch.device('cpu'))\n",
    "    metra_sf_td_option_policy = metra_sf_td_option_data['policy']\n",
    "    metra_sf_td_option_policy.to('cpu')\n",
    "    metra_sf_td_option_policy.eval()\n",
    "    \n",
    "    return metra_algo, metra_option_policy, metra_sf_td_algo, metra_sf_td_option_policy\n",
    "\n",
    "CHKPT = 60_000\n",
    "metra_algo, metra_option_policy, metra_sf_td_algo, metra_sf_td_option_policy = load_models(CHKPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4dfa816e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "from garagei.experiment.option_local_runner import OptionLocalRunner\n",
    "from garaged.src.garage.experiment.experiment import ExperimentContext\n",
    "from garagei.sampler.option_multiprocessing_sampler import OptionMultiprocessingSampler\n",
    "from iod.utils import get_normalizer_preset\n",
    "from garagei.envs.consistent_normalized_env import consistent_normalize\n",
    "from garagei.envs.child_policy_env import ChildPolicyEnv\n",
    "\n",
    "def make_env(cp_path):\n",
    "    from envs.mujoco.ant_nav_prime_env import AntNavPrimeEnv\n",
    "\n",
    "    env = AntNavPrimeEnv(\n",
    "        max_path_length=200,\n",
    "        goal_range=7.5,\n",
    "        num_goal_steps=50,\n",
    "        reward_type='esparse',\n",
    "    )\n",
    "    cp_num_truncate_obs = 2\n",
    "    \n",
    "    normalizer_type = \"preset\"\n",
    "    normalizer_kwargs = {}\n",
    "    \n",
    "    normalizer_name = 'ant'\n",
    "    additional_dim = cp_num_truncate_obs\n",
    "    \n",
    "    normalizer_mean, normalizer_std = get_normalizer_preset(f'{normalizer_name}_preset')\n",
    "    if additional_dim > 0:\n",
    "        normalizer_mean = np.concatenate([normalizer_mean, np.zeros(additional_dim)])\n",
    "        normalizer_std = np.concatenate([normalizer_std, np.ones(additional_dim)])\n",
    "    env = consistent_normalize(env, normalize_obs=True, mean=normalizer_mean, std=normalizer_std, **normalizer_kwargs)\n",
    "\n",
    "    if not os.path.exists(cp_path):\n",
    "        import glob\n",
    "        cp_path = glob.glob(cp_path)[0]\n",
    "    cp_dict = torch.load(cp_path, map_location='cpu')\n",
    "\n",
    "    env = ChildPolicyEnv(\n",
    "        env,\n",
    "        cp_dict,\n",
    "        cp_action_range=1.5,\n",
    "        cp_unit_length=1,\n",
    "        cp_multi_step=25,\n",
    "        cp_num_truncate_obs=cp_num_truncate_obs,\n",
    "    )\n",
    "    \n",
    "    return env\n",
    "    \n",
    "METRA_CP_PATH = 'exp/ant_metra/sd000_s_56955647.0.1718292963_ant_metra/option_policy30000.pt'\n",
    "METRA_SF_TD_CP_PATH = 'exp/ant_metra_sf_td/sd000_s_56969167.0.1718294685_ant_metra_sf/option_policy30000.pt'\n",
    "    \n",
    "metra_env = make_env(METRA_CP_PATH)\n",
    "metra_contextualized_make_env = functools.partial(make_env, cp_path=METRA_CP_PATH)\n",
    "\n",
    "metra_sf_td_env = make_env(METRA_SF_TD_CP_PATH)\n",
    "metra_sf_td_contextualized_make_env = functools.partial(make_env, cp_path=METRA_SF_TD_CP_PATH)\n",
    "\n",
    "# Setup runners\n",
    "metra_runner = OptionLocalRunner(ExperimentContext(\n",
    "    snapshot_dir='.',\n",
    "    snapshot_mode='last',\n",
    "    snapshot_gap=1,\n",
    "))\n",
    "\n",
    "metra_runner.setup(\n",
    "    algo=metra_algo,\n",
    "    env=metra_env,\n",
    "    make_env=metra_contextualized_make_env,\n",
    "    sampler_cls=OptionMultiprocessingSampler,\n",
    "    sampler_args=dict(n_thread=1),\n",
    "    n_workers=1,\n",
    ")\n",
    "\n",
    "metra_sf_td_runner = OptionLocalRunner(ExperimentContext(\n",
    "    snapshot_dir='.',\n",
    "    snapshot_mode='last',\n",
    "    snapshot_gap=1,\n",
    "))\n",
    "\n",
    "metra_sf_td_runner.setup(\n",
    "    algo=metra_sf_td_algo,\n",
    "    env=metra_sf_td_env,\n",
    "    make_env=metra_sf_td_contextualized_make_env,\n",
    "    sampler_cls=OptionMultiprocessingSampler,\n",
    "    sampler_args=dict(n_thread=1),\n",
    "    n_workers=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "129ef4b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _get_trajectories(runner,\n",
    "                        sampler_key,\n",
    "                        batch_size=None,\n",
    "                        extras=None,\n",
    "                        update_stats=False,\n",
    "                        worker_update=None,\n",
    "                        env_update=None,\n",
    "                         option_policy=None):\n",
    "    if batch_size is None:\n",
    "        batch_size = len(extras)\n",
    "    policy_sampler_key = sampler_key[6:] if sampler_key.startswith('local_') else sampler_key\n",
    "    time_get_trajectories = [0.0]\n",
    "\n",
    "    trajectories, infos = runner.obtain_exact_trajectories(\n",
    "        runner.step_itr,\n",
    "        sampler_key=sampler_key,\n",
    "        batch_size=batch_size,\n",
    "        agent_update=_get_policy_param_values({'option_policy':option_policy}, policy_sampler_key),\n",
    "        env_update=env_update,\n",
    "        worker_update=worker_update,\n",
    "        extras=extras,\n",
    "        update_stats=update_stats,\n",
    "    )\n",
    "    print(f'_get_trajectories({sampler_key}) {time_get_trajectories[0]}s')\n",
    "\n",
    "    for traj in trajectories:\n",
    "        for key in ['ori_obs', 'next_ori_obs', 'coordinates', 'next_coordinates']:\n",
    "            if key not in traj['env_infos']:\n",
    "                continue\n",
    "\n",
    "    return trajectories\n",
    "\n",
    "def _get_policy_param_values(policy, key):\n",
    "    param_dict = policy[key].get_param_values()\n",
    "    for k in param_dict.keys():\n",
    "        param_dict[k] = param_dict[k].detach().cpu()\n",
    "    return param_dict\n",
    "\n",
    "def _generate_option_extras(options):\n",
    "    return [{'option': option} for option in options]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a0ecee59",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "METRA SUCCESS RATES\n",
      "goal_1 0.93 (0.8791182485051084, 0.9808817514948917))\n",
      "goal_2 0.37 (0.27371853383960454, 0.46628146616039545))\n",
      "goal_3 0.37 (0.2737185338396045, 0.4662814661603955))\n",
      "goal_4 0.35 (0.25488209883009155, 0.4451179011699084))\n",
      "\n",
      "METRA SUCCESS RATES\n",
      "goal_1 0.93 (0.8791182485051084, 0.9808817514948917))\n",
      "goal_2 0.37 (0.27371853383960454, 0.46628146616039545))\n",
      "goal_3 0.37 (0.2737185338396045, 0.4662814661603955))\n",
      "goal_4 0.35 (0.25488209883009155, 0.4451179011699084))\n",
      "\n",
      "METRA SF TD SUCCESS RATES\n",
      "goal_1 0.95 ((0.9065371337811178, 0.9934628662188821))\n",
      "goal_2 0.38 ((0.2832036009440236, 0.4767963990559764))\n",
      "goal_3 0.38 ((0.28320360094402364, 0.47679639905597637))\n",
      "goal_4 0.44 ((0.3410098664856729, 0.5389901335143271))\n",
      "\n",
      "METRA SF TD SUCCESS RATES\n",
      "goal_1 0.95 ((0.9065371337811178, 0.9934628662188821))\n",
      "goal_2 0.38 ((0.2832036009440236, 0.4767963990559764))\n",
      "goal_3 0.38 ((0.28320360094402364, 0.47679639905597637))\n",
      "goal_4 0.44 ((0.3410098664856729, 0.5389901335143271))\n",
      "\n"
     ]
    }
   ],
   "source": [
    "NUM_RANDOM_TRAJECTORIES = 1\n",
    "\n",
    "# metra_random_trajectories = _get_trajectories(\n",
    "#     metra_runner,\n",
    "#     sampler_key='option_policy',\n",
    "#     extras=[{} for _ in range(NUM_RANDOM_TRAJECTORIES)],\n",
    "#     worker_update=dict(\n",
    "#         _render=False,\n",
    "#         _deterministic_initial_state=False,\n",
    "#         _deterministic_policy=True, \n",
    "#     ),\n",
    "#     env_update=dict(_action_noise_std=None),\n",
    "#     option_policy=metra_option_policy\n",
    "# )\n",
    "\n",
    "def evaluate_goals(env, option_policy):\n",
    "    obs = env.reset()\n",
    "    step = 0\n",
    "    done = False\n",
    "    reward_total = 0\n",
    "    attempted_goals = []\n",
    "    all_goals = [obs[-2:]]\n",
    "    xs = []\n",
    "    ys = []\n",
    "    success = {\n",
    "        'goal_1': 0,\n",
    "        'goal_2': 0,\n",
    "        'goal_3': 0,\n",
    "        'goal_4': 0\n",
    "    }\n",
    "    while not done:\n",
    "        attempted_goals.append(obs[-2:])\n",
    "        action, agent_info = option_policy.get_action(obs)\n",
    "        next_obs, reward, done, info = env.step(action, debug=True)\n",
    "        obs = next_obs\n",
    "        \n",
    "\n",
    "        step += 1\n",
    "        reward_total += reward\n",
    "        xs.append(info['coordinates'][:, 0])\n",
    "        ys.append(info['coordinates'][:, 1])\n",
    "        for _obs in info['original_next_observations']:\n",
    "            if not np.allclose(all_goals[-1], _obs[-2:]):\n",
    "                all_goals.append(_obs[-2:])\n",
    "                \n",
    "        for i in range(1, 5):\n",
    "            success[f'goal_{i}'] = 1 if success[f'goal_{i}'] else int(info[f'goal_{i}'] > 0)\n",
    "        \n",
    "#     print(reward_total)\n",
    "    # drop the very last goal, since we don't try to reach it\n",
    "    all_goals.pop()\n",
    "\n",
    "    return xs, ys, attempted_goals, reward_total, all_goals, success\n",
    "    \n",
    "def plot_multi_goal(env, option_policy):\n",
    "    xs, ys, attempted_goals, _, all_goals, success = evaluate_goals(env, option_policy)\n",
    "    print(success)\n",
    "\n",
    "    # 1. plot the goals\n",
    "    fig, ax = plt.subplots()\n",
    "    cmap = cm.get_cmap('tab10', 10)\n",
    "\n",
    "    # plot goals\n",
    "    for i, unique_goal in enumerate(all_goals, 1):\n",
    "        # add goal to plot\n",
    "        ax.scatter(unique_goal[0], unique_goal[1], color=cmap(i), marker='*', s=700)\n",
    "        ax.text(unique_goal[0], unique_goal[1], str(i), fontsize=10, ha='center', color='white', verticalalignment='center')\n",
    "        circle = plt.Circle((unique_goal[0], unique_goal[1]), 3, color=cmap(i), fill=False)\n",
    "        ax.add_patch(circle)\n",
    "\n",
    "    for _xs, _ys, goal in zip(xs, ys, attempted_goals):\n",
    "        goal_num  = None\n",
    "        for i, unique_goal in enumerate(all_goals, 1):\n",
    "            if np.allclose(goal, unique_goal):\n",
    "                goal_num = i\n",
    "                break\n",
    "\n",
    "        # add path to plot\n",
    "        ax.scatter(_xs, _ys, label='Metra SF TD', color=cmap(goal_num), linestyle='dotted', s=0.5)\n",
    "    \n",
    "    # ax.set_xlim(-20, 20)\n",
    "    # ax.set_ylim(-20, 20)\n",
    "\n",
    "    # ax.plot(metra_xs, metra_ys, label='Metra', color=cmap(0), linestyle='dashed')\n",
    "    # ax.plot(metra_sf_td_xs, metra_sf_td_ys, label='Metra SF TD', color=cmap(1), linestyle='dotted')\n",
    "\n",
    "    # legend_elements = [\n",
    "    #     Line2D([0], [0], color='black', lw=2, linestyle='dashed', label='Metra'),\n",
    "    #     Line2D([0], [0], color='black', lw=2, linestyle='dotted', label='Metra SF TD')\n",
    "    # ]\n",
    "    # ax.legend(handles=legend_elements, loc='upper left')\n",
    "    ax.set_aspect('equal', 'box')\n",
    "    plt.show()\n",
    "    \n",
    "# plot_multi_goal(metra_env, metra_option_policy)\n",
    "\n",
    "# METRA metrics\n",
    "NUM_ROLLOUTS = 100\n",
    "all_success = defaultdict(list)\n",
    "for i in range(NUM_ROLLOUTS):\n",
    "    xs, ys, attempted_goals, _, all_goals, success = evaluate_goals(metra_env, metra_option_policy)\n",
    "    for k in success:\n",
    "        all_success[k].append(success[k])\n",
    "\n",
    "print('METRA SUCCESS RATES')\n",
    "for k in all_success:\n",
    "    print(k, f'{np.mean(all_success[k])} {stats.t.interval(0.95, NUM_ROLLOUTS-1, np.mean(all_success[k]), stats.sem(all_success[k]))})')\n",
    "print()\n",
    "\n",
    "all_success = defaultdict(list)\n",
    "for i in range(NUM_ROLLOUTS):\n",
    "    xs, ys, attempted_goals, _, all_goals, success = evaluate_goals(metra_sf_td_env, metra_sf_td_option_policy)\n",
    "    for k in success:\n",
    "        all_success[k].append(success[k])\n",
    "\n",
    "print('METRA SF TD SUCCESS RATES')\n",
    "for k in all_success:\n",
    "    print(k, f'{np.mean(all_success[k])} ({stats.t.interval(0.95, NUM_ROLLOUTS-1, np.mean(all_success[k]), stats.sem(all_success[k]))})')\n",
    "print()\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "887854d3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
