{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/2_train_dagger.ipynb)\n",
    "# Train an Agent using the DAgger Algorithm\n",
    "\n",
    "The DAgger algorithm is an extension of behavior cloning. \n",
    "In behavior cloning, the training trajectories are recorded directly from an expert.\n",
    "In DAgger, the learner generates the trajectories but an expert corrects the actions with the optimal actions in each of the visited states.\n",
    "This ensures that the state distribution of the training data matches that of the learner's current policy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we need an expert to learn from:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.ppo import MlpPolicy\n",
    "\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "expert = PPO(\n",
    "    policy=MlpPolicy,\n",
    "    env=env,\n",
    "    seed=0,\n",
    "    batch_size=64,\n",
    "    ent_coef=0.0,\n",
    "    learning_rate=0.0003,\n",
    "    n_epochs=10,\n",
    "    n_steps=64,\n",
    ")\n",
    "expert.learn(1000)  # Note: set to 100000 to train a proficient expert"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can construct a DAgger trainer und use it to train the policy on the cartpole environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tempfile\n",
    "import gym\n",
    "import numpy as np\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "\n",
    "from imitation.algorithms import bc\n",
    "from imitation.algorithms.dagger import SimpleDAggerTrainer\n",
    "\n",
    "venv = DummyVecEnv([lambda: gym.make(\"CartPole-v1\")])\n",
    "\n",
    "\n",
    "bc_trainer = bc.BC(\n",
    "    observation_space=env.observation_space,\n",
    "    action_space=env.action_space,\n",
    "    rng=np.random.default_rng(),\n",
    ")\n",
    "\n",
    "with tempfile.TemporaryDirectory(prefix=\"dagger_example_\") as tmpdir:\n",
    "    print(tmpdir)\n",
    "    dagger_trainer = SimpleDAggerTrainer(\n",
    "        venv=venv,\n",
    "        scratch_dir=tmpdir,\n",
    "        expert_policy=expert,\n",
    "        bc_trainer=bc_trainer,\n",
    "        rng=np.random.default_rng(),\n",
    "    )\n",
    "\n",
    "    dagger_trainer.train(2000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, the evaluation shows, that we actually trained a policy that solves the environment (500 is the max reward)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "reward, _ = evaluate_policy(dagger_trainer.policy, env, 10)\n",
    "print(reward)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
