{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/6_train_mce.ipynb)\n",
    "# Learn a Reward Function using Maximum Conditional Entropy Inverse Reinforcement Learning\n",
    "\n",
    "Here, we're going to take a tabular environment with a pre-defined reward function, Cliffworld, and solve for the optimal policy. We then generate demonstrations from this policy, and use them to learn an approximation to the true reward function with MCE IRL. Finally, we directly compare the learned reward to the ground-truth reward (which we have access to in this example)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Cliffworld is a POMDP, and its \"observations\" consist of the (partial) observations proper and the (full) hidden environment state. We use `DictExtractWrapper` to extract only the hidden states from the environment, turning it into a fully observable MDP to make computing the optimal policy easy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "from seals import base_envs\n",
    "from seals.diagnostics.cliff_world import CliffWorldEnv\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from imitation.algorithms.mce_irl import (\n",
    "    MCEIRL,\n",
    "    mce_occupancy_measures,\n",
    "    mce_partition_fh,\n",
    "    TabularPolicy,\n",
    ")\n",
    "from imitation.data import rollout\n",
    "from imitation.rewards import reward_nets\n",
    "\n",
    "env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True)\n",
    "env_single = env_creator()\n",
    "\n",
    "state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator())\n",
    "\n",
    "# This is just a vectorized environment because `generate_trajectories` expects one\n",
    "state_venv = DummyVecEnv([state_env_creator] * 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we derive an expert policy using Bellman backups. We analytically compute the occupancy measures, and also sample some expert trajectories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, pi = mce_partition_fh(env_single)\n",
    "\n",
    "_, om = mce_occupancy_measures(env_single, pi=pi)\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "expert = TabularPolicy(\n",
    "    state_space=env_single.state_space,\n",
    "    action_space=env_single.action_space,\n",
    "    pi=pi,\n",
    "    rng=rng,\n",
    ")\n",
    "\n",
    "expert_trajs = rollout.generate_trajectories(\n",
    "    policy=expert,\n",
    "    venv=state_venv,\n",
    "    sample_until=rollout.make_min_timesteps(5000),\n",
    "    rng=rng,\n",
    ")\n",
    "\n",
    "print(\"Expert stats: \", rollout.rollout_stats(expert_trajs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the reward function\n",
    "\n",
    "The true reward here is not linear in the reduced feature space (i.e $(x,y)$ coordinates). Finding an appropriate linear reward is impossible, but an MLP should Just Work™."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch as th\n",
    "\n",
    "\n",
    "def train_mce_irl(demos, hidden_sizes, lr=0.01, **kwargs):\n",
    "    reward_net = reward_nets.BasicRewardNet(\n",
    "        env_single.observation_space,\n",
    "        env_single.action_space,\n",
    "        hid_sizes=hidden_sizes,\n",
    "        use_action=False,\n",
    "        use_done=False,\n",
    "        use_next_state=False,\n",
    "    )\n",
    "\n",
    "    mce_irl = MCEIRL(\n",
    "        demos,\n",
    "        env_single,\n",
    "        reward_net,\n",
    "        log_interval=250,\n",
    "        optimizer_kwargs=dict(lr=lr),\n",
    "        rng=rng,\n",
    "    )\n",
    "    occ_measure = mce_irl.train(**kwargs)\n",
    "\n",
    "    imitation_trajs = rollout.generate_trajectories(\n",
    "        policy=mce_irl.policy,\n",
    "        venv=state_venv,\n",
    "        sample_until=rollout.make_min_timesteps(5000),\n",
    "        rng=rng,\n",
    "    )\n",
    "    print(\"Imitation stats: \", rollout.rollout_stats(imitation_trajs))\n",
    "\n",
    "    plt.figure(figsize=(10, 5))\n",
    "    plt.subplot(1, 2, 1)\n",
    "    env_single.draw_value_vec(occ_measure)\n",
    "    plt.title(\"Occupancy for learned reward\")\n",
    "    plt.xlabel(\"Gridworld x-coordinate\")\n",
    "    plt.ylabel(\"Gridworld y-coordinate\")\n",
    "    plt.subplot(1, 2, 2)\n",
    "    _, true_occ_measure = mce_occupancy_measures(env_single)\n",
    "    env_single.draw_value_vec(true_occ_measure)\n",
    "    plt.title(\"Occupancy for true reward\")\n",
    "    plt.xlabel(\"Gridworld x-coordinate\")\n",
    "    plt.ylabel(\"Gridworld y-coordinate\")\n",
    "    plt.show()\n",
    "\n",
    "    plt.figure(figsize=(10, 5))\n",
    "    plt.subplot(1, 2, 1)\n",
    "    env_single.draw_value_vec(\n",
    "        reward_net(th.as_tensor(env_single.observation_matrix), None, None, None)\n",
    "        .detach()\n",
    "        .numpy()\n",
    "    )\n",
    "    plt.title(\"Learned reward\")\n",
    "    plt.xlabel(\"Gridworld x-coordinate\")\n",
    "    plt.ylabel(\"Gridworld y-coordinate\")\n",
    "    plt.subplot(1, 2, 2)\n",
    "    env_single.draw_value_vec(env_single.reward_matrix)\n",
    "    plt.title(\"True reward\")\n",
    "    plt.xlabel(\"Gridworld x-coordinate\")\n",
    "    plt.ylabel(\"Gridworld y-coordinate\")\n",
    "    plt.show()\n",
    "\n",
    "    return mce_irl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, a linear reward model cannot fit the data. Even though we're training the model on analytically computed occupancy measures for the optimal policy, the resulting reward and occupancy frequencies diverge sharply."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mce_irl(om, hidden_sizes=[])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's try using a very simple nonlinear reward model: an MLP with a single hidden layer. We first train it on the analytically computed occupancy measures. This should give a very precise result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mce_irl(om, hidden_sizes=[256])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we train it on trajectories sampled from the expert. This gives a stochastic approximation to occupancy measure, so performance is a little worse. Using more expert trajectories should improve performance -- try it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mce_irl_from_trajs = train_mce_irl(expert_trajs[0:10], hidden_sizes=[256])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "While the learned reward function is quite different from the true reward function, it induces a virtually identical occupancy measure over the states. In particular, states below the top row get almost the same reward as top-row states. This is because in Cliff World, there is an upward-blowing wind which will push the agent toward the top row with probability 0.3 at every timestep.\n",
    "\n",
    "Even though the agent only gets reward in the top row squares, and maximum reward in the top righthand square, the reward model considers it to be almost as good to end up in one of the squares below the top rightmost square, since the wind will eventually blow the agent to the goal square."
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "439158cd89905785fcc749928062ade7bfccc3f087fab145e5671f895c635937"
  },
  "kernelspec": {
   "display_name": "Python 3.9.13 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
