{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate STAIR's Performance in Stage-alignment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "from IPython.display import display\n",
    "import time\n",
    "from collections import defaultdict\n",
    "\n",
    "# setting\n",
    "env_name = 'window-close'  # ['door-open', 'window-close', 'window-open']\n",
    "long_env_name = 'metaworld_' + env_name + '-v2'\n",
    "METHODS = ['pebble', 'stair']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_queries(env_name, long_env_name, method):\n",
    "    queries_path = f'./query/{env_name}/{method}/'\n",
    "    if not os.path.exists(queries_path):\n",
    "        raise FileNotFoundError(f\"Queries path {queries_path} does not exist.\")\n",
    "    \n",
    "    queries = []\n",
    "    groups = defaultdict(dict)\n",
    "\n",
    "    \n",
    "    for root, dirs, files in os.walk(queries_path):\n",
    "        for file in files:\n",
    "            if file.endswith(\".gif\"):\n",
    "                \n",
    "                name = os.path.splitext(file)[0]\n",
    "                parts = name.split('_')\n",
    "                \n",
    "                env_name = parts[0]\n",
    "                method = parts[1]\n",
    "                seed = int(parts[2].replace(\"seed\", \"\"))     \n",
    "                step = int(parts[3])\n",
    "                count = int(parts[4])\n",
    "                query_id = parts[5]                          \n",
    "                en_idx = int(parts[6].replace(\"en\", \"\"))     \n",
    "                start_index = int(parts[7])\n",
    "                group_key = (env_name, method, seed, step, count)\n",
    "                \n",
    "                groups[group_key][query_id] = {\n",
    "                    \"en_idx\": en_idx,\n",
    "                    \"start_idx\": start_index,\n",
    "                    \"video_path\": os.path.join('.', 'query', env_name, method, file)\n",
    "                }\n",
    "\n",
    "    \n",
    "    for group_key, pairs in groups.items():\n",
    "        if \"0\" in pairs and \"1\" in pairs:  \n",
    "            q0 = pairs[\"0\"]\n",
    "            q1 = pairs[\"1\"]\n",
    "            queries.append({\n",
    "                \"seed\": group_key[2],\n",
    "                \"step\": group_key[3],\n",
    "                \"en_idx0\": q0[\"en_idx\"],\n",
    "                \"start_idx0\": q0[\"start_idx\"],\n",
    "                \"video_path0\": q0[\"video_path\"],\n",
    "                \"en_idx1\": q1[\"en_idx\"],\n",
    "                \"start_idx1\": q1[\"start_idx\"],\n",
    "                \"video_path1\": q1[\"video_path\"]\n",
    "            })\n",
    "\n",
    "    \n",
    "    print(f\"Generated {len(queries)} query pairs\")\n",
    "    print(queries[0])  \n",
    "    return queries\n",
    "\n",
    "obvious_step = {\n",
    "    \"door-open\": 500000,\n",
    "    \"window-close\": 300000, \n",
    "    \"window-open\": 300000,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "queries_dict = {\n",
    "    method: load_queries(env_name, long_env_name, method)\n",
    "    for method in METHODS\n",
    "}\n",
    "\n",
    "import datetime\n",
    "def getDataTimeString():\n",
    "    return datetime.datetime.now().strftime('%Y%m%d-%H%M%S')[2:]\n",
    "\n",
    "\n",
    "base_path = f\"./human_label/{env_name}/\"\n",
    "if not os.path.exists(base_path):\n",
    "    os.makedirs(base_path)\n",
    "output_csv = os.path.join(base_path, f\"human_fair_{getDataTimeString()}.csv\")\n",
    "if os.path.exists(output_csv):\n",
    "    df = pd.read_csv(output_csv)\n",
    "else:\n",
    "    df = pd.DataFrame(columns=[\n",
    "        'method',\n",
    "        'segment0_start_idx', 'segment0_episode', \n",
    "        'segment1_start_idx', 'segment1_episode', \n",
    "        'is_stage_aligned', \n",
    "    ])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompt\n",
    "\n",
    "### door-open\n",
    "\n",
    "The target behavior is that the robot arm smoothly rotates the door until it stays fully open at a clearly visible angle.\n",
    "\n",
    "### window-open\n",
    "The target behavior is that the window slides horizontally to a clearly open position with coordinated gripper guidance.\n",
    "\n",
    "### window-close\n",
    "The target behavior is that the window slides horizontally to a clearly close position with coordinated gripper guidance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output, HTML, Image\n",
    "import base64\n",
    "\n",
    "N_QUERIES = 10\n",
    "method_idxs = list(range(0, len(METHODS))) * N_QUERIES\n",
    "random.shuffle(method_idxs)\n",
    "n_all_queries = len(method_idxs)\n",
    "\n",
    "for i in range(n_all_queries):\n",
    "    clear_output()\n",
    "    print(f\"{i+1}th among total {n_all_queries} feedbacks\")\n",
    "    method = METHODS[method_idxs[i]]\n",
    "    queries = queries_dict[method]\n",
    "    query = random.sample(queries, 1)[0]\n",
    "    while int(query['step']) < obvious_step[env_name]:\n",
    "        query = random.sample(queries, 1)[0]\n",
    "\n",
    "    time.sleep(0.1)\n",
    "    display(HTML(f'''\n",
    "    <div style=\"display: inline-block; margin-right: 100px;\">\n",
    "        <p> Please enter '1' for stage-alignment, '0' for misalignment. </p>\n",
    "        <img src=\"{query['video_path0']}\" width=\"400\" loop=\"true\" >\n",
    "        <img src=\"{query['video_path1']}\" width=\"400\" loop=\"true\" >\n",
    "    </div>\n",
    "    '''))\n",
    "\n",
    "    time.sleep(1)\n",
    "    select = input(\"Please enter '1' for stage-alignment, '0' for misalignment.\")\n",
    "    if select == 'quit' or select == 'exit':\n",
    "        break\n",
    "    elif select not in ['1', '0']:\n",
    "        print(\"Invalid input. Please enter '1' for stage-alignment, '0' for misalignment.\")\n",
    "        continue\n",
    "    new_entry = {\n",
    "        'method': method,\n",
    "        'segment0_start_idx': query['start_idx0'],\n",
    "        'segment0_episode': query['en_idx0'],\n",
    "        'segment1_start_idx': query['start_idx1'],\n",
    "        'segment1_episode': query['en_idx1'],\n",
    "        'is_stage_aligned': select\n",
    "    }\n",
    "    df = pd.concat([df, pd.DataFrame([new_entry])], ignore_index=True)\n",
    "\n",
    "    df.to_csv(output_csv, index=False)\n",
    "    print(f\"\\nresult save to {output_csv}\")\n",
    "\n",
    "\n",
    "df.to_csv(output_csv, index=False)\n",
    "print(f\"\\nresult save to {output_csv}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "urlb4",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
