{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5d862383-81f4-47c2-a558-822e638db158",
   "metadata": {},
   "source": [
    "# LLaRA: VIMA dataset conversion\n",
    "\n",
    "This notebook:\n",
    "1. converts VIMA, a behavior cloning dataset, into multiple instruction tuning datasets.\n",
    "2. generates multiple auxiliary datasets based on VIMA."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0ca2ebd-b84a-4400-a2cb-fa35e42c0993",
   "metadata": {},
   "outputs": [],
   "source": [
    "# please assign a directory to save images\n",
    "original_vima_zip = 'vima.zip'\n",
    "unzip_image_destination = '/mnt/dist/'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e4de67f-e3e6-4be5-b7e5-d74d2ceacfac",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Step 0: Convert the original dataset\n",
    "\n",
    "You only need to complete this step once.\n",
    "Please download the following files:\n",
    "1. The original VIMA dataset [vima.zip](https://huggingface.co/datasets/VIMA/VIMA-Data) (660k episodes)\n",
    "2. Three subset splits (already included in this repo):\n",
    "   * `vima-0d8k.json` (806 episodes)\n",
    "   * `vima-8k.json` (7995 episodes)\n",
    "   * `vima-80k.json` (80002 episodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05ddc3d1-7143-48f2-9856-994a921b5f3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import zipfile\n",
    "import os\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "import pickle\n",
    "import re\n",
    "from collections import defaultdict\n",
    "import random\n",
    "from tqdm import tqdm \n",
    "from PIL import Image, ImageDraw\n",
    "from scipy.spatial.transform import Rotation as R\n",
    "import math\n",
    "import copy\n",
    "import io"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ead74d1a-68a1-4b4d-9477-c1ecc9b33f6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dataset split\n",
    "train_0d8k = json.load(open('vima-0d8k.json', 'r'))['train']\n",
    "train_8k = json.load(open('vima-8k.json', 'r'))['train']\n",
    "train_80k = json.load(open('vima-80k.json', 'r'))['train']\n",
    "\n",
    "print('Number of episodes: ', len(train_0d8k), len(train_8k), len(train_80k))\n",
    "\n",
    "selected_episodes = set(train_0d8k + train_8k + train_80k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7ff85f7-06ce-4b63-ae85-9cd4578e25a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the whole vima.zip into ram\n",
    "# no need to change the file names below\n",
    "target_vima_zip = 'cvt_vima.zip' \n",
    "vima_meta = 'vima_meta.json'\n",
    "vima_ep = 'vima_episodes.json'\n",
    "\n",
    "zip_file = zipfile.ZipFile(original_vima_zip, 'r') # this will take a while"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "046375e5-25df-409a-8a02-4ea988cdfb6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# aggregate all meta information into a json file\n",
    "episodes: dict[list] = dict()\n",
    "metas: list = []\n",
    "\n",
    "for zfile in tqdm(zip_file.namelist()):\n",
    "    keys = zfile.split('/')\n",
    "    if len(keys) < 4:\n",
    "        if 'metadata.pkl' == keys[-1]:\n",
    "            metas.append(zfile)\n",
    "        continue\n",
    "    dkey = osp.join(*keys[:3])\n",
    "    dfile = osp.join(*keys[3:])\n",
    "    \n",
    "    if dkey in episodes:\n",
    "        episodes[dkey].append(dfile)\n",
    "    else:\n",
    "        episodes[dkey] = [dfile]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c354fd34-1209-476f-b38b-e49ceb1a9f0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# confirm the bug that we don't have any rotation data when the end effector is a spatula \n",
    "\n",
    "for dkey in tqdm(episodes):\n",
    "    with zip_file.open(os.path.join(dkey, 'obs.pkl'), 'r') as f:\n",
    "        obs = pickle.load(f)\n",
    "    if obs['ee'][0] == 0:\n",
    "        # skip if the end effector is suction cup\n",
    "        continue\n",
    "    with zip_file.open(os.path.join(dkey, 'action.pkl'), 'r') as f:\n",
    "        action = pickle.load(f)\n",
    "        for rot_vec in action['pose0_rotation']:\n",
    "            assert np.linalg.norm(rot_vec - np.array([0, 0, 0, 1])) < 1e-7, action\n",
    "        for rot_vec in action['pose1_rotation']:\n",
    "            assert np.linalg.norm(rot_vec - np.array([0, 0, 0, 1])) < 1e-7, action\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e3af108-905e-4781-a5dc-a154b6beaf11",
   "metadata": {},
   "outputs": [],
   "source": [
    "meta_info = {}\n",
    "for meta in metas:\n",
    "    with zip_file.open(meta, 'r') as f:\n",
    "        data = pickle.load(f)\n",
    "        meta_info[meta[:-13]] = data\n",
    "\n",
    "# dump meta info\n",
    "with open(vima_meta, 'w') as f:\n",
    "    json.dump(meta_info, f, indent=2)\n",
    "\n",
    "# dump episodes:\n",
    "with open(vima_ep, 'w') as f:\n",
    "    json.dump(episodes, f, indent=2)\n",
    "    \n",
    "\n",
    "def trajectory_to_json(trajectory):\n",
    "    # skip prompt_assets\n",
    "    meta_exp = {k: v for k, v in trajectory.items() if k not in ['prompt_assets', 'obj_id_to_info']}\n",
    "\n",
    "    # modalities => list\n",
    "    meta_exp['modalities'] = list(meta_exp['modalities'])\n",
    "    # action_bounds=> [numpy array] => list\n",
    "    meta_exp['action_bounds'] = {k: v.tolist() for k, v in meta_exp['action_bounds'].items()}\n",
    "    return meta_exp\n",
    "\n",
    "def prune_info(info):\n",
    "    drop = ['obj_size_range', 'obj_replace_fn', 'obj_profile']\n",
    "    return {k: {kk: vv for kk, vv in v.items() if kk not in drop} for k, v in info.items() }\n",
    "\n",
    "def prune_imgs(info):\n",
    "    res = {}\n",
    "    imgs = {}\n",
    "    img_keys = ['rgb', 'segm']\n",
    "    views = ['top', 'front']\n",
    "    for referring, value in info.items():\n",
    "        res[referring] = {k: {kk: vv for kk, vv in v.items() if kk not in views} if isinstance(v, dict) else v for k, v in value.items() if k not in img_keys}\n",
    "        res[referring]['segm_obj_info'] = value['segm']['obj_info']\n",
    "        for moda in img_keys:\n",
    "            for view in value[moda]:\n",
    "                if isinstance(value[moda][view], np.ndarray):\n",
    "                    imgs[f'{moda}_{view}/a_{referring}.png'] = value[moda][view]\n",
    "    return res, imgs\n",
    "\n",
    "def prune_obs(info):\n",
    "    res = {}\n",
    "    imgs = {}\n",
    "    img_keys = ['rgb', 'segm']\n",
    "    views = ['top', 'front']\n",
    "    for k, value in info.items():\n",
    "        if k in img_keys:\n",
    "            moda = k\n",
    "            for view in value:\n",
    "                if isinstance(value[view], np.ndarray):\n",
    "                    for t, img in enumerate(value[view]):\n",
    "                        imgs[f'{moda}_{view}/{t}.png'] = img\n",
    "        else:\n",
    "            res[k] = value\n",
    "    return res, imgs\n",
    "\n",
    "\n",
    "def img_to_bytes(img_tmp):\n",
    "    if len(img_tmp.shape) > 2:\n",
    "        img_tmp = np.moveaxis(img_tmp, 0, -1)\n",
    "    img = Image.fromarray(img_tmp)\n",
    "    img_byte_arr = io.BytesIO()\n",
    "    img.save(img_byte_arr, format='PNG')\n",
    "    return img_byte_arr.getvalue()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "292690ed-44ff-4cf7-981c-64e1af642e2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_episodes = {}\n",
    "\n",
    "with zipfile.ZipFile(target_vima_zip, 'w') as zip_target:\n",
    "    # pack meta information\n",
    "    zip_target.write(vima_meta)\n",
    "    \n",
    "    for dkey in tqdm(selected_episodes):\n",
    "        new_files = []\n",
    "        for file in episodes[dkey]:\n",
    "            path = osp.join(dkey, file)\n",
    "            if file == 'trajectory.pkl':\n",
    "                with zip_file.open(path, 'r') as f:\n",
    "                    traj = pickle.load(f)\n",
    "                # write critical meta information into json file\n",
    "                traj_json = trajectory_to_json(traj)\n",
    "                traj_json['obj_id_to_info'] = prune_info(traj['obj_id_to_info'])\n",
    "                \n",
    "                # dump asset to images\n",
    "                res, imgs = prune_imgs(traj['prompt_assets'])\n",
    "                traj_json['prompt_assets'] = res\n",
    "                zip_target.writestr(osp.join(dkey, 'trajectory.json'), json.dumps(traj_json, indent=2))\n",
    "\n",
    "                new_files.append('trajectory.json')\n",
    "                \n",
    "                # dump images\n",
    "                for k, v in imgs.items():\n",
    "                    zip_target.writestr(osp.join(dkey, k), img_to_bytes(v))\n",
    "                    new_files.append(k)\n",
    "                    \n",
    "            elif file == 'obs.pkl':\n",
    "                # skip observation file\n",
    "                with zip_file.open(path, 'r') as f:\n",
    "                    obs = pickle.load(f)\n",
    "                res, imgs = prune_obs(obs)\n",
    "\n",
    "                # dump images\n",
    "                for k, v in imgs.items():\n",
    "                    zip_target.writestr(osp.join(dkey, k), img_to_bytes(v))\n",
    "                    new_files.append(k)\n",
    "                \n",
    "                zip_target.writestr(osp.join(dkey, file), pickle.dumps(res))\n",
    "                new_files.append(file)\n",
    "            else:\n",
    "                # directly copy files\n",
    "                with zip_file.open(path, 'r') as f:\n",
    "                    zip_target.writestr(path, f.read())\n",
    "                new_files.append(file)\n",
    "        new_episodes[dkey] = new_files\n",
    "    zip_target.writestr(vima_ep, json.dumps(new_episodes, indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2fd0716-99f0-417f-ba7b-b2fe10ba424a",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Step 1 Load the converted dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e61cc77d-15aa-4e1c-9443-ac8a6883118f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import zipfile\n",
    "import os\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "import pickle\n",
    "import re\n",
    "from collections import defaultdict\n",
    "import random\n",
    "from tqdm import tqdm \n",
    "from PIL import Image, ImageDraw\n",
    "from scipy.spatial.transform import Rotation as R\n",
    "import math\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c88575b-25e8-4a6e-86b4-7867f60c057a",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_vima_zip = 'cvt_vima.zip'\n",
    "vima_meta = 'vima_meta.json'\n",
    "vima_ep = 'vima_episodes.json'\n",
    "\n",
    "# Load dataset split\n",
    "train_0d8k = json.load(open('vima-0d8k.json', 'r'))['train']\n",
    "train_8k = json.load(open('vima-8k.json', 'r'))['train']\n",
    "train_80k = json.load(open('vima-80k.json', 'r'))['train']\n",
    "\n",
    "print('Number of episodes: ', len(train_0d8k), len(train_8k), len(train_80k))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a83707-f9b5-424d-b893-4c5f99bd228b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load converted vima dataset\n",
    "zip_file = zipfile.ZipFile(target_vima_zip, 'r')\n",
    "\n",
    "# load the episode list from the zip file\n",
    "with zip_file.open(vima_ep, 'r') as f:\n",
    "    episodes = json.load(f)\n",
    "\n",
    "with zip_file.open(vima_meta, 'r') as f:\n",
    "    task_meta = json.load(f)\n",
    "\n",
    "dkeys = list(episodes.keys())\n",
    "print(f'Total number of episodes: {len(dkeys)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db3fc8ba-de77-4b40-afbd-3751f15d0a76",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper functions to extract all files of one episode\n",
    "\n",
    "def get_all_files_for_episode(dkey):\n",
    "    result = {}\n",
    "    for file in episodes[dkey]:\n",
    "        path = osp.join(dkey, file)\n",
    "        # print(dkey, path)\n",
    "        if file[-4:] == 'json':\n",
    "            with zip_file.open(path, 'r') as f:\n",
    "                result[file] = json.load(f)\n",
    "        elif file[-3:] == 'pkl':\n",
    "            with zip_file.open(path, 'r') as f:\n",
    "                result[file] = pickle.load(f)\n",
    "        elif file[-3:] in ['jpg', 'png']:\n",
    "            with zip_file.open(path, 'r') as f:\n",
    "                result[file] = Image.open(f).copy()\n",
    "    return result\n",
    "\n",
    "def get_all_files_for_episode_by_idx(index):\n",
    "    dkey = dkeys[index]\n",
    "    print(dkey)\n",
    "    return get_all_files_for_episode(dkey)\n",
    "\n",
    "episode = get_all_files_for_episode_by_idx(114)\n",
    "print('This is a demo to list all the files in an episode.\\n' + '-' * 10)\n",
    "print(episode.keys())\n",
    "print('=' * 20)\n",
    "episode = get_all_files_for_episode_by_idx(514)\n",
    "print('This is a demo to list all the files in an episode.\\n' + '-' * 10)\n",
    "print(episode.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5251625b-bd87-4160-9a90-e075aa9f0934",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Other helper functions\n",
    "# pix: image pixel coordinates\n",
    "# pos: robot position in action space\n",
    "def pix2pos_front(px, py):\n",
    "    j, i = float(px - 3) / 251 - 0.5, float(py - 34) / 178 + 0.25\n",
    "    return np.clip(i, 0.25, 0.75), np.clip(j, -0.5, 0.5)\n",
    "\n",
    "def pos2pix_front(i, j):\n",
    "    return int((j + 0.5) * 251 + 3), int((i - 0.25) * 178 + 34)\n",
    "    \n",
    "def pix2pos(px, py):\n",
    "    return np.clip(float(py - 1) / 252 + 0.25, 0.25, 0.75), np.clip(float(px - 2) / 251 - 0.5, -0.5, 0.5)\n",
    "\n",
    "def pos2pix(i, j):\n",
    "    return int((j + 0.5) * 251 + 2), int((i - 0.25) * 252 + 1)\n",
    "\n",
    "# more helper functions\n",
    "def get_bounding_box(arr):\n",
    "    # format: y1, x1, y2, x2\n",
    "    assert np.sum(arr) > 0\n",
    "    rows, cols = np.where(arr)\n",
    "    min_row, max_row = np.min(rows), np.max(rows)\n",
    "    min_col, max_col = np.min(cols), np.max(cols)\n",
    "    return min_row, min_col, max_row, max_col\n",
    "\n",
    "def get_center_bbox(x1, y1, x2, y2, iw, ih):\n",
    "    cx, cy, w, h = (x1 + x2) // 2, (y1 + y2) // 2, x2 - x1, y2 - y1\n",
    "    return cx / iw, cy / ih, w / iw, h / ih \n",
    "\n",
    "def get_center_bbox_from_obs(mask):\n",
    "    y1, x1, y2, x2 = get_bounding_box(mask)\n",
    "    iw = mask.shape[1]\n",
    "    ih = mask.shape[0]\n",
    "    return get_center_bbox(x1, y1, x2, y2, iw, ih)\n",
    "\n",
    "def draw_norm_bbox(img, bbox, color='red', dp=False):\n",
    "    draw = ImageDraw.Draw(img)\n",
    "    w, h = img.size\n",
    "    cx, cy, iw, ih = bbox\n",
    "    bx = [(cx - iw / 2) * w, (cy - ih / 2) * h, (cx + iw / 2) * w, (cy + ih / 2) * h]\n",
    "    draw.rectangle(bx, outline=color)\n",
    "    if dp:\n",
    "        display(img)\n",
    "\n",
    "def draw_point(img, px, py, color='red', dp=False):\n",
    "    draw = ImageDraw.Draw(img)\n",
    "    size=5\n",
    "    draw.ellipse([px-size/2,py-size/2,px+size//2,py+size//2], fill=color)\n",
    "    if dp:\n",
    "        display(img)\n",
    "    \n",
    "\n",
    "def bbox_format(bbox):\n",
    "    cx, cy, iw, ih = bbox\n",
    "    return f'<b>({cx:.3f}, {cy:.3f}), {{{iw:.3f}, {ih:.3f}}}</b>'\n",
    "\n",
    "def obj_format(obj):\n",
    "    if 'obj_color' in obj:\n",
    "        return f\"<p>{obj['obj_color']} {obj['obj_name']}</p>\"\n",
    "    return f\"<p>{obj['texture_name']} {obj['obj_name']}</p>\"\n",
    "\n",
    "def get_action_prompt(spatula=False):\n",
    "    p = \"Every action you take must include two locations in the format of <b>(x, y)</b> and one clockwise rotation angle in the format of <r>[r]</r>. \"\n",
    "    if spatula:\n",
    "        p += \"The first location is the image coordinate where you start to sweep the object using a spatula, and the second location is where you stop sweeping. \"\n",
    "        p += \"The image coordinate ranges from 0 to 1. The rotation angle indicates how many degrees you rotate the spatula clockwise, and it ranges from -359 to 359.\"\n",
    "    else:\n",
    "        p += \"The first location is the image coordinate where you use a suction cup to pick up the object, and the second location is where you place the object.\"\n",
    "        p += \"The image coordinate ranges from 0 to 1. The rotation angle indicates how many degrees you rotate the object clockwise, and it ranges from -359 to 359.\"\n",
    "    return p\n",
    "\n",
    "def get_obj_info_from_action(episode, start_x, start_y, end_x, end_y, obs_i, obj_id_to_info, view_mode):\n",
    "    def distance(d1, d2):\n",
    "        return np.linalg.norm(np.array(d1) - np.array(d2))\n",
    "    \n",
    "    d = []\n",
    "    c = 0\n",
    "    for i in obj_id_to_info:\n",
    "        x1, y1 = get_center_bbox_from_obs(np.array(episode[f'segm_{view_mode}/{obs_i}.png']) == int(i))[:2]\n",
    "        x2, y2 = get_center_bbox_from_obs(np.array(episode[f'segm_{view_mode}/{obs_i + 1}.png']) == int(i))[:2]\n",
    "        # decide which object the expert moves based on object moving distance\n",
    "        move_distance = distance([x1, y1], [x2, y2])\n",
    "        # meanwhile consider the distance between the suction cup and the object\n",
    "        score = distance([start_x, start_y], [x1, y1]) + distance([end_x, end_y], [x2, y2])  -  move_distance\n",
    "        d.append((i, score, c))\n",
    "        c += 1\n",
    "    d.sort(key=lambda x: x[1])\n",
    "    # note that sometimes the expert policy may fail to move an object\n",
    "    return d[0][2], obj_id_to_info[d[0][0]]\n",
    "\n",
    "\n",
    "def get_image_corr_action(episode, action, step, view_mode):\n",
    "    act_pos_start = action['pose0_position'][step]\n",
    "    act_pos_end = action['pose1_position'][step]\n",
    "    act_rot_start = action['pose0_rotation'][step]\n",
    "    act_rot_end = action['pose1_rotation'][step]\n",
    "\n",
    "    # Convert to Euler angles\n",
    "    e1 = R.from_quat(act_rot_start).as_euler('xyz', degrees=True)\n",
    "    e2 = R.from_quat(act_rot_end).as_euler('xyz', degrees=True)\n",
    "    \n",
    "    rotation_degree = int(e2[-1] - e1[-1]) \n",
    "    if 'front' in view_mode:\n",
    "        coordinate_transform = pos2pix_front\n",
    "    else:\n",
    "        coordinate_transform = pos2pix\n",
    "        \n",
    "    start_x, start_y = coordinate_transform(act_pos_start[0], act_pos_start[1])\n",
    "    end_x,   end_y   = coordinate_transform(act_pos_end[0], act_pos_end[1])\n",
    "    w, h = episode[f'rgb_{view_mode}/0.jpg'].size\n",
    "    start_x /= w\n",
    "    start_y /= h\n",
    "    end_x /= w\n",
    "    end_y /= h\n",
    "    return start_x, start_y, end_x, end_y, rotation_degree\n",
    "\n",
    "\n",
    "def get_action_info(episode, action, step, obj_id_to_info, view_mode, spatula):\n",
    "    start_x, start_y, end_x, end_y, rotation_degree = get_image_corr_action(episode, action, step, view_mode)\n",
    "    \n",
    "    # get the object we actually moved\n",
    "    obj_idx, obj_info =  get_obj_info_from_action(episode, start_x, start_y, end_x, end_y, step, obj_id_to_info, view_mode)\n",
    "    obj_desc =  obj_format(obj_info) \n",
    "    if spatula:\n",
    "        action_prompt = f'Sweep the {obj_desc} at <b>({start_x:.3f}, {start_y:.3f})</b>, rotate <r>[{-rotation_degree}]</r> degrees, and stop at <b>({end_x:.3f}, {end_y:.3f})</b>.'\n",
    "    else:\n",
    "        action_prompt = f'Pick up the {obj_desc} at <b>({start_x:.3f}, {start_y:.3f})</b>, rotate <r>[{-rotation_degree}]</r> degrees, and drop it at <b>({end_x:.3f}, {end_y:.3f})</b>.'\n",
    "    return step, action_prompt, obj_idx\n",
    "\n",
    "\n",
    "def get_obj_list_from_segm_mask(segm_mask, obj_id_to_info):\n",
    "    return [\n",
    "        (get_center_bbox_from_obs(np.array(segm_mask) == int(obj_id)),\n",
    "         obj_format(obj_info))\n",
    "        for obj_id, obj_info in obj_id_to_info.items()\n",
    "    ]\n",
    "\n",
    "def write_dataset(data, dataset_name, split_name, view_mode, only_trajectory):\n",
    "    trajectory = 'limited_' if only_trajectory else ''\n",
    "    filename = f'aux-{trajectory}{dataset_name}-{split_name}-{view_mode}.json'\n",
    "    with open(filename, 'w') as f:\n",
    "        json.dump(data, f, indent=2)\n",
    "    print(f'Write {len(data)} samples to {filename}.')\n",
    "\n",
    "\n",
    "def expand_bbox(bbox_xywh):\n",
    "    obj_ch = bbox_xywh[0]\n",
    "    obj_cv = bbox_xywh[1]\n",
    "    obj_l = obj_ch - bbox_xywh[2] / 2\n",
    "    obj_r = obj_ch + bbox_xywh[2] / 2\n",
    "    obj_t = obj_cv - bbox_xywh[3] / 2\n",
    "    obj_b = obj_cv + bbox_xywh[3] / 2\n",
    "    return obj_l, obj_t, obj_r, obj_b, obj_ch, obj_cv\n",
    "\n",
    "def calc_dist(obj_ego_c, obj_ref_c):\n",
    "    dist = math.dist(obj_ego_c, obj_ref_c)\n",
    "    dist_2d = np.array(obj_ego_c) - np.array(obj_ref_c)\n",
    "    return dist, dist_2d\n",
    "\n",
    "def get_obj_loc_desc(obj_desc, bbox):\n",
    "    return f'{obj_desc} at {bbox_format(bbox)}.'\n",
    "\n",
    "def get_scene_desc(obj_list):\n",
    "    scene_desc = '\\n'.join([get_obj_loc_desc(obj_desc, bbox) for bbox, obj_desc in obj_list])\n",
    "    return f'<scene>{scene_desc}</scene>'\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce51343d-12c6-4b9e-a9a0-838f244c2818",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Step 2: unzip all the images into folder\n",
    "\n",
    "This could take a while."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd362857-78c7-4e75-92b3-70f147466d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "def unzip_all_images(dst, dkeys, validate=False):\n",
    "    for dkey in tqdm(dkeys):\n",
    "        ep = get_all_files_for_episode(dkey)\n",
    "        for file_key in ep:\n",
    "            if file_key.endswith('png') or file_key.endswith('jpg'):\n",
    "                img_path = osp.join(dkey, file_key)\n",
    "                os.makedirs(osp.join(dst, '/'.join(img_path.split('/')[:-1])), exist_ok=True)\n",
    "                target = osp.join(dst, img_path)\n",
    "                if validate:\n",
    "                    assert osp.exists(target), f'{target} not found.'\n",
    "                    assert Image.open(target).size == (256, 128), f'Size of {target} not correct.'\n",
    "                else:\n",
    "                    with open(target, 'wb') as w:\n",
    "                        with zip_file.open(img_path, 'r') as r:\n",
    "                            w.write(r.read())\n",
    "\n",
    "unzip_all_images(unzip_image_destination, set(train_0d8k + train_8k + train_80k))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f952eff-6f86-405e-a305-ea9f6d1d412a",
   "metadata": {},
   "source": [
    "## Step 3: Prepare the auxiliary datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a26f749b-17ab-41cf-a918-1418b88f8441",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Aux. Dataset 1: Object detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da0f5504-9adb-4aff-8d93-44bb92778a79",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# object detection prompt\n",
    "prompt_obj_det = [\n",
    "    \"Identify and describe each object in the image. For each object, list it in the format <b>(x, y), {w, h}</b>, where x and y represent the coordinates of the bounding box center, and w and h represent the width and height of the bounding box. The image coordinates should start from the top left corner and be normalized between 0 and 1.\",\n",
    "    \"Catalog all the objects present in the image. For every object, use the format <b>(x, y), {w, h}</b>, with x and y indicating the center of the object's bounding box coordinates, and w and h specifying the width and height. The coordinates are normalized from the top left corner, ranging from 0 to 1.\",\n",
    "    \"List each object in the image and describe it. Use the format <b>(x, y), {w, h}</b> for each object, where x and y denote the center coordinates of the bounding box, and w and h are the width and height of the bounding box. The coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"Provide descriptions for all objects within the image. Each object should be listed using the format <b>(x, y), {w, h}</b>, where x and y are the coordinates of the bounding box center, and w and h are the width and height. The coordinates should be normalized, starting from the top left corner, within a range of 0 to 1.\",\n",
    "    \"Enumerate and describe every object found in the image. For each object, utilize the format <b>(x, y), {w, h}</b>, where x, y are the bounding box center coordinates and w, h are the dimensions (width and height) of the bounding box. The coordinates begin at the top left corner and are normalized between 0 and 1.\",\n",
    "    \"Detail all the objects within the image, listing each one using the format <b>(x, y), {w, h}</b>. Here, x and y represent the coordinates of the bounding box center, while w and h indicate the width and height. The coordinates start from the top left corner and are normalized to the range of 0 to 1.\",\n",
    "    \"Document each object present in the image. For each object, use the format <b>(x, y), {w, h}</b>, where x and y are the coordinates of the center of the bounding box, and w and h are the width and height. The coordinates should be normalized, starting from the top left corner, and range from 0 to 1.\",\n",
    "    \"For each object in the image, provide a description using the format <b>(x, y), {w, h}</b>. Here, x and y denote the coordinates of the bounding box center, and w and h represent the width and height of the bounding box. The coordinates are normalized to a scale of 0 to 1, starting from the top left corner.\",\n",
    "    \"Describe all the objects seen in the image, and list them using the format <b>(x, y), {w, h}</b>. The x and y values are the coordinates for the center of the bounding box, while w and h represent its width and height. The coordinates should be normalized from the top left corner, within a range of 0 to 1.\",\n",
    "    \"Identify and list each object found in the image. For each one, use the format <b>(x, y), {w, h}</b>. In this format, x and y are the coordinates for the bounding box center, and w and h are the width and height. The coordinates are to be normalized starting from the top left corner, ranging from 0 to 1.\",\n",
    "    \"List and describe each object in the image using the format <b>(x, y), {w, h}</b>. Here, x and y correspond to the coordinates of the bounding box center, and w and h specify the width and height of the bounding box. The coordinates should start from the top left corner and be normalized to the range of 0 to 1.\",\n",
    "    \"Provide a description for each object in the image, formatted as <b>(x, y), {w, h}</b>. The x and y values indicate the center coordinates of the bounding box, while w and h represent the width and height. The coordinates start from the top left corner and are normalized between 0 and 1.\",\n",
    "    \"Catalog each object within the image, using the format <b>(x, y), {w, h}</b> for each one. In this format, x and y are the coordinates for the center of the bounding box, and w and h are the width and height. The coordinates should be normalized, beginning at the top left corner and ranging from 0 to 1.\",\n",
    "    \"Enumerate all the objects in the image, providing descriptions for each using the format <b>(x, y), {w, h}</b>. The x and y values represent the center coordinates of the bounding box, while w and h indicate its width and height. The coordinates are normalized starting from the top left corner, within a range of 0 to 1.\",\n",
    "    \"Describe each object in the image, listing them in the format <b>(x, y), {w, h}</b>. Here, x and y denote the center coordinates of the bounding box, and w and h specify the width and height. The coordinates should be normalized from the top left corner, ranging from 0 to 1.\"\n",
    "]\n",
    "\n",
    "def prepare_dt_obj_dec(dkeys, split_name, view_mode='front', only_trajectory=False, debug=False):\n",
    "    def gen_example(local_image_path, obj_list):\n",
    "        return {\n",
    "            'id': f'obj_det/{dkey}/{local_image_path}',\n",
    "            'image': [f'{dkey}/{local_image_path}'],\n",
    "            'conversations': [\n",
    "                {\n",
    "                    'from': 'human',\n",
    "                    'value': '<image>\\n' + random.choice(prompt_obj_det),\n",
    "                },\n",
    "                {\n",
    "                    'from': 'gpt',\n",
    "                    'value': get_scene_desc(obj_list)\n",
    "                }\n",
    "            ]\n",
    "        }\n",
    "        \n",
    "    data = []\n",
    "    for dkey in tqdm(dkeys):\n",
    "        episode = get_all_files_for_episode(dkey)\n",
    "        traj_meta = episode['trajectory.json']\n",
    "        obj_id_to_info  = traj_meta['obj_id_to_info']\n",
    "        obs = episode['obs.pkl']\n",
    "        num_of_steps = obs['ee'].shape[0]\n",
    "\n",
    "        # iterate over all trajectory images\n",
    "        for obs_i in range(num_of_steps):\n",
    "            local_image_path = f'rgb_{view_mode}/{obs_i}.jpg'\n",
    "            obj_list = get_obj_list_from_segm_mask(episode[f'segm_{view_mode}/{obs_i}.png'], obj_id_to_info)\n",
    "            data.append(gen_example(local_image_path, obj_list))\n",
    "            \n",
    "            if debug:\n",
    "                print(local_image_path)\n",
    "                print('\\n'.join([f'{obj_desc} at {bbox_format(bbox)}.' for bbox, obj_desc in obj_list]))\n",
    "                for t in range(len(obj_list)):\n",
    "                    draw_norm_bbox(episode[local_image_path], obj_list[t][0], dp=t == len(obj_list) - 1)\n",
    "\n",
    "        if not only_trajectory:\n",
    "            # iterate over all trajectory images\n",
    "            assets = traj_meta['prompt_assets']\n",
    "            for k in assets:\n",
    "                local_image_path = f'rgb_{view_mode}/a_{k}.png'\n",
    "                obj_info = assets[k]['segm_obj_info']\n",
    "                if not isinstance(obj_info, list):\n",
    "                    obj_info = [obj_info]\n",
    "\n",
    "                asset_obj_id_to_info = {i['obj_id']: i for i in obj_info}\n",
    "                obj_list = get_obj_list_from_segm_mask( episode[f'segm_{view_mode}/a_{k}.png'], asset_obj_id_to_info)\n",
    "                data.append(gen_example(local_image_path, obj_list))\n",
    "                \n",
    "                if debug:\n",
    "                    print(local_image_path)\n",
    "                    print('\\n'.join([f'{obj_desc} at {bbox_format(bbox)}.' for bbox, obj_desc in obj_list]))\n",
    "                    for t in range(len(obj_list)):\n",
    "                        draw_norm_bbox(episode[local_image_path], obj_list[t][0], dp=t == len(obj_list) - 1)\n",
    "\n",
    "        if debug:\n",
    "            print(json.dumps(data, indent=2))\n",
    "            break\n",
    "\n",
    "    if not debug:\n",
    "        write_dataset(data, 'obj_det', split_name, view_mode, only_trajectory)\n",
    "\n",
    "prepare_dt_obj_dec(train_0d8k, 'train-0d8k', only_trajectory=False, debug=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "851fdd1a-ed9d-4f01-951a-71fdb0736cf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_dt_obj_dec(train_0d8k, 'train-0d8k')\n",
    "prepare_dt_obj_dec(train_8k, 'train-8k')\n",
    "prepare_dt_obj_dec(train_80k, 'train-80k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0b6ee23-2440-421c-bbc2-bad77b8bbced",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_dt_obj_dec(train_0d8k, 'train-0d8k', only_trajectory=True)\n",
    "prepare_dt_obj_dec(train_8k, 'train-8k', only_trajectory=True)\n",
    "prepare_dt_obj_dec(train_80k, 'train-80k', only_trajectory=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a66f53d3-5e53-4f10-bedd-432aa47cc698",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Aux. Dataset 2: Object localization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e29a268d-2f68-4db3-b6c7-cf9a3ecee622",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Prompt object localization \n",
    "prompt_obj_loc = [\n",
    "    \"Where is {object} located in the image? Please use the format <b>(x, y), {w, h}</b> where x and y represent the center coordinates of the bounding box, and w and h are the width and height. The coordinates start from the top left corner and are normalized to a scale of 0 to 1.\",\n",
    "    \"Can you provide the location of {object} in the image? Format it as <b>(x, y), {w, h}</b>, with x and y as the center coordinates of the bounding box and w and h as the width and height. The coordinates should begin at the top left corner and be normalized from 0 to 1.\",\n",
    "    \"What are the coordinates of {object} in the image? Use the format <b>(x, y), {w, h}</b>, where x and y are the center of the bounding box, and w and h represent the width and height. Coordinates should start at the top left corner and be normalized to a range of 0 to 1.\",\n",
    "    \"Please specify the location of {object} in the image. List it in the format <b>(x, y), {w, h}</b>, where x and y denote the bounding box center coordinates, and w and h are the width and height. The coordinates begin from the top left corner and should be normalized to 0 to 1.\",\n",
    "    \"What is the position of {object} within the image? Use the format <b>(x, y), {w, h}</b> to describe it, with x and y as the center coordinates of the bounding box, and w and h as the width and height. The coordinates start at the top left corner and are normalized to a scale of 0 to 1.\",\n",
    "    \"Describe the location of {object} in the image using the format <b>(x, y), {w, h}</b>. In this format, x and y denote the center coordinates of the bounding box, while w and h represent its width and height. Coordinates should be normalized from the top left corner, ranging from 0 to 1.\",\n",
    "    \"Can you detail the location of {object} in the image? Format it as <b>(x, y), {w, h}</b>, where x and y indicate the bounding box center, and w and h represent the width and height. The coordinates should be normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Provide the location of {object} in the image using the format <b>(x, y), {w, h}</b>. Here, x and y are the center coordinates of the bounding box, and w and h are the width and height. The coordinates begin at the top left corner and are normalized from 0 to 1.\",\n",
    "    \"Where is {object} positioned in the image? Use the format <b>(x, y), {w, h}</b>, where x and y denote the center coordinates of the bounding box, and w and h are the width and height. The coordinates should be normalized to a range of 0 to 1 starting from the top left corner.\",\n",
    "    \"Specify the location of {object} in the image in the format <b>(x, y), {w, h}</b>. In this format, x and y represent the bounding box center, and w and h are the width and height. The coordinates should start from the top left corner and be normalized between 0 and 1.\",\n",
    "    \"What is the exact position of {object} in the image? Format the coordinates as <b>(x, y), {w, h}</b>, where x and y are the center of the bounding box and w and h denote its width and height. The coordinates start from the top left corner and are normalized to a scale of 0 to 1.\",\n",
    "    \"Describe where {object} is located in the image using the format <b>(x, y), {w, h}</b>. Here, x and y indicate the bounding box center coordinates, and w and h specify its width and height. The coordinates should be normalized starting from the top left corner, within the range of 0 to 1.\",\n",
    "    \"Could you tell me the location of {object} in the image? Use the format <b>(x, y), {w, h}</b>, where x and y denote the center of the bounding box and w and h are the width and height. Coordinates start at the top left corner and should be normalized between 0 and 1.\",\n",
    "    \"Provide the coordinates of {object} in the image in the format <b>(x, y), {w, h}</b>. Here, x and y are the center of the bounding box, while w and h represent its width and height. The coordinates should start from the top left corner and be normalized to 0 to 1.\",\n",
    "    \"How is the {object} located in the image? List its coordinates using the format <b>(x, y), {w, h}</b>, where x and y are the center coordinates of the bounding box, and w and h indicate its width and height. The coordinates begin at the top left corner and are normalized to a range of 0 to 1.\"\n",
    "]\n",
    "\n",
    "def prepare_dt_obj_loc(dkeys, split_name, view_mode='front', only_trajectory=False, debug=False):\n",
    "    def gen_example(local_image_path, obj_list):        \n",
    "        return [{\n",
    "                    'id': f'obj_loc/{dkey}/{local_image_path}',\n",
    "                    'image': [f'{dkey}/{local_image_path}'],\n",
    "                    'conversations': [\n",
    "                        {\n",
    "                            'from': 'human',\n",
    "                            'value': '<image>\\n' + random.choice(prompt_obj_loc).replace('{object}', obj_desc),\n",
    "                        },\n",
    "                        {\n",
    "                            'from': 'gpt',\n",
    "                            'value': get_obj_loc_desc(obj_desc, bbox)\n",
    "                        }\n",
    "                    ]\n",
    "                }\n",
    "               for bbox, obj_desc in obj_list]\n",
    "    \n",
    "    data = []\n",
    "    for dkey in tqdm(dkeys):\n",
    "        episode = get_all_files_for_episode(dkey)\n",
    "        traj_meta = episode['trajectory.json']\n",
    "        obj_id_to_info  = traj_meta['obj_id_to_info']\n",
    "        obs = episode['obs.pkl']\n",
    "        num_of_steps = obs['ee'].shape[0]\n",
    "\n",
    "        for obs_i in range(num_of_steps):\n",
    "            local_image_path = f'rgb_{view_mode}/{obs_i}.jpg'\n",
    "            obj_list = get_obj_list_from_segm_mask(episode[f'segm_{view_mode}/{obs_i}.png'], obj_id_to_info)\n",
    "            data.extend(gen_example(local_image_path, obj_list))\n",
    "            \n",
    "            if debug:\n",
    "                print(local_image_path)\n",
    "                for t in range(len(obj_list)):\n",
    "                    print(f'{obj_list[t][1]} at {bbox_format(obj_list[t][0])}.')\n",
    "                    draw_norm_bbox(episode[local_image_path].copy(), obj_list[t][0], dp=True)\n",
    "\n",
    "        if not only_trajectory:\n",
    "            # prompt assets, the objects mentioned in the prompt\n",
    "            assets = traj_meta['prompt_assets']\n",
    "            for k in assets:\n",
    "                local_image_path = f'rgb_{view_mode}/a_{k}.png'\n",
    "                obj_info = assets[k]['segm_obj_info']\n",
    "                if not isinstance(obj_info, list):\n",
    "                    obj_info = [obj_info]\n",
    "                asset_obj_id_to_info = {i['obj_id']: i for i in obj_info}\n",
    "                obj_list = get_obj_list_from_segm_mask( episode[f'segm_{view_mode}/a_{k}.png'], asset_obj_id_to_info)\n",
    "                data.extend(gen_example(local_image_path, obj_list))\n",
    "                \n",
    "                if debug:\n",
    "                    print(local_image_path)\n",
    "                    for t in range(len(obj_list)):\n",
    "                        print(f'{obj_list[t][1]} at {bbox_format(obj_list[t][0])}.')\n",
    "                        draw_norm_bbox(episode[local_image_path].copy(), obj_list[t][0], dp=True)\n",
    "                    \n",
    "        if debug:\n",
    "            print(json.dumps(data, indent=2))\n",
    "            break\n",
    "\n",
    "    if not debug:\n",
    "        write_dataset(data, 'obj_loc', split_name, view_mode, only_trajectory)\n",
    "\n",
    "prepare_dt_obj_loc(train_0d8k, 'train-0d8k', debug=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd51f47d-eb7b-48a3-a9be-257715bc2553",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_dt_obj_loc(train_0d8k, 'train-0d8k')\n",
    "prepare_dt_obj_loc(train_8k, 'train-8k')\n",
    "prepare_dt_obj_loc(train_80k, 'train-80k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3525a507-a62d-4077-a09e-93fc7a5f7b2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_dt_obj_loc(train_0d8k, 'train-0d8k', only_trajectory=True)\n",
    "prepare_dt_obj_loc(train_8k, 'train-8k', only_trajectory=True)\n",
    "prepare_dt_obj_loc(train_80k, 'train-80k', only_trajectory=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9003cbb-3078-40e7-b7ed-91e18a35ef57",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Aux. Dataset 3: Action prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d54e96c3-5c10-4658-96b5-1461713575cc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "prompt_act_pred = [\n",
    "    \"Could you detail the steps needed to transform the scene shown in the image into the second scene? The second scene is provided as a collection of object bounding boxes {scene}. The format for these bounding boxes is <b>(x, y), {w, h}</b>, where x and y represent the center coordinates, and w and h are the width and height. The coordinates should be normalized to a scale of 0 to 1, starting from the top left corner.\",\n",
    "    \"Can you describe what actions are required to rearrange the scene shown in the image to match the second scene? The second scene is given as a set of object bounding boxes {scene}. These bounding boxes follow the format <b>(x, y), {w, h}</b>, where x and y indicate the center coordinates, and w and h represent the width and height. The coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"Could you list the steps necessary to modify the scene shown in the image to the second scene? The second scene is described as a collection of object bounding boxes {scene}. The bounding box format is <b>(x, y), {w, h}</b>, with x and y denoting the center coordinates, and w and h representing the width and height. The coordinates are normalized to a scale of 0 to 1, starting from the top left corner.\",\n",
    "    \"Can you explain what needs to be done to adjust the scene shown in the image to resemble the second scene? The second scene {scene} consists of object bounding boxes provided in the format <b>(x, y), {w, h}</b>. Here, x and y represent the center coordinates, and w and h are the width and height. The coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"Could you outline the necessary actions to arrange the scene shown in the image into the second scene? The second scene is defined by a collection of object bounding boxes {scene}. These bounding boxes follow the format <b>(x, y), {w, h}</b>, where x and y denote the center coordinates, and w and h are the width and height. The coordinates start from the top left corner and should be normalized to a scale of 0 to 1.\",\n",
    "    \"Can you specify what needs to be done to convert the scene shown in the image into the second scene? The second scene is provided as a series of object bounding boxes {scene}. The format for these bounding boxes is <b>(x, y), {w, h}</b>, with x and y representing the center coordinates, and w and h indicating the width and height. Coordinates should be normalized from the top left corner to a scale of 0 to 1.\",\n",
    "    \"Could you describe the steps required to change the scene shown in the image to the second scene? The second scene is depicted as a collection of object bounding boxes {scene}. The bounding box format is <b>(x, y), {w, h}</b>, where x and y denote the center coordinates, and w and h represent the width and height. The coordinates are normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Can you list the actions necessary to transform the scene shown in the image into the second scene? The second scene is described using object bounding boxes {scene}. The format of these bounding boxes is <b>(x, y), {w, h}</b>, where x and y are the center coordinates, and w and h represent the width and height. Coordinates should be normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Could you explain the process to arrange the scene shown in the image to match the second scene? The second scene is provided as a collection of object bounding boxes {scene}. These bounding boxes are formatted as <b>(x, y), {w, h}</b>, where x and y represent the center coordinates, and w and h are the width and height. The coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"Can you detail what needs to be done to rearrange the scene shown in the image to the second scene? The second scene is given as a series of object bounding boxes {scene}. The bounding box format is <b>(x, y), {w, h}</b>, where x and y denote the center coordinates, and w and h represent the width and height. Coordinates should be normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Could you specify the steps needed to modify the scene shown in the image to resemble the second scene? The second scene is described as a set of object bounding boxes {scene}. These bounding boxes follow the format <b>(x, y), {w, h}</b>, where x and y represent the center coordinates, and w and h indicate the width and height. The coordinates start from the top left corner and should be normalized to a scale of 0 to 1.\",\n",
    "    \"Can you outline the necessary actions to change the scene shown in the image into the second scene? The second scene {scene} consists of object bounding boxes provided in the format <b>(x, y), {w, h}</b>, where x and y denote the center coordinates, and w and h represent the width and height. Coordinates should be normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Could you describe the steps to adjust the scene shown in the image to the second scene? The second scene is given as a collection of object bounding boxes {scene}. The format for these bounding boxes is <b>(x, y), {w, h}</b>, where x and y represent the center coordinates, and w and h are the width and height. The coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"Can you explain what needs to be done to transform the scene shown in the image into the second scene? The second scene is depicted using object bounding boxes {scene}. The bounding box format is <b>(x, y), {w, h}</b>, with x and y representing the center coordinates, and w and h indicating the width and height. The coordinates start from the top left corner and are normalized to a scale of 0 to 1.\",\n",
    "    \"Could you detail the steps necessary to convert the scene shown in the image to the second scene? The second scene is described as a set of object bounding boxes {scene}. These bounding boxes follow the format <b>(x, y), {w, h}</b>, where x and y represent the center coordinates, and w and h denote the width and height. The coordinates should be normalized to a scale of 0 to 1 starting from the top left corner.\"\n",
    "]\n",
    "\n",
    "\n",
    "def prepare_dt_act_pred(dkeys, split_name, view_mode='front', debug=False):\n",
    "    \n",
    "    def gen_example(local_image_path, action, obj_list, spatula):\n",
    "        return {\n",
    "            'id': f'act_pred/{dkey}/{local_image_path}',\n",
    "            'image': [f'{dkey}/{local_image_path}'],\n",
    "            'conversations': [\n",
    "                {\n",
    "                    'from': 'human',\n",
    "                    'value': '\\n'.join(['<image>',\n",
    "                                        random.choice(prompt_act_pred).replace('{scene}', get_scene_desc(obj_list)),\n",
    "                                        get_action_prompt(spatula)]),\n",
    "                },\n",
    "                {\n",
    "                    'from': 'gpt',\n",
    "                    'value': action[1]\n",
    "                }\n",
    "            ]\n",
    "        }\n",
    "    \n",
    "    data = []\n",
    "    for dkey in tqdm(dkeys):\n",
    "        episode = get_all_files_for_episode(dkey)\n",
    "        traj_meta = episode['trajectory.json']\n",
    "        obj_id_to_info  = traj_meta['obj_id_to_info']\n",
    "        obs = episode['obs.pkl']\n",
    "        num_of_steps = obs['ee'].shape[0]\n",
    "        \n",
    "        action = episode['action.pkl']\n",
    "        \n",
    "        actions = []\n",
    "        for i in range(num_of_steps):\n",
    "            local_image_path = f'rgb_{view_mode}/{i}.jpg'\n",
    "            last_local_image_path = f'rgb_{view_mode}/{i - 1}.jpg'\n",
    "            obj_list = get_obj_list_from_segm_mask(episode[f'segm_{view_mode}/{i}.png'], obj_id_to_info)\n",
    "            # prepare action list\n",
    "            if i != num_of_steps - 1:\n",
    "                actions.append(get_action_info(episode, action, i, obj_id_to_info, view_mode, obs['ee'][i] > 0))\n",
    "            # prepare example from the last step and current step\n",
    "            if i > 0:\n",
    "                data.append(gen_example(last_local_image_path, actions[i - 1], obj_list, obs['ee'][0] > 0))\n",
    "\n",
    "                if debug:\n",
    "                    print(last_local_image_path, '->', local_image_path)\n",
    "                    print(actions[i - 1][1])\n",
    "                    last_obj_list = get_obj_list_from_segm_mask(episode[f'segm_{view_mode}/{i - 1}.png'], obj_id_to_info)\n",
    "                    draw_norm_bbox(episode[last_local_image_path].copy(), last_obj_list[actions[i - 1][-1]][0], dp=True)\n",
    "                    draw_norm_bbox(episode[local_image_path].copy(), obj_list[actions[i - 1][-1]][0], dp=True)\n",
    "        \n",
    "        if debug:\n",
    "            print(json.dumps(data, indent=2))\n",
    "            break\n",
    "    \n",
    "    if not debug:\n",
    "        write_dataset(data, 'act_pred', split_name, view_mode, False)\n",
    "\n",
    "prepare_dt_act_pred(train_0d8k, 'train-0d8k', debug=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "995a5592-f523-4da1-bcd3-250adc2b9614",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_dt_act_pred(train_0d8k, 'train-0d8k')\n",
    "prepare_dt_act_pred(train_8k, 'train-8k')\n",
    "prepare_dt_act_pred(train_80k, 'train-80k')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3dbe9a78-62fa-4d82-b589-81c821dc4fd5",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Aux. Dataset 4: Future prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8515daab-0733-4018-874a-c75d3e20ad4d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "prompt_fut_pred = [\n",
    "    \"The image shows a scene with multiple objects. Now you {pick up and drop}, what will the scene look like? List the object bounding boxes. The bounding box format is <b>(x, y), {w, h}</b>, where x and y represent the center coordinates of the bounding box, and w and h are its width and height. The coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"An image depicts a scene containing multiple objects. Now you {pick up and drop}, what will the scene look like? Write the list of object bounding boxes. The bounding boxes should be formatted as <b>(x, y), {w, h}</b>, where x and y denote the center coordinates, and w and h are the width and height. The coordinates start from the top left corner and are normalized to a scale of 0 to 1.\",\n",
    "    \"The image presents a scene with several objects. Now you {pick up and drop}, what will the scene look like? List the object bounding boxes. The format for these bounding boxes is <b>(x, y), {w, h}</b>, where x and y represent the center coordinates, and w and h are the width and height. Coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"Displayed in the image is a scene containing multiple objects. Now you {pick up and drop}, what will the scene look like? Write down the list of object bounding boxes. These bounding boxes follow the format <b>(x, y), {w, h}</b>, with x and y as the center coordinates, and w and h as the width and height. The coordinates should be normalized starting from the top left corner to a scale of 0 to 1.\",\n",
    "    \"The image illustrates a scene with multiple objects. Now you {pick up and drop}, what will the scene look like? Write the list of object bounding boxes. The bounding boxes are formatted as <b>(x, y), {w, h}</b>, where x and y denote the center coordinates, and w and h represent the width and height. Coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"The image depicts a scene with several objects. Now you {pick up and drop}, what will the scene look like? List the object bounding boxes. The bounding box format is <b>(x, y), {w, h}</b>, where x and y represent the center coordinates, and w and h denote the width and height. The coordinates should be normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"In the image, there is a scene with multiple objects. Now you {pick up and drop}, what will the scene look like? Write the list of object bounding boxes. The format of these bounding boxes is <b>(x, y), {w, h}</b>, where x and y indicate the center coordinates, and w and h represent the width and height. The coordinates start from the top left corner and are normalized to a scale of 0 to 1.\",\n",
    "    \"An image shows a scene with various objects. Now you {pick up and drop}, what will the scene look like? Write down the list of object bounding boxes. The bounding boxes follow the format <b>(x, y), {w, h}</b>, where x and y denote the center coordinates, and w and h are the width and height. The coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"The image presents a scene containing several objects. Now you {pick up and drop}, what will the scene look like? List the object bounding boxes. The bounding box format is <b>(x, y), {w, h}</b>, where x and y represent the center coordinates, and w and h are the width and height. Coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"The image displays a scene with multiple objects. Now you {pick up and drop}, what will the scene look like? Write the list of object bounding boxes. The bounding boxes should be in the format <b>(x, y), {w, h}</b>, where x and y denote the center coordinates, and w and h represent the width and height. The coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"An image illustrates a scene with multiple objects. Now you {pick up and drop}, what will the scene look like? Write down the list of object bounding boxes. These bounding boxes are formatted as <b>(x, y), {w, h}</b>, where x and y represent the center coordinates, and w and h denote the width and height. Coordinates should be normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"The image shows a scene with various objects. Now you {pick up and drop}, what will the scene look like? List the object bounding boxes. The format for these bounding boxes is <b>(x, y), {w, h}</b>, with x and y representing the center coordinates, and w and h as the width and height. Coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\",\n",
    "    \"Displayed in the image is a scene containing multiple objects. Now you {pick up and drop}, what will the scene look like? Write the list of object bounding boxes. The bounding box format is <b>(x, y), {w, h}</b>, where x and y denote the center coordinates, and w and h are the width and height. The coordinates start from the top left corner and are normalized to a scale of 0 to 1.\",\n",
    "    \"The image illustrates a scene with various objects. Now you {pick up and drop}, what will the scene look like? List the object bounding boxes. The bounding boxes are formatted as <b>(x, y), {w, h}</b>, where x and y indicate the center coordinates, and w and h represent the width and height. Coordinates should be normalized from the top left corner to a scale of 0 to 1.\",\n",
    "    \"An image depicts a scene with multiple objects. Now you {pick up and drop}, what will the scene look like? Write the list of object bounding boxes. The bounding box format is <b>(x, y), {w, h}</b>, where x and y represent the center coordinates, and w and h denote the width and height. The coordinates should start from the top left corner and be normalized to a scale of 0 to 1.\"\n",
    "]\n",
    "\n",
    "def prepare_dt_fut_pred(dkeys, split_name, view_mode='front', debug=False):\n",
    "    def gen_example(local_image_path, action, obj_list):\n",
    "        action_str_list = action[1][:-1].split()\n",
    "        action_str_list[0] = action_str_list[0].lower()\n",
    "        action_str = ' '.join(action_str_list)\n",
    "        return {\n",
    "            'id': f'fut_pred/{dkey}/{local_image_path}',\n",
    "            'image': [f'{dkey}/{local_image_path}'],\n",
    "            'conversations': [\n",
    "                {\n",
    "                    'from': 'human',\n",
    "                    'value': '<image>\\n' + random.choice(prompt_fut_pred).replace('{pick up and drop}', action_str),\n",
    "                },\n",
    "                {\n",
    "                    'from': 'gpt',\n",
    "                    'value': get_scene_desc(obj_list)\n",
    "                }\n",
    "            ]\n",
    "        }\n",
    "    \n",
    "    data = []\n",
    "    # state {text, visual} x {no, name, colorful name}\n",
    "    for dkey in tqdm(dkeys):\n",
    "        episode = get_all_files_for_episode(dkey)\n",
    "        \n",
    "        traj_meta = episode['trajectory.json']\n",
    "        obj_id_to_info  = traj_meta['obj_id_to_info']\n",
    "        obs = episode['obs.pkl']\n",
    "        num_of_steps = obs['ee'].shape[0]\n",
    "        \n",
    "        action = episode['action.pkl']\n",
    "        \n",
    "        actions = []\n",
    "        for i in range(num_of_steps):\n",
    "            local_image_path = f'rgb_{view_mode}/{i}.jpg'\n",
    "            last_local_image_path = f'rgb_{view_mode}/{i - 1}.jpg'\n",
    "            obj_list = get_obj_list_from_segm_mask(episode[f'segm_{view_mode}/{i}.png'], obj_id_to_info)\n",
    "            # prepare action list\n",
    "            if i != num_of_steps - 1:\n",
    "                actions.append(get_action_info(episode, action, i, obj_id_to_info, view_mode, obs['ee'][i] > 0))\n",
    "            # prepare example from the last step and current step\n",
    "            if i > 0:\n",
    "                data.append(gen_example(last_local_image_path, actions[i - 1], obj_list))\n",
    "\n",
    "                if debug:\n",
    "                    print(last_local_image_path, '->', local_image_path)\n",
    "                    last_obj_list = get_obj_list_from_segm_mask(episode[f'segm_{view_mode}/{i - 1}.png'], obj_id_to_info)\n",
    "                    print('\\n'.join([f'{obj_desc} at {bbox_format(bbox)}.' for bbox, obj_desc in last_obj_list]))\n",
    "                    print('-' * 10)\n",
    "                    print(actions[i - 1][1])\n",
    "                    print('-' * 10)\n",
    "                    print('\\n'.join([f'{obj_desc} at {bbox_format(bbox)}.' for bbox, obj_desc in obj_list]))\n",
    "\n",
    "                    draw_norm_bbox(episode[last_local_image_path].copy(), last_obj_list[actions[i - 1][-1]][0], dp=True)\n",
    "                    draw_norm_bbox(episode[local_image_path].copy(), obj_list[actions[i - 1][-1]][0], dp=True)\n",
    "        \n",
    "        if debug:\n",
    "            print(json.dumps(data, indent=2))\n",
    "            break\n",
    "    \n",
    "    if not debug:\n",
    "        write_dataset(data, 'fut_pred', split_name, view_mode, False)\n",
    "\n",
    "prepare_dt_fut_pred(train_0d8k, 'train-0d8k', debug=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f1c4fe7-9825-493c-9c72-fdbff5e45189",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_dt_fut_pred(train_0d8k, 'train-0d8k')\n",
    "prepare_dt_fut_pred(train_8k, 'train-8k')\n",
    "prepare_dt_fut_pred(train_80k, 'train-80k')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a274914e-c6ca-4394-99f8-21fef68d52f7",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Aux. Dataset 5: Spatial relationship"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4925413a-7cea-4791-8401-40bf641f3bef",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "prompt_spa_rel = [\n",
    "    \"Can you describe the relative spatial locations of {Putego} compared to {Putref} in this image? Use relative location words like left, right, above, below, etc. Also, find the 2D center distance and the Euclidean center distance between them. Your output must follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Could you describe the relative spatial positions of {Putego} in comparison to {Putref} in this image? Use terms like left, right, above, below, etc. Also, calculate the 2D center distance and the Euclidean center distance between them. Your output should be formatted as follows: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Please describe the relative spatial locations of {Putego} compared to {Putref} in this image. Use words like left, right, above, below, etc. Additionally, find the 2D center distance and the Euclidean center distance between them. Your output must be in this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Can you explain the relative spatial positions of {Putego} compared to {Putref} in this image? Use terms such as left, right, above, below, etc. Also, determine the 2D center distance and the Euclidean center distance between them. Your output should match this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Describe the relative spatial locations of {Putego} compared to {Putref} in this image using words like left, right, above, below, etc. Also, calculate the 2D center distance and the Euclidean center distance between them. Your output must follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Could you describe the spatial relationship between {Putego} and {Putref} in this image using relative location words like left, right, above, below, etc.? Also, find the 2D center distance and the Euclidean center distance between them. Your output should be formatted as follows: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Can you detail the relative spatial positions of {Putego} compared to {Putref} in this image? Use words like left, right, above, below, etc. Also, determine the 2D center distance and the Euclidean center distance between them. Your output must be in this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Could you explain the spatial relationship between {Putego} and {Putref} in this image using terms such as left, right, above, below, etc.? Also, calculate the 2D center distance and the Euclidean center distance between them. Your output should match this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Describe the relative spatial positions of {Putego} compared to {Putref} in this image. Use relative location words like left, right, above, below, etc. Also, find the 2D center distance and the Euclidean center distance between them. Your output must follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Can you describe how {Putego} is positioned relative to {Putref} in this image using words such as left, right, above, below, etc.? Also, find the 2D center distance and the Euclidean center distance between them. Your output should be in this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Could you detail the relative positions of {Putego} compared to {Putref} in this image using terms like left, right, above, below, etc.? Also, calculate the 2D center distance and the Euclidean center distance between them. Your output must be formatted as follows: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Please describe the spatial relationship of {Putego} in comparison to {Putref} in this image using relative location terms such as left, right, above, below, etc. Additionally, find the 2D center distance and the Euclidean center distance between them. Your output should match this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Can you describe the relative spatial locations of {Putego} compared to {Putref} in this image? Use relative location words like left, right, above, below, etc. Also, calculate the 2D center distance and the Euclidean center distance between them. Your output should follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"Could you describe the spatial locations of {Putego} relative to {Putref} in this image using words such as left, right, above, below, etc.? Additionally, find the 2D center distance and the Euclidean center distance between them. Your output must be in this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\"\n",
    "]\n",
    "\n",
    "\n",
    "def get_spatial_relationship(obj_list):\n",
    "    qa_pairs = []\n",
    "    for ego, obj_ego_bbox in obj_list.items():\n",
    "        obj_ego_l, obj_ego_t, obj_ego_r, obj_ego_b, obj_ego_ch, obj_ego_cv = expand_bbox(obj_ego_bbox)\n",
    "        for ref, obj_ref_bbox in obj_list.items():\n",
    "            obj_ref_l, obj_ref_t, obj_ref_r, obj_ref_b, obj_ref_ch, obj_ref_cv = expand_bbox(obj_ref_bbox)\n",
    "            if ego == ref:\n",
    "                continue\n",
    "            dist, dist_2d = calc_dist(obj_ego_bbox[:2], obj_ref_bbox[:2])\n",
    "            if obj_ego_ch < obj_ref_ch:\n",
    "                if obj_ego_r < obj_ref_r:\n",
    "                    left_right_str = \"left\"\n",
    "                else:\n",
    "                    left_right_str = \"horizontally containig\"\n",
    "            else:\n",
    "                if obj_ego_l < obj_ref_l:\n",
    "                    left_right_str = \"horizontally containig\"\n",
    "                else:\n",
    "                    left_right_str = \"right\"\n",
    "            if obj_ego_cv < obj_ref_cv:\n",
    "                if obj_ego_b < obj_ref_b:\n",
    "                    top_bottm_str = \"top\"\n",
    "                else:\n",
    "                    top_bottm_str = \"vertically containig\"\n",
    "            else:\n",
    "                if obj_ego_t < obj_ref_t:\n",
    "                    top_bottm_str = \"horizontally containig\"\n",
    "                else:\n",
    "                    top_bottm_str = \"bottom\"\n",
    "\n",
    "            ques = '<image>\\n' + random.choice(prompt_spa_rel).replace(\"{Putego}\", ego).replace(\"{Putref}\", ref)\n",
    "            answer = f'{ego} is {left_right_str} and {top_bottm_str} from {ref} with 2d center distance (x,y) of <d>({dist_2d[0]:.3f}, {dist_2d[1]:.3f})</d> and euclidean center distance of <e>{dist:.3f}</e>.'        \n",
    "            qa_pairs.append((ques, answer))\n",
    "    return qa_pairs\n",
    "\n",
    "\n",
    "def prepare_dt_spa_rel(dkeys, split_name, view_mode='front',only_trajectory=False, debug=False):\n",
    "    def gen_example(local_image_path, obj_list):\n",
    "        objs = {obj_desc: bbox for bbox, obj_desc in obj_list}\n",
    "        qa_pairs = get_spatial_relationship(objs)\n",
    "        result = []\n",
    "\n",
    "        for i in range(len(qa_pairs)):\n",
    "            q, a = qa_pairs[i]    \n",
    "            old_a = qa_pairs[i - 1][1]\n",
    "            if old_a.endswith('.'):\n",
    "                old_a = old_a[:-1]\n",
    "            \n",
    "            result.append({\n",
    "                'id': f'spa_rel/{dkey}/{local_image_path}',\n",
    "                'image': [f'{dkey}/{local_image_path}'],\n",
    "                'conversations': [\n",
    "                    {\n",
    "                        'from': 'human',\n",
    "                        'value': q.replace('{example}', old_a),\n",
    "                    },\n",
    "                    {\n",
    "                        'from': 'gpt',\n",
    "                        'value': a\n",
    "                    }\n",
    "                ]\n",
    "            })\n",
    "        return result\n",
    "    \n",
    "    data = []\n",
    "    # state {text, visual} x {no, name, colorful name}\n",
    "    for dkey in tqdm(dkeys):\n",
    "        episode = get_all_files_for_episode(dkey)\n",
    "        \n",
    "        traj_meta = episode['trajectory.json']\n",
    "        obj_id_to_info  = traj_meta['obj_id_to_info']\n",
    "        obs = episode['obs.pkl']\n",
    "        num_of_steps = obs['ee'].shape[0]\n",
    "\n",
    "        for obs_i in range(num_of_steps):\n",
    "            local_image_path = f'rgb_{view_mode}/{obs_i}.jpg'\n",
    "            obj_list = get_obj_list_from_segm_mask(episode[f'segm_{view_mode}/{obs_i}.png'], obj_id_to_info)\n",
    "            data.extend(gen_example(local_image_path, obj_list))\n",
    "            \n",
    "            if debug:\n",
    "                print(local_image_path)\n",
    "                for t in range(len(obj_list)):            \n",
    "                    draw_norm_bbox(episode[local_image_path], obj_list[t][0], dp=t == len(obj_list) - 1)\n",
    "\n",
    "        if not only_trajectory:\n",
    "            # prompt assets, the objects mentioned in the prompt\n",
    "            assets = traj_meta['prompt_assets']\n",
    "            for k in assets:\n",
    "                local_image_path = f'rgb_{view_mode}/a_{k}.png'\n",
    "                obj_info = assets[k]['segm_obj_info']\n",
    "                if not isinstance(obj_info, list):\n",
    "                    obj_info = [obj_info]\n",
    "                asset_obj_id_to_info = {i['obj_id']: i for i in obj_info}\n",
    "                obj_list = get_obj_list_from_segm_mask( episode[f'segm_{view_mode}/a_{k}.png'], asset_obj_id_to_info)\n",
    "                data.extend(gen_example(local_image_path, obj_list))\n",
    "                \n",
    "                if debug:\n",
    "                    print(local_image_path)\n",
    "                    for t in range(len(obj_list)):            \n",
    "                        draw_norm_bbox(episode[local_image_path.replace('segm', 'rgb')], obj_list[t][0], dp=t == len(obj_list) - 1)\n",
    "\n",
    "        if debug:\n",
    "            print(json.dumps(data, indent=2))\n",
    "            break\n",
    "    \n",
    "    if not debug:\n",
    "        write_dataset(data, 'spa_rel', split_name, view_mode, only_trajectory)\n",
    "\n",
    "prepare_dt_spa_rel(train_0d8k, 'train-0d8k', only_trajectory=True, debug=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f89935d-1899-4652-b3a0-6c627054708a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "prepare_dt_spa_rel(train_0d8k, 'train-0d8k')\n",
    "prepare_dt_spa_rel(train_8k, 'train-8k')\n",
    "prepare_dt_spa_rel(train_80k, 'train-80k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dbdd355-900b-479d-8829-9f70e33cfdc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_dt_spa_rel(train_0d8k, 'train-0d8k', only_trajectory=True)\n",
    "prepare_dt_spa_rel(train_8k, 'train-8k', only_trajectory=True)\n",
    "prepare_dt_spa_rel(train_80k, 'train-80k', only_trajectory=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "546772d2-f645-44c9-a7ce-42dc61256ca9",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Aux. Dataset 6: Temporal relationship"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5bfb18-102c-4913-a820-353166c00327",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "prompt_tmp_rel = [\n",
    "    \"The image shows a scene at the first timestamp, while the second image described as {scene} shows the next timestamp. Can you describe the change in the relative location of {Putego} compared to {Putref} between these two timestamps? Use relative distance words like getting closer or further away, etc. Also, find the change in the 2D center distance and the Euclidean center distance between the two images. Your output must follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"In the first timestamp, the image shows a scene, and the second image described as {scene} depicts the next timestamp. Can you describe the change in the relative location of {Putego} compared to {Putref} between these two timestamps? Use terms like getting closer or moving further away, etc. Additionally, find the change in the 2D center distance and the Euclidean center distance between the two images. Your output must follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"The scene in the first image is at the initial timestamp, and the second image described as {scene} shows the subsequent timestamp. Can you explain the change in the relative location of {Putego} compared to {Putref} between these two timestamps? Use words like getting closer or moving further apart, etc. Also, calculate the change in the 2D center distance and the Euclidean center distance between the two images. Your output should be formatted as follows: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"At the first timestamp, the image shows a scene, and the second image described as {scene} represents the next timestamp. Can you detail the change in the relative location of {Putego} compared to {Putref} between these two timestamps? Use relative distance words like moving closer or getting further away, etc. Additionally, find the change in the 2D center distance and the Euclidean center distance between the two images. Your output must follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"The first image shows a scene at an initial timestamp, and the second image described as {scene} depicts the next timestamp. Can you describe the change in the relative position of {Putego} compared to {Putref} between these two timestamps? Use terms such as getting closer or moving further apart, etc. Also, determine the change in the 2D center distance and the Euclidean center distance between the two images. Your output should follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"The initial timestamp shows a scene in the first image, and the second image described as {scene} represents the next timestamp. Can you describe how the relative location of {Putego} compared to {Putref} changes between these two timestamps? Use relative distance words like getting closer or moving further away, etc. Also, find the change in the 2D center distance and the Euclidean center distance between the two images. Your output must be in this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"The image shows a scene at the first timestamp, and the second image described as {scene} shows the subsequent timestamp. Can you detail the change in the relative location of {Putego} compared to {Putref} between these two timestamps? Use words like getting closer or moving further apart, etc. Also, calculate the change in the 2D center distance and the Euclidean center distance between the two images. Your output should be formatted as follows: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"At the initial timestamp, the image shows a scene, and the second image described as {scene} depicts the next timestamp. Can you describe the change in the relative position of {Putego} compared to {Putref} between these two timestamps? Use relative distance terms such as getting closer or moving further away, etc. Also, find the change in the 2D center distance and the Euclidean center distance between the two images. Your output must follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"The scene in the first image is at the initial timestamp, and the second image described as {scene} shows the following timestamp. Can you describe the change in the relative location of {Putego} compared to {Putref} between these two timestamps? Use words like getting closer or moving further apart, etc. Additionally, calculate the change in the 2D center distance and the Euclidean center distance between the two images. Your output should follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"The first image shows a scene at an initial timestamp, and the second image described as {scene} depicts the next timestamp. Can you explain how the relative location of {Putego} compared to {Putref} changes between these two timestamps? Use relative distance words like moving closer or getting further away, etc. Also, determine the change in the 2D center distance and the Euclidean center distance between the two images. Your output must follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"The image shows a scene at the initial timestamp, and the second image described as {scene} shows the next timestamp. Can you describe the change in the relative position of {Putego} compared to {Putref} between these two timestamps? Use words like getting closer or moving further apart, etc. Also, calculate the change in the 2D center distance and the Euclidean center distance between the two images. Your output should follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"At the first timestamp, the image shows a scene, and the second image described as {scene} depicts the next timestamp. Can you detail how the relative location of {Putego} compared to {Putref} changes between these two timestamps? Use relative distance terms such as moving closer or getting further away, etc. Also, find the change in the 2D center distance and the Euclidean center distance between the two images. Your output must follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"The initial timestamp shows a scene in the first image, and the second image described as {scene} represents the next timestamp. Can you describe the change in the relative location of {Putego} compared to {Putref} between these two timestamps? Use words like getting closer or moving further apart, etc. Also, determine the change in the 2D center distance and the Euclidean center distance between the two images. Your output should follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\",\n",
    "    \"The first image shows a scene at the initial timestamp, and the second image described as {scene} shows the following timestamp. Can you describe how the relative position of {Putego} compared to {Putref} changes between these two timestamps? Use terms like moving closer or getting further away, etc. Additionally, find the change in the 2D center distance and the Euclidean center distance between the two images. Your output must follow this format: {example}. The coordinates are image coordinates normalized to a scale of 0 to 1 starting from the top left corner.\"\n",
    "]\n",
    "\n",
    "def get_temporal_relationship(obj_list1, obj_list2):\n",
    "    qa_pairs = []\n",
    "\n",
    "    keys = obj_list1.keys()\n",
    "    assert keys == obj_list2.keys()\n",
    "    \n",
    "    for ego in keys:\n",
    "        obj_ego_bbox1 = obj_list1[ego]\n",
    "        obj_ego_bbox2 = obj_list2[ego]\n",
    "        \n",
    "        for ref in keys:\n",
    "            if ego == ref:\n",
    "                continue\n",
    "            obj_ref_bbox1 = obj_list1[ref]\n",
    "            obj_ref_bbox2 = obj_list2[ref]\n",
    "\n",
    "            dist_1, dist_1_2d = calc_dist(obj_ego_bbox1[:2], obj_ref_bbox1[:2])\n",
    "            dist_2, dist_2_2d = calc_dist(obj_ego_bbox2[:2], obj_ref_bbox2[:2])\n",
    "            \n",
    "            del_dist = dist_2 - dist_1\n",
    "            del_dist_2d = dist_2_2d - dist_1_2d\n",
    "            \n",
    "            if del_dist > 0.01:\n",
    "                txt = f\"{ref} moves far away from {ego}. \"\n",
    "            elif del_dist < -0.01:\n",
    "                txt = f\"{ref} moves closer to {ego}. \"\n",
    "            else:\n",
    "                txt = f\"The distance between {ref} and {ego} does not change. \"\n",
    "                \n",
    "            ques = '<image>/n' + random.choice(prompt_tmp_rel).replace(\"{Putego}\", ego).replace(\"{Putref}\", ref)\n",
    "            answer = f\"{txt}2d center distance (x,y) of {ego} from {ref} changes by <d>({del_dist_2d[0]:.3f}, {del_dist_2d[1]:.3f})</d> and Euclidean center distance between them <e>{del_dist:.3f}</e>.\\n\"\n",
    "            qa_pairs.append((ques, answer))\n",
    "    return qa_pairs\n",
    "\n",
    "\n",
    "def prepare_dt_tmp_rel(dkeys, split_name, view_mode='front', debug=False):\n",
    "    def gen_example(local_image_path, obj_list1, obj_list2):\n",
    "        objs1 = {obj_desc: bbox for bbox, obj_desc in obj_list1}\n",
    "        objs2 = {obj_desc: bbox for bbox, obj_desc in obj_list2}\n",
    "        scene_desc2 = get_scene_desc(obj_list2)\n",
    "        qa_pairs = get_temporal_relationship(objs1, objs2)\n",
    "        result = []\n",
    "\n",
    "        for i in range(len(qa_pairs)):\n",
    "            q, a = qa_pairs[i]    \n",
    "            old_a = qa_pairs[i - 1][1]\n",
    "            if old_a.endswith('.'):\n",
    "                old_a = old_a[:-1]\n",
    "            \n",
    "            result.append({\n",
    "                'id': f'tmp_rel/{dkey}/{local_image_path}',\n",
    "                'image': [f'{dkey}/{local_image_path}'],\n",
    "                'conversations': [\n",
    "                    {\n",
    "                        'from': 'human',\n",
    "                        'value': q.replace('{example}', old_a).replace('{scene}', scene_desc2),\n",
    "                    },\n",
    "                    {\n",
    "                        'from': 'gpt',\n",
    "                        'value': a\n",
    "                    }\n",
    "                ]\n",
    "            })\n",
    "        return result\n",
    "\n",
    "    data = []\n",
    "    # state {text, visual} x {no, name, colorful name}\n",
    "    for dkey in tqdm(dkeys):\n",
    "        episode = get_all_files_for_episode(dkey)\n",
    "        \n",
    "        traj_meta = episode['trajectory.json']\n",
    "        obj_id_to_info  = traj_meta['obj_id_to_info']\n",
    "        obs = episode['obs.pkl']\n",
    "        num_of_steps = obs['ee'].shape[0]\n",
    "\n",
    "        last_list = None\n",
    "        for obs_i in range(num_of_steps):\n",
    "            local_image_path = f'rgb_{view_mode}/{obs_i - 1}.jpg'\n",
    "            obj_list = get_obj_list_from_segm_mask(episode[f'segm_{view_mode}/{obs_i}.png'], obj_id_to_info)\n",
    "            if last_list:\n",
    "                data.extend(gen_example(local_image_path, last_list, obj_list))\n",
    "            if debug and obs_i > 0:\n",
    "                print(local_image_path)\n",
    "                for t in range(len(last_list)):            \n",
    "                    draw_norm_bbox(episode[local_image_path], last_list[t][0], dp=t == len(last_list) - 1)\n",
    "            last_list = obj_list\n",
    "\n",
    "        if debug:\n",
    "            print(json.dumps(data, indent=2))\n",
    "            break\n",
    "    \n",
    "    if not debug:\n",
    "        write_dataset(data, 'tmp_rel', split_name, view_mode, False)\n",
    "\n",
    "prepare_dt_tmp_rel(train_0d8k, 'train-0d8k', debug=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77f9a25c-55ec-4bf7-a0b5-30c3327dcf01",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_dt_tmp_rel(train_0d8k, 'train-0d8k')\n",
    "prepare_dt_tmp_rel(train_8k, 'train-8k')\n",
    "prepare_dt_tmp_rel(train_80k, 'train-80k')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02e837a4-b692-4b55-bf7e-67252a8e66fb",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Step 4 Genearte Instruction Tuning BC\n",
    "We will have `inBC`, `RT-2 style` and the `D-` version of these two datasets.\n",
    "\n",
    "**Update**: Turn `is_rt_raw_action` to `False` to generate the datasets of `RT-2 style (I)` baselines."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b140477-3911-4482-ab22-0f0c58fea9b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "import torch \n",
    "tokenizer = AutoTokenizer.from_pretrained(\"llava-hf/llava-1.5-7b-hf\")\n",
    "print(tokenizer.vocab_size)\n",
    "\n",
    "# this is for RT-2\n",
    "def quantize_float_to_text(num: np.ndarray, tokenizer, token_pool_start=31000, token_num=256):\n",
    "    assert np.all(num <= 1) and np.all(num >= 0) \n",
    "    ids = token_pool_start + np.round(num * (token_num - 1))\n",
    "    ids = ids.astype(int)\n",
    "    return tokenizer.decode(ids.astype(int))\n",
    "\n",
    "\n",
    "def text_to_float(s: str, tokenizer, token_pool_start=31000, token_num=256):\n",
    "    ids = tokenizer.encode(s)\n",
    "    return [float(i - token_pool_start) / (token_num - 1) for i in ids if i >= token_pool_start and i < token_pool_start + token_num]\n",
    "\n",
    "\n",
    "def get_rt2_action_info(episode, action, step, view_mode):\n",
    "    act_pos_start = action['pose0_position'][step]\n",
    "    act_pos_end = action['pose1_position'][step]\n",
    "    act_rot_start = action['pose0_rotation'][step]\n",
    "    act_rot_end = action['pose1_rotation'][step]\n",
    "\n",
    "    # Convert to Euler angles\n",
    "    e1 = R.from_quat(act_rot_start).as_euler('xyz', degrees=True)\n",
    "    e2 = R.from_quat(act_rot_end).as_euler('xyz', degrees=True)\n",
    "    \n",
    "    rotation_degree = int(e2[-1] - e1[-1]) \n",
    "    \n",
    "    w, h = episode[f'rgb_{view_mode}/0.jpg'].size\n",
    "    \n",
    "    norm_action = []\n",
    "    norm_action.append((act_pos_start[0] - 0.25) * 2) # 0.25, 0.75 => 0, 1\n",
    "    norm_action.append(act_pos_start[1] + 0.5) # -0.5, 0.5 => 0, 1\n",
    "    norm_action.append((act_pos_end[0] - 0.25) * 2) # 0.25, 0.75 => 0, 1\n",
    "    norm_action.append(act_pos_end[1] + 0.5) # -0.5, 0.5 => 0, 1\n",
    "    norm_action.append((float(rotation_degree) + 180) / 360) # -180, 180 => 0, 1\n",
    "    \n",
    "    return step, quantize_float_to_text(np.array(norm_action), tokenizer)\n",
    "\n",
    "def get_image_corr_rt2_action_info(episode, action, step, view_mode):\n",
    "    start_x, start_y, end_x, end_y, rotation_degree = get_image_corr_action(episode, action, step, view_mode)\n",
    "    norm_action = []\n",
    "    norm_action.append(start_x)\n",
    "    norm_action.append(start_y)\n",
    "    norm_action.append(end_x)\n",
    "    norm_action.append(end_y)\n",
    "    norm_action.append((float(rotation_degree) + 180) / 360) # -180, 180 => 0, 1\n",
    "    return step, quantize_float_to_text(np.array(norm_action), tokenizer)\n",
    "\n",
    "\n",
    "def prepare_behavior_cloning_dataset(dkeys,\n",
    "                                     split_name,\n",
    "                                     state_mode='text', # options: text, visual\n",
    "                                     action_mode='multi', # options: single, multi\n",
    "                                     is_rt = False, \n",
    "                                     is_rt_raw_action = True,\n",
    "                                     view_mode='front', # options: front top\n",
    "                                     describe=True,\n",
    "                                     debug=False):\n",
    "\n",
    "    if describe:\n",
    "        det_dataset = json.load(open(f'aux-obj_det-{split_name}-{view_mode}.json', 'r'))\n",
    "        det_dataset = {i['id']: i for i in det_dataset}\n",
    "        \n",
    "        loc_dataset = json.load(open(f'aux-obj_loc-{split_name}-{view_mode}.json', 'r'))\n",
    "        loc_dataset = {i['id']: i for i in loc_dataset}\n",
    "    \n",
    "    data = []\n",
    "    # state {text, visual} x {no, name, colorful name}\n",
    "    for dkey in tqdm(dkeys):\n",
    "        episode = get_all_files_for_episode(dkey)\n",
    "\n",
    "        example_base = {\n",
    "            'image': [],\n",
    "            'conversations': []\n",
    "        }\n",
    "\n",
    "        image_idx = 0\n",
    "\n",
    "        def get_img_token(image_idx):\n",
    "            return f'<image{image_idx}>'\n",
    "\n",
    "        traj_meta = episode['trajectory.json']\n",
    "        prompt = traj_meta['prompt']\n",
    "        obs = episode['obs.pkl']\n",
    "        obj_id_to_info  = traj_meta['obj_id_to_info']\n",
    "        action = episode['action.pkl']\n",
    "        num_of_steps = obs['ee'].shape[0]\n",
    "  \n",
    "        if debug:\n",
    "            print(dkey)\n",
    "            print(prompt)\n",
    "        \n",
    "        # prompt assets, the objects mentioned in the prompt\n",
    "        assets = traj_meta['prompt_assets']\n",
    "        \n",
    "        p = f'<task>{prompt}</task>\\n'\n",
    "        # print(episode.keys())\n",
    "        # replace the prompt with pure text\n",
    "        for k in assets:\n",
    "            obj_info = assets[k]['segm_obj_info']\n",
    "            if isinstance(obj_info, list):\n",
    "                to_rep = k # multiple objects, it is a scene\n",
    "                if describe:\n",
    "                    to_rep = det_dataset[f'obj_det/{dkey}/rgb_{view_mode}/a_{k}.png']['conversations'][1]['value']\n",
    "            else:\n",
    "                to_rep = obj_format(obj_info)\n",
    "                if describe:\n",
    "                    to_rep = loc_dataset[f'obj_loc/{dkey}/rgb_{view_mode}/a_{k}.png']['conversations'][1]['value'][:-1] # remove the last dot\n",
    "                    \n",
    "            if state_mode == 'text':\n",
    "                pass\n",
    "            elif state_mode == 'visual':\n",
    "                example_base['image'].append(f'{dkey}/rgb_{view_mode}/a_{k}.png')\n",
    "                if debug:\n",
    "                    print(k)\n",
    "                    display(episode[f'rgb_{view_mode}/a_{k}.png'])\n",
    "                to_rep = get_img_token(image_idx) + to_rep\n",
    "                image_idx += 1\n",
    "            else:\n",
    "                raise NotImplementedError\n",
    "            p = p.replace('{' + k + '}', to_rep)\n",
    "\n",
    "        if not is_rt:\n",
    "            p += get_action_prompt(obs['ee'][0] > 0)\n",
    "            \n",
    "        actions = []\n",
    "        # get action\n",
    "        for i in range(num_of_steps - 1):\n",
    "            if i != num_of_steps - 1:\n",
    "                if is_rt:\n",
    "                    if is_rt_raw_action:\n",
    "                        actions.append(get_rt2_action_info(episode, action, i, view_mode))\n",
    "                    else:\n",
    "                        actions.append(get_image_corr_rt2_action_info(episode, action, i, view_mode))\n",
    "                else:\n",
    "                    actions.append(get_action_info(episode, action, i, obj_id_to_info, view_mode, obs['ee'][i] > 0))\n",
    "\n",
    "        example_base['conversations'].append({\n",
    "            'from': 'human',\n",
    "            'value': get_img_token(image_idx) + '\\n' + p,\n",
    "        })\n",
    "        for action_i, action_tuple in enumerate(actions):\n",
    "            example = copy.deepcopy(example_base)\n",
    "            example['id'] = f'{dkey}_{action_i}'\n",
    "            example['image'].append(f'{dkey}/rgb_{view_mode}/{action_i}.jpg')\n",
    "\n",
    "            if action_i > 0:\n",
    "                example['conversations'][0]['value'] += '\\nYou have finished: ' + '\\n'.join([f'Step {i + 1}: ' + re.sub(r'<p>.+?</p>', 'object', a[1]) for i, a in enumerate(actions[:action_i])])\n",
    "\n",
    "            if action_mode == 'single':\n",
    "                example['conversations'].append({\n",
    "                    'from': 'gpt',\n",
    "                    'value': f'Step {action_i + i + 1}: {actions[action_i][1]}',\n",
    "                })\n",
    "            elif action_mode == 'multi':\n",
    "                example['conversations'].append({\n",
    "                    'from': 'gpt',\n",
    "                    'value': '\\n'.join([f'Step {action_i + i + 1}: {actions[i + action_i][1]}' for i in range(len(actions[action_i:]))]),\n",
    "                })\n",
    "            else:\n",
    "                raise NotImplementedError\n",
    "\n",
    "            if debug:\n",
    "                display(episode[f'rgb_{view_mode}/{action_i}.jpg'])\n",
    "                \n",
    "            data.append(example)\n",
    "            \n",
    "        if debug:\n",
    "            display(episode[f'rgb_{view_mode}/{num_of_steps - 1}.jpg'])\n",
    "            print(json.dumps(data, indent=2))\n",
    "            break\n",
    "            \n",
    "    if not debug:\n",
    "        method_name = 'RT2' if is_rt else 'inBC'\n",
    "        desc = 'D-' if describe else ''\n",
    "        \n",
    "        filename = f'{desc}{method_name}-{state_mode}-{action_mode}-{split_name}-{view_mode}.json'\n",
    "        if not is_rt_raw_action:\n",
    "            filename = 'I-' + filename\n",
    "        with open(filename, 'w') as f:\n",
    "            json.dump(data, f, indent=2)\n",
    "        print(f'Write {len(data)} samples to {filename}.')\n",
    "\n",
    "\n",
    "prepare_behavior_cloning_dataset(train_0d8k, 'train-0d8k', state_mode='text', describe=True, is_rt=True, is_rt_raw_action=False, debug=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0f02956-82b4-43bc-9b1b-3dbacff49e78",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_behavior_cloning_dataset(train_0d8k, 'train-0d8k', describe=True)\n",
    "prepare_behavior_cloning_dataset(train_0d8k, 'train-0d8k', describe=False)\n",
    "prepare_behavior_cloning_dataset(train_8k, 'train-8k', describe=True)\n",
    "prepare_behavior_cloning_dataset(train_8k, 'train-8k', describe=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6794b3d7-6091-4284-bdac-bb4a117eda48",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_behavior_cloning_dataset(train_80k, 'train-80k', describe=True)\n",
    "prepare_behavior_cloning_dataset(train_80k, 'train-80k', describe=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f398eb02-77e7-4bf2-81c6-3ab7b7ab89aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_behavior_cloning_dataset(train_0d8k, 'train-0d8k', is_rt=True, describe=True)\n",
    "prepare_behavior_cloning_dataset(train_0d8k, 'train-0d8k', is_rt=True, describe=False)\n",
    "prepare_behavior_cloning_dataset(train_8k, 'train-8k', is_rt=True, describe=True)\n",
    "prepare_behavior_cloning_dataset(train_8k, 'train-8k', is_rt=True, describe=False)\n",
    "prepare_behavior_cloning_dataset(train_80k, 'train-80k', is_rt=True, describe=True)\n",
    "prepare_behavior_cloning_dataset(train_80k, 'train-80k', is_rt=True, describe=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76419c34-172f-44c7-87ac-bb1600e1987a",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_behavior_cloning_dataset(train_0d8k, 'train-0d8k', is_rt=True, describe=True, is_rt_raw_action=False)\n",
    "prepare_behavior_cloning_dataset(train_0d8k, 'train-0d8k', is_rt=True, describe=False, is_rt_raw_action=False)\n",
    "prepare_behavior_cloning_dataset(train_8k, 'train-8k', is_rt=True, describe=True, is_rt_raw_action=False)\n",
    "prepare_behavior_cloning_dataset(train_8k, 'train-8k', is_rt=True, describe=False, is_rt_raw_action=False)\n",
    "prepare_behavior_cloning_dataset(train_80k, 'train-80k', is_rt=True, describe=True, is_rt_raw_action=False)\n",
    "prepare_behavior_cloning_dataset(train_80k, 'train-80k', is_rt=True, describe=False, is_rt_raw_action=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b780b230-1fff-474a-83e4-0b0030c2e8fb",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Step 5 Merge Datasets\n",
    "Merge the main Inst-BC dataset with other auxiliary datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f419c6fa-7074-4bc3-ae53-4bf981a7a2ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge datasets\n",
    "\n",
    "def merge_dts(main_dt, aux_dt, ratio, split_name, view_mode='front'):\n",
    "    assert len(ratio) == len(aux_dt)\n",
    "    merge_target = '_'.join([i.split(f'-{view_mode}')[0].replace(f'-{split_name}', '') for i in [main_dt] + aux_dt]) +  '_'.join([f'{i}' for i in ratio]) + f'-{split_name}-{view_mode}.json'                     \n",
    "    # print(merge_target)\n",
    "    \n",
    "    main_db = json.load(open(main_dt, 'r'))\n",
    "\n",
    "    if osp.exists(merge_target):\n",
    "        print('We already have ', merge_target)\n",
    "        db = json.load(open(merge_target))\n",
    "        print(f'Found {len(db)} samples, which is {len(db) // len(main_db)} times of the main dataset size.')\n",
    "        return\n",
    "\n",
    "    aux_db = [json.load(open(i, 'r')) for i in aux_dt]\n",
    "    merged_db = main_db\n",
    "    n = len(main_db)\n",
    "    for i, j in zip(aux_db, ratio):\n",
    "        random.shuffle(i)\n",
    "        merged_db += i[:int(n * j)]\n",
    "    \n",
    "    with open(merge_target, 'w') as f:\n",
    "        json.dump(merged_db, f, indent=2)\n",
    "        print('Dump to ' + merge_target, f'{len(merged_db)} samples.')\n",
    "\n",
    "\n",
    "# one example\n",
    "aux_dt = ['aux-obj_det-train-0d8k-front.json', 'aux-obj_loc-train-0d8k-front.json', 'aux-act_pred-train-0d8k-front.json', 'aux-fut_pred-train-0d8k-front.json']\n",
    "ratio = [1] * len(aux_dt)\n",
    "merge_dts('D-inBC-text-multi-train-0d8k-front.json', aux_dt, ratio, 'train-0d8k')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba6b2695-a307-478a-b004-8f05ef58124d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## [Optional] Prepare the segmentation dataset for MaskRCNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3992964b-6289-4e2a-9b1e-b9e4e6d7d611",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def prepare_segm_dataset(dkeys, split_name, view_mode='front', debug=False):\n",
    "    data = []\n",
    "\n",
    "    def append_obj_label(obj_id, obj_name, texture_name):\n",
    "        if debug:\n",
    "            print(obj_id, obj_name, texture_name)\n",
    "        obj_labels.append({\n",
    "            'id': int(obj_id),\n",
    "            'cls': obj_name,\n",
    "            'color': texture_name\n",
    "        })\n",
    "        \n",
    "    def append_sample(local_image_path, local_seg_path, obj_labels):\n",
    "        if debug:\n",
    "            print(local_image_path, np.unique(local_image_seg))\n",
    "            display(episode[local_image_path])\n",
    "            display(Image.fromarray(local_image_seg * 30))\n",
    "\n",
    "        data.append({\n",
    "            'image_path': dkey + '/' + local_image_path,\n",
    "            'mask_path': dkey + '/' + local_seg_path,\n",
    "            'object': obj_labels\n",
    "        })\n",
    "\n",
    "    \n",
    "    for dkey in tqdm(dkeys):\n",
    "        episode = get_all_files_for_episode(dkey)\n",
    "\n",
    "        obs = episode['obs.pkl']\n",
    "        traj_meta = episode['trajectory.json']\n",
    "        obj_id_to_info  = traj_meta['obj_id_to_info']\n",
    "        num_of_steps = obs['ee'].shape[0]\n",
    "\n",
    "        # iterate over all trajectory steps\n",
    "        for obs_i in range(num_of_steps):\n",
    "            local_image_path = f'rgb_{view_mode}/{obs_i}.jpg'\n",
    "            local_seg_path = f'segm_{view_mode}/{obs_i}.png'\n",
    "            local_image_seg = np.array(episode[local_seg_path])\n",
    "\n",
    "            obj_labels = []\n",
    "            for obj_id, obj_info in obj_id_to_info.items():\n",
    "                obj_name = obj_info['obj_name']\n",
    "                texture_name = obj_info['texture_name']\n",
    "                append_obj_label(obj_id, obj_name, texture_name)\n",
    "                \n",
    "            append_sample(local_image_path, local_seg_path, obj_labels)\n",
    "                \n",
    "        # iterate over all reference images\n",
    "        assets = traj_meta['prompt_assets']\n",
    "        for k in assets:\n",
    "            local_image_path = f'rgb_{view_mode}/a_{k}.png'\n",
    "            local_seg_path = f'segm_{view_mode}/a_{k}.png'\n",
    "            local_image_seg = np.array(episode[local_seg_path])\n",
    "            \n",
    "            obj_id_to_info = assets[k]['segm_obj_info']\n",
    "            if not isinstance(obj_id_to_info, list):\n",
    "                obj_id_to_info = [obj_id_to_info]\n",
    "\n",
    "            obj_labels = []\n",
    "            for obj_info in obj_id_to_info:\n",
    "                obj_id = obj_info['obj_id']\n",
    "                obj_name = obj_info['obj_name']\n",
    "                texture_name = obj_info['obj_color']\n",
    "                \n",
    "                append_obj_label(obj_id, obj_name, texture_name)\n",
    "                \n",
    "            append_sample(local_image_path, local_seg_path, obj_labels)\n",
    "\n",
    "        if debug:\n",
    "            print(json.dumps(data, indent=2))\n",
    "            break\n",
    "    if not debug:\n",
    "        with open(f'maskrcnn-{split_name}-{view_mode}.json', 'w') as f:\n",
    "            json.dump(data, f)\n",
    "\n",
    "# A demo to show the sample in this dataset\n",
    "prepare_segm_dataset(train_0d8k, 'train-0d8k', debug=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96201ddb-49f8-4e03-8e7d-deb625f60a9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_segm_dataset(train_0d8k, 'train-0d8k')\n",
    "prepare_segm_dataset(train_8k, 'train-8k')\n",
    "prepare_segm_dataset(train_80k, 'train-80k')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sel",
   "language": "python",
   "name": "sel"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
