{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/1_train_bc.ipynb)\n",
    "# Train an Agent using Behavior Cloning\n",
    "\n",
    "Behavior cloning is the most naive approach to imitation learning. \n",
    "We take the transitions of trajectories taken by some expert and use them as training samples to train a new policy.\n",
    "The method has many drawbacks and often does not work. \n",
    "However in this example, where we train an agent for the CartPole-v1 environment, it is feasible."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we need some kind of expert in CartPole-v1 so we can sample some expert trajectories.\n",
    "For convenience we just train one using the stable-baselines3 library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.ppo import MlpPolicy\n",
    "\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "expert = PPO(\n",
    "    policy=MlpPolicy,\n",
    "    env=env,\n",
    "    seed=0,\n",
    "    batch_size=64,\n",
    "    ent_coef=0.0,\n",
    "    learning_rate=0.0003,\n",
    "    n_epochs=10,\n",
    "    n_steps=64,\n",
    ")\n",
    "expert.learn(1000)  # Note: set to 100000 to train a proficient expert"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's quickly check if the expert is any good.\n",
    "We usually should be able to reach a reward of 500, which is the maximum achievable value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "reward, _ = evaluate_policy(expert, env, 10)\n",
    "print(reward)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can use the expert to sample some trajectories.\n",
    "We flatten them right away since we are only interested in the individual transitions for behavior cloning.\n",
    "`imitation` comes with a number of helper functions that makes collecting those transitions really easy. First we collect 50 episode rollouts, then we flatten them to just the transitions that we need for training.\n",
    "Note that the rollout function requires a vectorized environment and needs the `RolloutInfoWrapper` around each of the environments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.data import rollout\n",
    "from imitation.data.wrappers import RolloutInfoWrapper\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "import numpy as np\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "rollouts = rollout.rollout(\n",
    "    expert,\n",
    "    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n",
    "    rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n",
    "    rng=rng,\n",
    ")\n",
    "transitions = rollout.flatten_trajectories(rollouts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's have a quick look at what we just generated using those library functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    f\"\"\"The `rollout` function generated a list of {len(rollouts)} {type(rollouts[0])}.\n",
    "After flattening, this list is turned into a {type(transitions)} object containing {len(transitions)} transitions.\n",
    "The transitions object contains arrays for: {', '.join(transitions.__dict__.keys())}.\"\n",
    "\"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After we collected our transitions, it's time to set up our behavior cloning algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.algorithms import bc\n",
    "\n",
    "bc_trainer = bc.BC(\n",
    "    observation_space=env.observation_space,\n",
    "    action_space=env.action_space,\n",
    "    demonstrations=transitions,\n",
    "    rng=rng,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see the untrained policy only gets poor rewards:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_before_training, _ = evaluate_policy(bc_trainer.policy, env, 10)\n",
    "print(f\"Reward before training: {reward_before_training}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training, we can match the rewards of the expert (500):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bc_trainer.train(n_epochs=1)\n",
    "reward_after_training, _ = evaluate_policy(bc_trainer.policy, env, 10)\n",
    "print(f\"Reward after training: {reward_after_training}\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}