{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import functools\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "from tests.metaworld.envs.mujoco.sawyer_xyz.utils import trajectory_summary\n",
    "from tests.metaworld.envs.mujoco.sawyer_xyz.test_scripted_policies import ALL_ENVS, test_cases_latest_nonoise\n",
    "\n",
    "sns.set()\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def sample_trajectories_from(env, policy, act_noise_pct, iters=100):\n",
    "    sampled_rewards = []\n",
    "    sampled_returns = []\n",
    "    sampled_first_successes = []\n",
    "\n",
    "    for _ in range(iters):\n",
    "        s = trajectory_summary(env, policy, act_noise_pct, end_on_success=False)\n",
    "        sampled_rewards.append(s[1])\n",
    "        sampled_returns.append(s[2])\n",
    "        sampled_first_successes.append(s[3])\n",
    "\n",
    "    sampled_rewards = np.vstack(sampled_rewards)\n",
    "    sampled_returns = np.vstack(sampled_returns)\n",
    "    sampled_first_successes = np.array(sampled_first_successes)\n",
    "\n",
    "    return sampled_rewards, sampled_returns, sampled_first_successes\n",
    "\n",
    "\n",
    "def plot(rewards, returns, first_successes, tag):\n",
    "    first_success = min(int(first_successes.mean()), rewards.shape[1])\n",
    "    first_success_rew = rewards.mean(axis=0)[first_success]\n",
    "    first_success_ret = returns.mean(axis=0)[first_success]\n",
    "    \n",
    "    fig, ax = plt.subplots(1, 2, figsize=(6.75, 4))\n",
    "\n",
    "    reward_df = pd.DataFrame(rewards).melt()\n",
    "    ax[0] = sns.lineplot(x='variable', y='value', data=reward_df, ax=ax[0], ci=95, lw=.5)\n",
    "    ax[0].set_xlabel('Time Steps')\n",
    "    ax[0].set_ylabel('Reward')\n",
    "    ax[0].set_title('Rewards')\n",
    "    ax[0].vlines(first_success, ymin=0, ymax=first_success_rew, linestyle='--', color='green')\n",
    "    ax[0].hlines(first_success_rew, xmin=0, xmax=first_success, linestyle='--', color='green')\n",
    "#     ax[0].set_yscale('symlog')\n",
    "\n",
    "    return_df = pd.DataFrame(returns).melt()\n",
    "    ax[1] = sns.lineplot(x='variable', y='value', data=return_df, ax=ax[1], ci=95, lw=.5)\n",
    "    ax[1].set_xlabel('Time Steps')\n",
    "    ax[1].set_ylabel('Return')\n",
    "    ax[1].set_title('Returns')\n",
    "    ax[1].vlines(first_success, ymin=0, ymax=first_success_ret, linestyle='--', color='green')\n",
    "    ax[1].hlines(first_success_ret, xmin=0, xmax=first_success, linestyle='--', color='green')\n",
    "#     ax[1].set_yscale('symlog')\n",
    "\n",
    "    plt.subplots_adjust(top=.85)\n",
    "    fig.suptitle(f'{tag} (n={rewards.shape[0]})')\n",
    "    fig.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
    "\n",
    "    plt.show()\n",
    "    if not os.path.exists('figures'):\n",
    "        os.mkdir('figures')\n",
    "    fig.savefig(f'figures/{tag}_rewards_returns.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "config = [\n",
    "    # env name, action noise, path length\n",
    "    ['pick-place-v2', np.zeros(4), 200],\n",
    "]\n",
    "\n",
    "for env, noise, path_length in config:\n",
    "    tag = env + '-noise-' + np.array2string(noise, precision=2, separator=',', suppress_small=True)\n",
    "\n",
    "    policy = functools.reduce(lambda a,b : a if a[0] == env else b, test_cases_latest_nonoise)[1]\n",
    "    env = ALL_ENVS[env]()\n",
    "    env.max_path_length = path_length\n",
    "    env._partially_observable = False\n",
    "    env._freeze_rand_vec = False\n",
    "    env._set_task_called = True\n",
    "\n",
    "    sampled_rewards, sampled_returns, sampled_first_successes = sample_trajectories_from(env, policy, noise)\n",
    "    plot(sampled_rewards, sampled_returns, sampled_first_successes, tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "i = 12\n",
    "first_success = sampled_first_successes[i]\n",
    "first_success_reward = sampled_rewards[i][first_success]\n",
    "fig, ax = plt.subplots(1, 1, figsize=(6.75, 4))\n",
    "ax.plot(np.arange(len(sampled_rewards[0])), sampled_rewards[i])\n",
    "\n",
    "ax.vlines(first_success, ymin=0, ymax=first_success_reward, linestyle='--', color='green')\n",
    "ax.hlines(first_success_reward, xmin=0, xmax=first_success, linestyle='--', color='green')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}