{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import functools\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "# %matplotlib tk\n",
    "from matplotlib import cm\n",
    "import seaborn as sns\n",
    "\n",
    "from tests.metaworld.envs.mujoco.sawyer_xyz.utils import trajectory_summary\n",
    "from tests.metaworld.envs.mujoco.sawyer_xyz.test_scripted_policies import ALL_ENVS, test_cases_latest_nonoise\n",
    "\n",
    "sns.set()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def sample_trajectories_from(env, policy, act_noise_pct, iters=100):\n",
    "    sampled_rewards = []\n",
    "    sampled_returns = []\n",
    "    sampled_first_successes = []\n",
    "\n",
    "    for _ in range(iters):\n",
    "        s = trajectory_summary(env, policy, act_noise_pct, end_on_success=False)\n",
    "        sampled_rewards.append(s[1])\n",
    "        sampled_returns.append(s[2])\n",
    "        sampled_first_successes.append(s[3])\n",
    "\n",
    "    sampled_rewards = np.vstack(sampled_rewards)\n",
    "    sampled_returns = np.vstack(sampled_returns)\n",
    "    sampled_first_successes = np.array(sampled_first_successes)\n",
    "\n",
    "    return sampled_rewards, sampled_returns, sampled_first_successes\n",
    "\n",
    "\n",
    "def plot(rewards, returns, tag, dim):\n",
    "    x = np.linspace(0, 1, rewards.shape[0])\n",
    "    y = np.arange(rewards.shape[1])\n",
    "    X, Y = np.meshgrid(x, y)\n",
    "    \n",
    "    fig = plt.figure(figsize=(12, 5))\n",
    "\n",
    "    Z = rewards.T\n",
    "    norm = plt.Normalize(Z.min(), Z.max())\n",
    "    colors = cm.viridis(norm(Z))\n",
    "    rcount, ccount, _ = colors.shape\n",
    "\n",
    "    ax0 = fig.add_subplot(121, projection='3d')\n",
    "    surf = ax0.plot_surface(X, Y, Z, rcount=rcount, ccount=ccount, facecolors=colors, shade=False)\n",
    "\n",
    "    surf.set_facecolor((0,0,0,0))\n",
    "    ax0.set_xlabel(f'Noise Percent in Action Dim {dim}')\n",
    "    ax0.set_ylabel('Time Steps')\n",
    "    ax0.set_zlabel('Rewards')\n",
    "    ax0.set_zscale('symlog')\n",
    "\n",
    "    Z = returns.T\n",
    "    norm = plt.Normalize(Z.min(), Z.max())\n",
    "    colors = cm.viridis(norm(Z))\n",
    "    rcount, ccount, _ = colors.shape\n",
    "\n",
    "    ax1 = fig.add_subplot(122, projection='3d')\n",
    "    surf = ax1.plot_surface(X, Y, Z, rcount=rcount, ccount=ccount, facecolors=colors, shade=False)\n",
    "\n",
    "    surf.set_facecolor((0,0,0,0))\n",
    "    ax1.set_xlabel(f'Noise Percent in Action Dim {dim}')\n",
    "    ax1.set_ylabel('Time Steps')\n",
    "    ax1.set_zlabel('Returns')\n",
    "    ax1.set_zscale('symlog')\n",
    "\n",
    "    plt.subplots_adjust(top=.85)\n",
    "    fig.suptitle(f'{tag}')\n",
    "    fig.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
    "\n",
    "    ax0.view_init(30, +32)\n",
    "    ax1.view_init(30, -45)\n",
    "    plt.show()\n",
    "    if not os.path.exists('figures'):\n",
    "        os.mkdir('figures')\n",
    "    fig.savefig(f'figures/{tag}_vary_noise_rewards_returns.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "config = [\n",
    "#     ['button-press-topdown-v1', 3],\n",
    "    ['pick-place-v2', 3],\n",
    "#     ['reach-v2', 3],\n",
    "#     ['window-open-v2', 3],\n",
    "#     ['sweep-v1', 3],\n",
    "#     ['sweep-into-v1', 3],\n",
    "#     ['shelf-place-v2', 3],\n",
    "#     ['push-v2', 3],\n",
    "#     ['peg-insert-side-v2', 3],\n",
    "#     ['lever-pull-v2', 3],\n",
    "]\n",
    "\n",
    "for env, axis in config:\n",
    "    tag = env + '-vary-axis-' + str(axis)\n",
    "    policy = functools.reduce(lambda a,b : a if a[0] == env else b, test_cases_latest_nonoise)[1]\n",
    "    env = ALL_ENVS[env]()\n",
    "    env._partially_observable = False\n",
    "    env._freeze_rand_vec = False\n",
    "    env._set_task_called = True\n",
    "\n",
    "    sampled_rewards, sampled_returns = [], []\n",
    "    noise = np.full(4, .75)\n",
    "    tag = tag + '-others-75-percent'\n",
    "\n",
    "    for i in np.linspace(0, 1, 10):\n",
    "        noise[axis] = i\n",
    "\n",
    "        rew, ret, _ = sample_trajectories_from(env, policy, noise)\n",
    "        sampled_rewards.append(rew.mean(axis=0))\n",
    "        sampled_returns.append(ret.mean(axis=0))\n",
    "\n",
    "    sampled_rewards = np.vstack(sampled_rewards)\n",
    "    sampled_returns = np.vstack(sampled_returns)\n",
    "\n",
    "    plot(sampled_rewards, sampled_returns, tag, axis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
