{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d0167d15-a6ff-4602-835d-c88851e83113",
   "metadata": {},
   "source": [
    "# Meta Motivo Tutorial\n",
    "This notebook provides a simple introduction on how to use the Meta Motivo api."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be0d8e4a-882c-467e-bce6-ef2f33b509e2",
   "metadata": {},
   "source": [
    "## All imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "03d3cd63-1d2e-4bda-b224-d2b2b73bf655",
   "metadata": {},
   "outputs": [],
   "source": [
    "from packaging.version import Version\n",
    "from metamotivo.fb_cpr.huggingface import FBcprModel\n",
    "from huggingface_hub import hf_hub_download\n",
    "from humenv import make_humenv\n",
    "import gymnasium\n",
    "from gymnasium.wrappers import FlattenObservation, TransformObservation\n",
    "from metamotivo.buffers.buffers import DictBuffer\n",
    "from humenv.env import make_from_name\n",
    "from humenv import rewards as humenv_rewards\n",
    "\n",
    "import torch\n",
    "import mediapy as media\n",
    "import math\n",
    "import h5py\n",
    "from pathlib import Path\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa35d241-fa2d-4ad3-aaa9-9e2b4175e742",
   "metadata": {},
   "source": [
    "## Model download\n",
    "The first step is to download the model. We show how to use HuggingFace hub for that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "00f7b632-864d-4b05-848c-f7a22b662a12",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = FBcprModel.from_pretrained(\"facebook/metamotivo-S-1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3609f0e1-2ff9-462b-9bea-15695a128da7",
   "metadata": {},
   "source": [
    "**Run a policy from Meta Motivo:**\n",
    "\n",
    "Now that we saw how to load a pre-trained Meta Motivo policy, we can prompt it and execute actions with it. \n",
    "\n",
    "The first step is to sample a context embedding `z` that needs to be passed to the policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2f4e85ca-1940-403b-adfd-126c244e39ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "embedding size torch.Size([1, 256])\n",
      "z norm: 15.999999046325684\n",
      "z norm / sqrt(d): 0.9999999403953552\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n"
     ]
    }
   ],
   "source": [
    "device = \"cpu\"\n",
    "\n",
    "if Version(\"0.26\") <= Version(gymnasium.__version__) < Version(\"1.0\"):\n",
    "    transform_obs_wrapper = lambda env: TransformObservation(\n",
    "            env, lambda obs: torch.tensor(obs.reshape(1, -1), dtype=torch.float32, device=device)\n",
    "        )\n",
    "else:\n",
    "    transform_obs_wrapper = lambda env: TransformObservation(\n",
    "            env, lambda obs: torch.tensor(obs.reshape(1, -1), dtype=torch.float32, device=device), env.observation_space\n",
    "        )\n",
    "\n",
    "env, _ = make_humenv(\n",
    "    num_envs=1,\n",
    "    wrappers=[\n",
    "        FlattenObservation,\n",
    "        transform_obs_wrapper,\n",
    "    ],\n",
    "    state_init=\"Default\",\n",
    ")\n",
    "\n",
    "model.to(device)\n",
    "z = model.sample_z(1)\n",
    "print(f\"embedding size {z.shape}\")\n",
    "print(f\"z norm: {torch.norm(z)}\")\n",
    "print(f\"z norm / sqrt(d): {torch.norm(z) / math.sqrt(z.shape[-1])}\")\n",
    "observation, _ = env.reset()\n",
    "# frames = [env.render()]\n",
    "for i in range(30):\n",
    "    action = model.act(observation, z, mean=True)\n",
    "    observation, reward, terminated, truncated, info = env.step(action.cpu().numpy().ravel())\n",
    "    print(reward)\n",
    "\n",
    "# media.show_video(frames, fps=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05338016-0890-4f9e-8741-73985dbc89b3",
   "metadata": {},
   "source": [
    "### Computing Q-functions\n",
    "\n",
    "FB-CPR provides a way of directly computing the action-value function of any policy embedding `z` on any task embedding `z_r`. Then, the Q function of a policy $z$ is given by\n",
    "\n",
    "$Q(s,a, z) = F(s,a,z) \\cdot z_r$\n",
    "\n",
    "The task embedding can be computed in the following way. Given a set of samples labeled with rewards $(s,a,s',r)$, the task embedding is given by: \n",
    "\n",
    "$z_r = \\mathrm{normalised}(\\sum_{i \\in \\mathrm{batch}} r_i B(s'_i))$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "98e3b022-22d1-4602-ab84-b056d621b702",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([408.6027, 189.3272, 426.6580, 477.1006, 305.5982, 329.3961, 514.0405,\n",
      "        550.2684, 414.1617, 546.8745])\n"
     ]
    }
   ],
   "source": [
    "def Qfunction(state, action, z_reward, z_policy):\n",
    "    F = model.forward_map(obs=state, z=z_policy.repeat(state.shape[0],1), action=action) # num_parallel x num_samples x z_dim\n",
    "    Q = F @ z_reward.ravel()\n",
    "    return Q.mean(axis=0)\n",
    "\n",
    "z_reward = model.sample_z(1)\n",
    "z_policy = model.sample_z(1)\n",
    "state = torch.rand((10, env.observation_space.shape[0]), device=model.cfg.device, dtype=torch.float32)\n",
    "action = torch.rand((10, env.action_space.shape[0]), device=model.cfg.device, dtype=torch.float32)*2 - 1\n",
    "Q = Qfunction(state, action, z_reward, z_policy)\n",
    "print(Q)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f168c33-a35f-4335-bc97-e6d4eeb5fce5",
   "metadata": {},
   "source": [
    "## Prompting the model\n",
    "\n",
    "We have seen that we can condition the model via the context variable `z`. We can control the task to execute via _prompting_ (or _policy inference_)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10d2b486-63c1-45f7-bfca-acd5308da1ec",
   "metadata": {},
   "source": [
    "### Reward prompts\n",
    "The first version of inference we investigate is the reward prompting, i.e., given a set of reward label samples we can infer in a zero-shot way the near-optimal policy for solving such task.\n",
    "\n",
    "First step, download the data for inference. We provide a buffer for inference of about 500k samples. This buffer has been generated by randomly subsampling the final replay buffer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "26312a39-fdb8-4843-a3e2-08f0dafc3db5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metamotivo-S-1-datasets/data/buffer_inference_500000.hdf5\n"
     ]
    }
   ],
   "source": [
    "local_dir = \"metamotivo-S-1-datasets\"\n",
    "dataset = \"buffer_inference_500000.hdf5\"\n",
    "buffer_path = hf_hub_download(\n",
    "        repo_id=\"facebook/metamotivo-S-1\",\n",
    "        filename=f\"data/{dataset}\",\n",
    "        repo_type=\"model\",\n",
    "        local_dir=local_dir,\n",
    "    )\n",
    "print(buffer_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac150fc7-0719-4428-8ef1-eac48ddf0d9a",
   "metadata": {},
   "source": [
    "Now that we have download the h5 file for inference, we can conveniently loaded it in a buffer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "20c39961-15e6-4739-96fd-bb7aa47e60f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<KeysViewHDF5 ['action', 'next_observation', 'next_qpos', 'next_qvel', 'observation', 'qpos', 'qvel']>\n",
      "action              : (500000, 69)\n",
      "next_observation    : (500000, 358)\n",
      "next_qpos           : (500000, 76)\n",
      "next_qvel           : (500000, 75)\n",
      "observation         : (500000, 358)\n",
      "qpos                : (500000, 76)\n",
      "qvel                : (500000, 75)\n"
     ]
    }
   ],
   "source": [
    "hf = h5py.File(buffer_path, \"r\")\n",
    "print(hf.keys())\n",
    "data = {}\n",
    "for k, v in hf.items():\n",
    "    print(f\"{k:20s}: {v.shape}\")\n",
    "    data[k] = v[:]\n",
    "buffer = DictBuffer(capacity=data[\"qpos\"].shape[0], device=\"cpu\")\n",
    "buffer.extend(data)\n",
    "del data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "af64888e-c992-4c0b-aa7b-9421942ee605",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "action              : torch.Size([5, 69])\n",
      "next_observation    : torch.Size([5, 358])\n",
      "next_qpos           : torch.Size([5, 76])\n",
      "next_qvel           : torch.Size([5, 75])\n",
      "observation         : torch.Size([5, 358])\n",
      "qpos                : torch.Size([5, 76])\n",
      "qvel                : torch.Size([5, 75])\n"
     ]
    }
   ],
   "source": [
    "batch = buffer.sample(5)\n",
    "for k, v in batch.items():\n",
    "    print(f\"{k:20s}: {v.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f1238a0-1ffc-4de9-a698-7853ea7fdd92",
   "metadata": {},
   "source": [
    "As you can see, the buffer does not provide a reward signal. We need to label this buffer with the desired reward function. We provide API for that but here we start looking into the basic steps:\n",
    "* Instantiate a reward function\n",
    "* Computing the reward from the batch data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dedac7ea-6365-41ce-b7cb-6b2f1ccbd4ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LocomotionReward(move_speed=2.0, stand_height=1.4, move_angle=0.0, egocentric_target=True, low_height=0.6, stay_low=False)\n"
     ]
    }
   ],
   "source": [
    "reward_fn = humenv_rewards.LocomotionReward(move_speed=2.0) # move ahead with speed 2\n",
    "# humenv provides also a name-base reward initialization. We could\n",
    "# get the same reward function in this way\n",
    "reward_fn = make_from_name(\"move-ego-0-2\") \n",
    "print(reward_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d95db165-5c5c-402c-9ddf-6cfed05f48a8",
   "metadata": {},
   "source": [
    "We can call the method `__call__` to obtain a reward value from the physics state. This function receives a mujoco model, qpos, qvel and the action. See the humenv tutorial for more information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caa8a3dd-159e-45a9-95b7-c52d6244b6f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# N = 100_000\n",
    "# batch = buffer.sample(N)\n",
    "# rewards = []\n",
    "# for i in range(N):\n",
    "#     rewards.append(\n",
    "#         reward_fn(\n",
    "#             env.unwrapped.model,\n",
    "#             qpos=batch[\"next_qpos\"][i],\n",
    "#             qvel=batch[\"next_qvel\"][i],\n",
    "#             ctrl=batch[\"action\"][i])\n",
    "#     )\n",
    "# rewards = np.stack(rewards).reshape(-1,1)\n",
    "# print(rewards.ravel())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1ccc931-9be2-4a88-b66f-70d369648c2c",
   "metadata": {},
   "source": [
    "**Note** that the reward functions implemented in humenv are functions of next state and action which means we need to use `next_qpos` and `next_qvel` that are the physical state of the system at the next state.\n",
    "\n",
    "We provide a multi-thread version for faster relabeling, see `metamotivo.wrappers.humenvbench.relabel`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "10ca4068-e529-43d4-a110-5f224d118d2e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[6.57521421e-01 9.64426044e-02 2.44013031e-01 4.91341914e-01\n",
      " 5.81574460e-04]\n"
     ]
    }
   ],
   "source": [
    "from metamotivo.wrappers.humenvbench import relabel\n",
    "rewards = relabel(\n",
    "    env,\n",
    "    qpos=batch[\"next_qpos\"],\n",
    "    qvel=batch[\"next_qvel\"],\n",
    "    action=batch[\"action\"],\n",
    "    reward_fn=reward_fn, \n",
    "    max_workers=4\n",
    ")\n",
    "print(rewards.ravel())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8327f700-746f-4d76-b0e2-4b66cd4b44de",
   "metadata": {},
   "source": [
    "We can now infer the context `z` for the selected task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5537121c-ce60-44e0-9fdf-5457bcad6c08",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.3076, -0.2439, -0.7661, -0.6795, -0.7002, -0.6440, -0.2513, -0.2484,\n",
      "          0.3242, -0.5451, -0.6776, -1.4743, -0.2207,  1.3004,  0.0222, -0.7012,\n",
      "         -0.8221,  0.2887,  1.7149,  0.6139,  0.1733, -0.2412,  0.0740,  0.4092,\n",
      "          0.2089, -1.3360, -2.3101,  1.1554, -0.9571, -1.8237, -0.1461, -1.9122,\n",
      "         -0.2505, -0.0465, -0.0552,  0.3294,  2.0049,  1.2339,  0.2701,  0.5270,\n",
      "          0.2593,  2.0540, -0.7256, -0.8436, -0.2003, -0.1574, -0.7549,  0.4072,\n",
      "          0.0612,  0.5219,  1.4295, -0.5869,  0.9393, -0.7208,  1.3991,  0.3836,\n",
      "         -0.3250,  0.6880, -0.3892, -0.2345, -0.3116,  0.7587, -0.0341,  0.9872,\n",
      "          0.2301, -0.1386,  0.8478,  1.9818, -0.7292,  0.4505,  0.4112,  0.6671,\n",
      "          0.8671, -0.1077, -1.5943,  1.6093, -0.2688,  1.2927, -0.7521,  0.0457,\n",
      "          1.7979, -2.4199,  0.3253, -0.1710,  0.6316, -1.0871, -0.1804,  0.5481,\n",
      "         -0.4115,  2.4300,  1.8018,  0.9862,  0.0930,  0.5422,  0.2738, -0.7534,\n",
      "         -0.5232, -0.0445,  0.2829,  0.2265,  0.6510,  0.1943, -0.0687, -0.3826,\n",
      "          0.2477, -0.0348, -1.6214,  0.1279,  0.5272,  0.4666,  0.5274,  0.0146,\n",
      "          0.1194, -0.2838, -1.2758,  1.8623, -0.1040,  0.5865,  0.2879, -1.7184,\n",
      "         -1.0462,  0.1422,  0.3566, -1.1703,  0.3362,  0.3208,  0.4994, -0.1273,\n",
      "          1.2863, -0.1318, -1.4895,  1.5381, -0.0590,  0.0655,  1.0023, -0.1261,\n",
      "         -0.3673, -1.5232,  0.4366, -0.0533, -1.0344, -0.8035, -0.8976, -2.5898,\n",
      "          0.4201,  0.5180,  0.8725,  0.6807, -1.4769,  3.3322, -0.3845,  0.9815,\n",
      "         -0.1426, -0.0779,  1.2733, -2.8230,  1.0266,  0.4559,  0.6103, -0.1189,\n",
      "          0.2965,  0.4673, -0.1315,  0.5495,  2.8369, -0.0106, -1.2414, -0.5133,\n",
      "          0.2168, -0.3327, -0.4262,  1.8487,  0.8416, -1.4966,  0.0238,  0.4169,\n",
      "         -0.4190,  1.0502,  1.3189,  0.4225,  0.8545, -0.1738,  1.6700, -0.3599,\n",
      "          0.1532, -2.6332,  1.6673,  0.9300, -1.0590,  0.5029, -0.2246,  1.0630,\n",
      "          0.5746,  0.5086, -0.6506, -0.6065, -1.2411, -0.4758, -0.0963,  0.3045,\n",
      "         -1.5502, -0.3197, -1.1750, -0.3065, -0.0555,  0.0150, -0.7762,  0.8744,\n",
      "          1.0162,  0.6598,  0.9835,  2.3806,  2.1573,  0.1602,  0.9950, -0.8173,\n",
      "         -0.9022, -0.8009,  0.8450,  0.9039, -0.5597,  0.7646,  0.6675, -0.8146,\n",
      "         -0.2400,  0.2849, -1.0101,  0.4083,  0.0082, -1.6717, -2.4339,  0.6756,\n",
      "         -0.0091,  1.0022,  0.1004, -0.5101,  1.2122, -0.8115,  0.1318, -0.1791,\n",
      "         -2.1569, -1.2127,  1.6897,  1.6093, -1.5666,  1.0244, -0.5834, -0.4430,\n",
      "          2.0673,  0.6955,  0.9622,  1.1398, -1.4695,  0.5614, -0.9527, -0.6809]])\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n",
      "0.0\n"
     ]
    }
   ],
   "source": [
    "import gc\n",
    "import torch\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "with torch.no_grad():\n",
    "    z = model.reward_wr_inference(\n",
    "        next_obs=batch[\"next_observation\"],\n",
    "        reward=torch.tensor(rewards, device=model.cfg.device, dtype=torch.float32)\n",
    "    )\n",
    "\n",
    "    print(z)\n",
    "    \n",
    "    observation, _ = env.reset()\n",
    "    # frames = [env.render()]\n",
    "    for i in range(100):\n",
    "        action = model.act(observation, z, mean=True)\n",
    "        observation, reward, terminated, truncated, info = env.step(action.cpu().numpy().ravel())\n",
    "        print(reward)\n",
    "        # frames.append(env.render())\n",
    "\n",
    "# media.show_video(frames, fps=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bd2550c-06a5-4353-b3f5-4eeb56c30d70",
   "metadata": {},
   "source": [
    "Let's compute the **Q-function** for this policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3719a3bf-8282-430a-84a6-c5067a4cfcd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([6303.0771, 8355.1621, 5579.8247, 6280.5664, 5039.2002])\n"
     ]
    }
   ],
   "source": [
    "z_reward = torch.sum(\n",
    "    model.backward_map(obs=batch[\"next_observation\"]) * torch.tensor(rewards, dtype=torch.float32, device=model.cfg.device),\n",
    "    dim=0\n",
    ")\n",
    "z_reward = model.project_z(z_reward)\n",
    "Q = Qfunction(batch[\"observation\"], batch[\"action\"], z_reward, z)\n",
    "print(Q)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42121e99-c920-4d69-a499-ec039e0b8e05",
   "metadata": {},
   "source": [
    "# Goal and Tracking prompts\n",
    "The model supports two other modalities, `goal` and `tracking`. These two modalities expose similar functions for context inference:\n",
    "- `def goal_inference(self, next_obs: torch.Tensor) -> torch.Tensor`\n",
    "- `def tracking_inference(self, next_obs: torch.Tensor) -> torch.Tensor`\n",
    "  \n",
    "We show an example on how to perform goal inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c3e1de2e-d792-446c-ba1f-7b07abf69cac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "goal pose\n"
     ]
    }
   ],
   "source": [
    "goal_qpos = np.array([0.13769039,-0.20029453,0.42305034,0.21707786,0.94573617,0.23868944\n",
    ",0.03856998,-1.05566834,-0.12680767,0.11718296,1.89464102,-0.01371153\n",
    ",-0.07981451,-0.70497424,-0.0478,-0.05700732,-0.05363342,-0.0657329\n",
    ",0.08163511,-1.06263979,0.09788937,-0.22008936,1.85898192,0.08773695\n",
    ",0.06200327,-0.3802791,0.07829525,0.06707749,0.14137152,0.08834448\n",
    ",-0.07649805,0.78328658,0.12580912,-0.01076061,-0.35937259,-0.13176489\n",
    ",0.07497022,-0.2331914,-0.11682692,0.04782308,-0.13571422,0.22827948\n",
    ",-0.23456622,-0.12406075,-0.04466465,0.2311667,-0.12232673,-0.25614032\n",
    ",-0.36237662,0.11197906,-0.08259534,-0.634934,-0.30822742,-0.93798716\n",
    ",0.08848668,0.4083417,-0.30910404,0.40950143,0.30815359,0.03266103\n",
    ",1.03959336,-0.19865537,0.25149713,0.3277561,0.16943092,0.69125975\n",
    ",0.21721349,-0.30871948,0.88890484,-0.08884043,0.38474549,0.30884107\n",
    ",-0.40933304,0.30889523,-0.29562966,-0.6271498])\n",
    "env.unwrapped.set_physics(qpos=goal_qpos, qvel=np.zeros(75))\n",
    "goal_obs = torch.tensor(env.unwrapped.get_obs()[\"proprio\"].reshape(1,-1), device=model.cfg.device, dtype=torch.float32)\n",
    "print(\"goal pose\")\n",
    "# media.show_image(env.render())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24755455-28f9-41ad-b7e6-26f5b10ae13e",
   "metadata": {},
   "outputs": [],
   "source": [
    "z = model.goal_inference(next_obs=goal_obs)\n",
    "\n",
    "\n",
    "observation, _ = env.reset()\n",
    "frames = [env.render()]\n",
    "for i in range(30):\n",
    "    action = model.act(observation, z, mean=True)\n",
    "    observation, reward, terminated, truncated, info = env.step(action.cpu().numpy().ravel())\n",
    "    print(reward)\n",
    "    # frames.append(env.render())\n",
    "\n",
    "# media.show_video(frames, fps=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24e7c53e-76a6-476e-966d-2ff1e020d33f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (metamotivo-py310)",
   "language": "python",
   "name": "metamotivo-py310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
