{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/4_train_airl.ipynb)\n",
    "# Train an Agent using Adversarial Inverse Reinforcement Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As usual, we first need an expert. \n",
    "Note that we now use a variant of the CartPole environment from the seals package, which has fixed episode durations. Read more about why we do this [here](https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.ppo import MlpPolicy\n",
    "import seals  # needed to load environments\n",
    "\n",
    "env = gym.make(\"seals/CartPole-v0\")\n",
    "expert = PPO(\n",
    "    policy=MlpPolicy,\n",
    "    env=env,\n",
    "    seed=0,\n",
    "    batch_size=64,\n",
    "    ent_coef=0.0,\n",
    "    learning_rate=0.0003,\n",
    "    n_epochs=10,\n",
    "    n_steps=64,\n",
    ")\n",
    "expert.learn(1000)  # Note: set to 100000 to train a proficient expert"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We generate some expert trajectories, that the discriminator needs to distinguish from the learner's trajectories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.data import rollout\n",
    "from imitation.data.wrappers import RolloutInfoWrapper\n",
    "from imitation.util.util import make_vec_env\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "import numpy as np\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "rollouts = rollout.rollout(\n",
    "    expert,\n",
    "    make_vec_env(\n",
    "        \"seals/CartPole-v0\",\n",
    "        n_envs=5,\n",
    "        post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],\n",
    "        rng=rng,\n",
    "    ),\n",
    "    rollout.make_sample_until(min_timesteps=None, min_episodes=60),\n",
    "    rng=rng,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to set up our AIRL trainer.\n",
    "Note, that the `reward_net` is actually the network of the discriminator.\n",
    "We evaluate the learner before and after training so we can see if it made any progress."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.algorithms.adversarial.airl import AIRL\n",
    "from imitation.rewards.reward_nets import BasicShapedRewardNet\n",
    "from imitation.util.networks import RunningNorm\n",
    "from imitation.util.util import make_vec_env\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "import gym\n",
    "import seals\n",
    "\n",
    "\n",
    "venv = make_vec_env(\"seals/CartPole-v0\", n_envs=8, rng=rng)\n",
    "learner = PPO(\n",
    "    env=venv,\n",
    "    policy=MlpPolicy,\n",
    "    batch_size=64,\n",
    "    ent_coef=0.0,\n",
    "    learning_rate=0.0003,\n",
    "    n_epochs=10,\n",
    ")\n",
    "reward_net = BasicShapedRewardNet(\n",
    "    venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm\n",
    ")\n",
    "airl_trainer = AIRL(\n",
    "    demonstrations=rollouts,\n",
    "    demo_batch_size=1024,\n",
    "    gen_replay_buffer_capacity=2048,\n",
    "    n_disc_updates_per_round=4,\n",
    "    venv=venv,\n",
    "    gen_algo=learner,\n",
    "    reward_net=reward_net,\n",
    ")\n",
    "\n",
    "learner_rewards_before_training, _ = evaluate_policy(\n",
    "    learner, venv, 100, return_episode_rewards=True\n",
    ")\n",
    "airl_trainer.train(20000)  # Note: set to 300000 for better results\n",
    "learner_rewards_after_training, _ = evaluate_policy(\n",
    "    learner, venv, 100, return_episode_rewards=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least.\n",
    "If not, just re-run the above cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "print(np.mean(learner_rewards_after_training))\n",
    "print(np.mean(learner_rewards_before_training))\n",
    "\n",
    "plt.hist(\n",
    "    [learner_rewards_before_training, learner_rewards_after_training],\n",
    "    label=[\"untrained\", \"trained\"],\n",
    ")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}