{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7e714762-bf94-474a-a4b0-154d8715c9a0",
   "metadata": {},
   "source": [
    "# Meta Motivo benchmarking using HumEnv\n",
    "\n",
    "This notebook shows how to evaluate a Meta Motivo model using the benchmark proposed in HumEnv. It assumes that motions for tracking and poses for goal reaching have been processed by following the [instructions](https://github.com/facebookresearch/humenv/tree/main/data_preparation) in HumEnv and are available in the folder `MOTIONS_BASE_PATH`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89840e8c-ac6d-4fd1-880f-53e034b56f36",
   "metadata": {},
   "outputs": [],
   "source": [
    "from metamotivo.fb_cpr.huggingface import FBcprModel\n",
    "from metamotivo.wrappers.humenvbench import RewardWrapper, TrackingWrapper, GoalWrapper\n",
    "from metamotivo.buffers.buffers import DictBuffer\n",
    "from huggingface_hub import hf_hub_download\n",
    "import h5py\n",
    "import json\n",
    "import numpy as np\n",
    "from humenv import STANDARD_TASKS\n",
    "from humenv.bench import (\n",
    "    RewardEvaluation,\n",
    "    GoalEvaluation,\n",
    "    TrackingEvaluation,\n",
    ")\n",
    "\n",
    "# paths where to find the output of HumEnv's data preparation scripts\n",
    "MOTIONS_BASE_PATH = \"humenv/data_preparation/humenv_amass\"\n",
    "MOTIONS_TRACKING = \"humenv/data_preparation/test_train_split/large1_small1_test_0.1.txt\"\n",
    "GOAL_POSES = \"humenv/data_preparation/goal_poses/goals.json\"\n",
    "\n",
    "# load the goal poses into a dictionary\n",
    "with open(GOAL_POSES, \"r\") as json_file:\n",
    "    GOAL_DICT = json.load(json_file)\n",
    "GOAL_DICT = {k: np.array(v[\"observation\"]) for k,v in GOAL_DICT.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ef8eeea-ee50-464d-9ecc-45d44fc7a71d",
   "metadata": {},
   "source": [
    "Load inference buffer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7690ac63-ac66-474a-95a0-9b6ccb33f0da",
   "metadata": {},
   "outputs": [],
   "source": [
    "buffer_path = hf_hub_download(\n",
    "        repo_id=\"facebook/metamotivo-S-1\",\n",
    "        filename=\"data/buffer_inference_500000.hdf5\",\n",
    "        repo_type=\"model\",\n",
    "        local_dir=\"metamotivo-S-1-datasets\",\n",
    "    )\n",
    "hf = h5py.File(buffer_path, \"r\")\n",
    "data = {k: v[:] for k, v in hf.items()}\n",
    "buffer = DictBuffer(capacity=data[\"qpos\"].shape[0], device=\"cpu\")\n",
    "buffer.extend(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d4fa4f5-74fd-45b1-b5b3-13d6368fb794",
   "metadata": {},
   "source": [
    "Load model and prepare it for inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e23b3e6c-eeb1-4fa8-bc77-f2ef04a3d495",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cpu\"  # it is normally faster to evaluate on cpu since tracking is parallelized\n",
    "model = FBcprModel.from_pretrained(\"facebook/metamotivo-S-1\").to(device)\n",
    "model = RewardWrapper(\n",
    "        model=model,\n",
    "        inference_dataset=buffer,\n",
    "        num_samples_per_inference=100_000,\n",
    "        inference_function=\"reward_wr_inference\",\n",
    "        max_workers=80,\n",
    "    )\n",
    "model = GoalWrapper(model=model)\n",
    "model = TrackingWrapper(model=model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "279ca5af-54f2-4daa-9476-6c2ec00109a7",
   "metadata": {},
   "source": [
    "Humenv provides 3 evaluation protocols:\n",
    "- reward based,\n",
    "- goal based,\n",
    "- tracking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc2f13f3-ee0e-4925-b87c-25e11a2f79f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_eval = RewardEvaluation(\n",
    "        tasks=STANDARD_TASKS,  # all the 45 tasks used in the paper\n",
    "        env_kwargs={\"state_init\": \"Fall\"},\n",
    "        num_contexts=1,\n",
    "        num_envs=50,\n",
    "        num_episodes=100,\n",
    "    )\n",
    "\n",
    "reward_metrics = reward_eval.run(agent=model)\n",
    "print(reward_metrics)\n",
    "\n",
    "r = np.array([m['reward'] for m in reward_metrics.values()])\n",
    "print(f\"reward averaged across {r.shape[0]} tasks: {r.mean()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e79882e-1e48-42b4-8a0d-4c1bdaf12148",
   "metadata": {},
   "outputs": [],
   "source": [
    "goal_eval = GoalEvaluation(\n",
    "    goals=GOAL_DICT,\n",
    "    env_kwargs={\"state_init\": \"Fall\"},\n",
    "    num_contexts=1,\n",
    "    num_envs=50,\n",
    "    num_episodes=100,\n",
    ")\n",
    "\n",
    "goal_metrics = goal_eval.run(agent=model)\n",
    "print(goal_metrics)\n",
    "\n",
    "for k in ['success', 'proximity']:\n",
    "    r = np.array([m[k] for m in goal_metrics.values()])\n",
    "    print(f\"goal {k} averaged across {r.shape[0]} poses: {r.mean()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6ccae6b-510e-43c0-a92b-f0772c150ee5",
   "metadata": {},
   "outputs": [],
   "source": [
    "tracking_eval = TrackingEvaluation(\n",
    "    motions=MOTIONS_TRACKING,\n",
    "    motion_base_path=MOTIONS_BASE_PATH,\n",
    "    env_kwargs={\"state_init\": \"Default\"},\n",
    "    num_envs=50,\n",
    ")\n",
    "\n",
    "tracking_metrics = tracking_eval.run(agent=model)\n",
    "print(tracking_metrics)\n",
    "\n",
    "for k in ['success_phc_linf', 'emd']:\n",
    "    r = np.array([m[k] for m in tracking_metrics.values()])\n",
    "    print(f\"tracking {k} averaged across {r.shape[0]} motions: {r.mean()}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
