{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25ba290c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from gridworld.persuadable_gw import PersuadableGWEnvironment\n",
    "from matplotlib import pyplot as plt\n",
    "from visualization import plot_goal\n",
    "from abstraction import DiscreteSequenceEncoder, EnergyModel, GoalAbstraction, sample_goal\n",
    "from crl import ContrastiveQf, DiscretePolicy, StableContrastiveRL\n",
    "from buffer import EpisodicBuffer\n",
    "from IPython.display import clear_output\n",
    "from abstraction import HierarchicalPolicy\n",
    "from collections import deque\n",
    "\n",
    "from visualization import plot_trajectory\n",
    "\n",
    "\n",
    "device = \"cuda\"\n",
    "torch.set_default_device(device)\n",
    "\n",
    "\n",
    "env_config =  {\"step_size\": 10.0,\n",
    "        \"num_rotation_steps\": 4,\n",
    "        \"scan_resolution\": 64,\n",
    "        \"scan_range\": 100,\n",
    "        \"scan_angle_range\": 2*torch.pi,\n",
    "        \"observation_types\": [\"lidar\"],\n",
    "        \"scale\": 0.5,\n",
    "        \"action_space\": \"dir\",\n",
    "        \"map_config\" : {\n",
    "            \"seed\" : 7,\n",
    "            \"num_extra_passages\" : 1,\n",
    "            \"num_rooms\" : 9,\n",
    "            \"size\" : 400,\n",
    "            \"room_scale\" : 5,\n",
    "            \"passage_scale\" : 5\n",
    "        }\n",
    "    }\n",
    "\n",
    "observation_shape = 64\n",
    "latents = 16\n",
    "categories = 16\n",
    "\n",
    "goal_abstraction_config = {\n",
    "    \"batch_size\": 256,\n",
    "    \"kl_warmup\": 50000,\n",
    "    \"kl_max\": 5.0,\n",
    "    \"kl_target\" : 5.0,\n",
    "        \"min_traj_len\": 6,\n",
    "        \"max_traj_len\" : 50,\n",
    "        \"latent_space\" : [latents,categories],\n",
    "        \"negative_batch\" : True,\n",
    "        \"max_info_rec_steps\" : 100000,\n",
    "        \"merged_trajectories\": True,\n",
    "        \"symmetric\": False, #dont care about order in subset\n",
    "        \"symmetric_negatives\" : True, #actively use symmetry as negative examples\n",
    "        \"zero_goal\" : True,\n",
    "        \"one_goal\" : True,\n",
    "        \"individual\": True\n",
    "}\n",
    "\n",
    "## Environment\n",
    "env = PersuadableGWEnvironment(device=device, **env_config)\n",
    "\n",
    "## Abstraction networks\n",
    "# Note: The observation decoder is not required for this environment, making training fully reconstruction free.\n",
    "goal_encoder = DiscreteSequenceEncoder(input_shape=observation_shape, categoricals=categories, latents=latents).to(device)\n",
    "subset_energy = EnergyModel(goal_shape=latents*categories).to(device)\n",
    "goal_abstraction = GoalAbstraction(device=device, subset_energy=subset_energy, goal_encoder=goal_encoder, observation_decoder=None, config=goal_abstraction_config).to(device)\n",
    "\n",
    "## Contrastive RL networks\n",
    "\n",
    "# Note: Here we use the encoded observation history as abstract state for the policy and Q-function.\n",
    "# In the paper we used an additional RNN (i.e. a recurrent policy and a recurrent Q function) to encode the history of observations to tackle the POMDP setting, for comparability with observation based goals.\n",
    "# For simplicity we use our representation which is able to perform this task.\n",
    "\n",
    "policy = DiscretePolicy(obs_shape=observation_shape+latents*categories, goal_shape=latents*categories, action_shape=4)\n",
    "qf = ContrastiveQf(hidden_sizes=[256,256], representation_dim=32, action_dim=4, goal_dim=latents*categories, obs_dim=observation_shape+latents*categories)\n",
    "agent = StableContrastiveRL(device=device, goal_rep=goal_abstraction, policy=policy, qf=qf, use_kl_reg=True, kl_target=0.5, relabel_steps=10, use_adaptive_entropy_reg=False, trajectory_length=200, batch_size=2048)\n",
    "goal_sampler = HierarchicalPolicy(device=device, input_shape=64, goal_shape=latents*categories, goal_rep=goal_abstraction, categoricals=categories, latents=latents, batch_size=1024, hidden_dim=256)\n",
    "\n",
    "# Replay buffer to store experiences.\n",
    "buffer_shapes = (\n",
    "    {\n",
    "        \"step\": [1],\n",
    "        \"observation\": [observation_shape],\n",
    "        \"state\" : [64 + latents*categories],\n",
    "        \"action\": [4],\n",
    "        \"reward\" : [1],\n",
    "        \"desired_goal\" : [latents*categories],  \n",
    "    } | env.env_state.get_state_shape()\n",
    "\n",
    ")\n",
    "buffer = EpisodicBuffer(device=device, num_episodes=1000, max_episode_length=200, shapes = buffer_shapes) \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d235a1cb",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4b8643b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "In the paper, the agent and abstraction is pre-trained and later, rewards are introduced. \n",
    "Here, we directly collect rewards for some example reward function (room centers), and train the abstract goal simultaneously. \n",
    "Nonetheless, the agent is still tasked to do the original, goal-conditioned task. \n",
    "Note that the data used to learn the goal is biased by the original task, resulting in under-represented areas.\n",
    "\"\"\"\n",
    "\n",
    "MAX_EPISODES = 5000\n",
    "MAX_STEPS_PER_EPISODE = 200\n",
    "\n",
    "step = 0\n",
    "successes  = deque(maxlen=100)\n",
    "\n",
    "for episode in range(MAX_EPISODES):\n",
    "    last_obs = deque(maxlen=10)\n",
    "\n",
    "    goal_dict = env.sample_new_goal(1) # environment provides goals\n",
    "    result = env.reset()\n",
    "    last_obs.append(result[\"observation\"])\n",
    "    encoded_goal = goal_abstraction.encode(goal_dict[\"observation\"]).detach()\n",
    "    encoded_state = goal_abstraction.encode(torch.stack(list(last_obs), dim=1)).detach()\n",
    "    trajectory = []\n",
    "    success = False\n",
    "    reward_function = env.world.get_room_center_reward(radius=5)\n",
    "\n",
    "    for i in range(MAX_STEPS_PER_EPISODE):\n",
    "\n",
    "        encoded_state = goal_abstraction.encode(torch.stack(list(last_obs), dim=1)).detach()\n",
    "        state = torch.cat([encoded_state, result[\"observation\"]], dim=-1)\n",
    "        action = agent.get_action(state=state, goal=encoded_goal)\n",
    "        result = env.step(action)\n",
    "        reward = env.get_reward(reward_function, env_state=env.env_state)        \n",
    "        \n",
    "    \n",
    "        last_obs.append(result[\"observation\"])\n",
    "\n",
    "        buffer.append({\n",
    "            \"step\" : torch.tensor([step], dtype=torch.float32).unsqueeze(0),\n",
    "            \"observation\" : result[\"observation\"],\n",
    "            \"state\" : state,\n",
    "            \"action\" : action.detach(),\n",
    "            \"reward\" : reward,\n",
    "            \"desired_goal\" : encoded_goal.detach()\n",
    "        } | env.env_state.get_state_dict())\n",
    "\n",
    "        if step % 16 == 0:\n",
    "            abstraction_losses = goal_abstraction.update(buffer, step=step)\n",
    "            agent_losses = agent.update(buffer, step=step)\n",
    "            goal_sampler_losses = goal_sampler.update(buffer, step=step)\n",
    "            losses = abstraction_losses | agent_losses | goal_sampler_losses\n",
    "        trajectory.append(env.env_state.position)\n",
    "\n",
    "        step += 1\n",
    "        if torch.norm(env.env_state.position - goal_dict[\"position\"]) <= 25:\n",
    "            #print(\"Goal reached!\")\n",
    "            success = True\n",
    "            break\n",
    "    if success:\n",
    "        successes.append(1)\n",
    "    else:\n",
    "        successes.append(0)\n",
    "\n",
    "\n",
    "    if episode % 20 == 0:\n",
    "        clear_output(wait=True)\n",
    "        # side by side plot of goal and trajectory:\n",
    "        \n",
    "        img = plot_goal(env, goal_abstraction, buffer, encoded_goal, goal_position=goal_dict[\"position\"], blur_sigma=4)\n",
    "        \n",
    "        # subplot side by side:\n",
    "        plt.figure(figsize=(10,5))\n",
    "        plt.subplot(1,3,1)\n",
    "        trajectory_img = plot_trajectory(env, torch.stack(trajectory, dim=0), goal_dict[\"position\"].reshape((-1,2)))\n",
    "        plt.imshow(trajectory_img)        \n",
    "        plt.axis(\"off\")\n",
    "        plt.title(\"Trajectory\")\n",
    "        plt.subplot(1,3,2)\n",
    "        plt.imshow(img)\n",
    "        plt.axis(\"off\")\n",
    "        plt.title(\"Tasked Goal\")\n",
    "        plt.subplot(1,3,3)\n",
    "        rew_img = plot_goal(env, goal_abstraction, buffer, goal_sampler.sample(1), goal_position=None, blur_sigma=4)\n",
    "        plt.imshow(rew_img)\n",
    "        plt.axis(\"off\")\n",
    "        plt.title(\"Encoded reward (room centers)\")\n",
    "        plt.show()\n",
    "        if losses:\n",
    "            print(f\"{episode}: success_rate: {sum(successes) / len(successes):.3f} | bin_acc : {losses['bin_acc']:.3f}\")\n",
    "    buffer.new_episode()    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ced080b8",
   "metadata": {},
   "source": [
    "## Original goal oriented task\n",
    "First, we try to reach a few goals provided by the environment as single observations. Resulting trajectories as well as the encoded goals are shown"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b560a266",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 5  # Number of trajectories to generate\n",
    "trajectories = []\n",
    "goal_positions = []\n",
    "goal_encodings = []\n",
    "\n",
    "for _ in range(N):\n",
    "    last_obs = deque(maxlen=10)\n",
    "    goal_dict = env.sample_new_goal(1)\n",
    "    result = env.reset()\n",
    "    last_obs.append(result[\"observation\"])\n",
    "    encoded_goal = goal_abstraction.encode(goal_dict[\"observation\"]).detach()\n",
    "    encoded_state = goal_abstraction.encode(torch.stack(list(last_obs), dim=1)).detach()\n",
    "    trajectory = []\n",
    "    for _ in range(MAX_STEPS_PER_EPISODE):\n",
    "        encoded_state = goal_abstraction.encode(torch.stack(list(last_obs), dim=1)).detach()\n",
    "        state = torch.cat([encoded_state, result[\"observation\"]], dim=-1)\n",
    "        action = agent.get_action(state=state, goal=encoded_goal)\n",
    "        result = env.step(action)\n",
    "        last_obs.append(result[\"observation\"])\n",
    "        trajectory.append(env.env_state.position)\n",
    "        if torch.norm(env.env_state.position - goal_dict[\"position\"]) <= 25:\n",
    "            break\n",
    "    trajectories.append(torch.stack(trajectory, dim=0))\n",
    "    goal_positions.append(goal_dict[\"position\"].reshape((-1,2)))\n",
    "    goal_encodings.append(encoded_goal)\n",
    "plt.figure(figsize=(4 * N, 8))\n",
    "for i in range(N):\n",
    "    plt.subplot(2, N, i + 1)\n",
    "    traj_img = plot_trajectory(env, trajectories[i], goal_positions[i])\n",
    "    plt.imshow(traj_img)\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(f\"Trajectory {i+1}\")\n",
    "\n",
    "    plt.subplot(2, N, N + i + 1)\n",
    "    goal_img = plot_goal(env, goal_abstraction, buffer, goal_encodings[i], goal_position=goal_positions[i], blur_sigma=4)\n",
    "    plt.imshow(goal_img)\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(f\"Goal {i+1}\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f48501cf",
   "metadata": {},
   "source": [
    "## Encoded reward function\n",
    "Next, we plot the encoded reward function (room center reward). Note that the goal encoding uses a discrete VAE, thus we can sample multiple times to get slightly different goals:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f172f0bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "M = 5  # Number of samples to plot\n",
    "plt.figure(figsize=(4 * M, 4))\n",
    "for idx in range(M):\n",
    "    sampled_goal = goal_sampler.sample(1)\n",
    "    img = plot_goal(env, goal_abstraction, buffer, sampled_goal, goal_position=None, blur_sigma=4)\n",
    "    plt.subplot(1, M, idx + 1)\n",
    "    plt.imshow(img)\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(f\"Sample {idx+1}\")\n",
    "plt.tight_layout()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "961775cd",
   "metadata": {},
   "source": [
    "## Compositionality\n",
    "Let's try to combine goals in goal space by optimizing over a more abstract goal, combining the considered goals:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96d8001d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 4 * N))\n",
    "for row in range(N):\n",
    "    # Sample two transitions\n",
    "    goal_transition = buffer.sample(1, 5, to_device=device)\n",
    "    goal = goal_abstraction.encode(goal_transition[\"observation\"])\n",
    "    goal_transition2 = buffer.sample(1, 5, to_device=device)\n",
    "    goal2 = goal_abstraction.goal_encoder.encode(goal_transition2[\"observation\"])\n",
    "    specific_goal, e = sample_goal(goal_abstraction, torch.cat([goal, goal2], dim=0), dir=-1, steps=500, lr=1, latents=latents, categories=categories, device=device)\n",
    "\n",
    "    # Plot Goal 1\n",
    "    plt.subplot(N, 3, row * 3 + 1)\n",
    "    plt.imshow(plot_goal(env, goal_abstraction, buffer, goal, goal_position=goal_transition[\"position\"], blur_sigma=4))\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(\"Goal 1\")\n",
    "\n",
    "    # Plot Goal 2\n",
    "    plt.subplot(N, 3, row * 3 + 2)\n",
    "    plt.imshow(plot_goal(env, goal_abstraction, buffer, goal2, goal_position=goal_transition2[\"position\"], blur_sigma=4))\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(\"Goal 2\")\n",
    "\n",
    "    # Plot Combined Goal\n",
    "    plt.subplot(N, 3, row * 3 + 3)\n",
    "    plt.imshow(plot_goal(env, goal_abstraction, buffer, specific_goal, goal_position=None, blur_sigma=4))\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(\"Combined Goal\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "620488a3",
   "metadata": {},
   "source": [
    "As we can see, some combinations are possible, especially if they share features/similarities, while other result in a lot of activity (converging closer to the most abstract goal). "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv (3.10.12)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
