{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 834860389382879,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "import os\n",
        "\n",
        "# Import local configuration from macros.py (gitignored)\n",
        "# Run `python setup_macros.py` to generate it from your environment variables\n",
        "# TODO: Hardcode your path to the jepa-wms repo here\n",
        "jepa_dir = \"/home/name/jepa-wms\"\n",
        "\n",
        "sys.path.insert(0, jepa_dir)\n",
        "from macros import JEPAWM_HOME, JEPAWM_LOGS\n",
        "local_plan_common_dir = os.path.join(jepa_dir, \"app/plan_common/local\")\n",
        "print(f\"{local_plan_common_dir=}\")\n",
        "\n",
        "from app.plan_common.plot.logs_plan_joint_unif_utils import (\n",
        "    collect_task_eval_data,\n",
        "    plot_task_eval_data,\n",
        ")\n",
        "from app.plan_common.plot.aliases import eval_setup_aliases"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 757696170056722,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [],
      "source": [
        "base_dir = os.path.join(local_plan_common_dir, 'paper-app')\n",
        "base_dir = os.path.join(base_dir, 'mw_sweep')\n",
        "os.makedirs(base_dir, exist_ok=True)\n",
        "\n",
        "model_training_folders = [\n",
        "      # Central run\n",
        "    (os.path.join(JEPAWM_LOGS, 'mw_final_sweep/mw_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_save_seed1'),\n",
        "     'WM'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mw_final_sweep/mw_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_hist7'),\n",
        "      r'$\\text{WM}_W$'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mw_final_sweep/mw_4f_fsk5_ask1_r224_pred_dino_wm_depth6_repro_1roll_save'),\n",
        "     'WM-prop'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mw_final_sweep/mw_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_2roll'),\n",
        "     'WM-2-step'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mw_final_sweep/mw_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_3roll'),\n",
        "     'WM-3-step'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mw_final_sweep/mw_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_6roll_hist7'),\n",
        "      r'$\\text{WM}_W$-6-step'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mw_final_sweep/mw_4f_fsk5_ask1_r224_pred_dino_wm_dinovitb_depth6_noprop_repro_1roll_save'),\n",
        "     'WM-B'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mw_final_sweep/mw_4f_fsk5_ask1_r224_pred_dino_wm_dinovitl_depth6_noprop_repro_1roll_save'),\n",
        "     'WM-L'),\n",
        "]\n",
        "\n",
        "task_subset = [\"mw-reach\", \"mw-reach-wall\"]\n",
        "\n",
        "task_eval_data = collect_task_eval_data(\n",
        "  model_training_folders,\n",
        "  task_subset,\n",
        "  eval_setup_aliases=eval_setup_aliases,\n",
        ")\n",
        "\n",
        "plot_task_eval_data(\n",
        "  task_eval_data, y_values=['SR'], smooth=True, alpha=0.2,\n",
        "  show_original=True, base_dir=base_dir, put_title=True, eval_setup_aliases=eval_setup_aliases,\n",
        "  truncate_epoch=50, y_min=0, y_max=90,\n",
        "  average_seeds=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 1879828612621588,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [],
      "source": [
        "base_dir = os.path.join(local_plan_common_dir, 'paper-app')\n",
        "base_dir = os.path.join(base_dir, 'wall_sweep')\n",
        "os.makedirs(base_dir, exist_ok=True)\n",
        "\n",
        "model_training_folders = [\n",
        "    (os.path.join(JEPAWM_LOGS,'wall_sweep/wall_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_save_2n'),\n",
        "     'WM'),\n",
        "    (os.path.join(JEPAWM_LOGS,'wall_sweep/wall_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_save_hist7_2n_lightfreq100'),\n",
        "      r'$\\text{WM}_W$'),\n",
        "    (os.path.join(JEPAWM_LOGS,'wall_sweep/wall_4f_fsk5_ask1_r224_pred_dino_wm_depth6_repro_1roll_save_2n'),\n",
        "     'WM-prop'),\n",
        "    (os.path.join(JEPAWM_LOGS,'wall_sweep/wall_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_2roll_save_2n'),\n",
        "     'WM-2-step'),\n",
        "    (os.path.join(JEPAWM_LOGS,'wall_sweep/wall_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_3roll_save_2n'),\n",
        "     'WM-3-step'),\n",
        "    (os.path.join(JEPAWM_LOGS,'wall_sweep/wall_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_6roll_hist7_bs4_save_2n'),\n",
        "      r'$\\text{WM}_W$-6-step'),\n",
        "    (os.path.join(JEPAWM_LOGS,'wall_sweep/wall_4f_fsk5_ask1_r224_pred_dino_wm_dinovitb_depth6_noprop_repro_1roll_save_2n'),\n",
        "     'WM-B'),\n",
        "    (os.path.join(JEPAWM_LOGS,'wall_sweep/wall_4f_fsk5_ask1_r224_pred_dino_wm_dinovitl_depth6_noprop_repro_1roll_save_2n'),\n",
        "     'WM-L'),\n",
        "]\n",
        "hist1_folders = [\n",
        "        os.path.join(JEPAWM_LOGS, \"vjepa_wm/wall_sweep/wall_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_hist1_save_2n\"),\n",
        "        os.path.join(JEPAWM_LOGS, 'vjepa_wm/wall_sweep/wall_dwm-repro-r196-hist1'),\n",
        "    ]\n",
        "task_subset = [\"wall\"]\n",
        "task_eval_data = collect_task_eval_data(\n",
        "    model_training_folders, task_subset,\n",
        "    eval_setup_aliases=eval_setup_aliases,\n",
        "    hist1_folders=hist1_folders,\n",
        ")\n",
        "\n",
        "plot_task_eval_data(\n",
        "  task_eval_data, y_values=['SR'], smooth=True, alpha=0.2,\n",
        "  show_original=True, base_dir=base_dir, put_title=True, eval_setup_aliases=eval_setup_aliases,\n",
        "  y_min=0, y_max=90,\n",
        "  average_seeds=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "output": {
          "id": 868148555586209,
          "loadingStatus": "loaded"
        }
      },
      "outputs": [],
      "source": [
        "base_dir = os.path.join(local_plan_common_dir, 'paper-app')\n",
        "base_dir = os.path.join(base_dir, 'pt_sweep')\n",
        "os.makedirs(base_dir, exist_ok=True)\n",
        "\n",
        "model_training_folders = [\n",
        "    (os.path.join(JEPAWM_LOGS, 'pt_sweep/pt_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_save'),\n",
        "     'WM'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'pt_sweep/pt_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_save_hist7'),\n",
        "      r'$\\text{WM}_W$'),\n",
        "    (os.path.join(JEPAWM_LOGS,'pt_sweep/pt_4f_fsk5_ask1_r224_pred_dino_wm_depth6_repro_1roll_save'),\n",
        "     'WM-prop'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'pt_sweep/pt_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_2roll_save'),\n",
        "     'WM-2-step'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'pt_sweep/pt_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_3roll_save'),\n",
        "     'WM-3-step'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'pt_sweep/pt_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_6roll_hist7_bs4_save'),\n",
        "      r'$\\text{WM}_W$-6-step'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'pt_sweep/pt_4f_fsk5_ask1_r224_pred_dino_wm_dinovitb_depth6_noprop_repro_1roll_save'),\n",
        "     'WM-B'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'pt_sweep/pt_4f_fsk5_ask1_r224_pred_dino_wm_dinovitl_depth6_noprop_repro_1roll_save'),\n",
        "     'WM-L'),\n",
        "    (os.path.join(JEPAWM_LOGS,'pt_sweep/pt_4f_fsk5_ask1_r224_predAdaLN0_noprop_depth6_repro_1roll'),\n",
        "     'AdaLN0'),\n",
        "    ]\n",
        "task_subset = [\"pt\"]\n",
        "hist1_folders = [\n",
        "    os.path.join(JEPAWM_LOGS, \"pt_sweep/pt_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_hist1_save\"),\n",
        "]\n",
        "task_eval_data = collect_task_eval_data(\n",
        "    model_training_folders, task_subset,\n",
        "    eval_setup_aliases=eval_setup_aliases,\n",
        "    hist1_folders=hist1_folders,\n",
        ")\n",
        "plot_task_eval_data(\n",
        "  task_eval_data, y_values=['SR'], smooth=True, alpha=0.2,\n",
        "  show_original=True, base_dir=base_dir, put_title=True, eval_setup_aliases=eval_setup_aliases,\n",
        "  y_min=0, y_max=90, truncate_epoch=50,\n",
        "  average_seeds=True,\n",
        ")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "base_dir = os.path.join(local_plan_common_dir, 'paper-app')\n",
        "base_dir = os.path.join(base_dir, 'mz_sweep')\n",
        "os.makedirs(base_dir, exist_ok=True)\n",
        "\n",
        "model_training_folders = [\n",
        "    (os.path.join(JEPAWM_LOGS, 'mz_sweep/mz_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_save_2n'),\n",
        "     'WM'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mz_sweep/mz_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_save_hist7_2n_lightfreq100'),\n",
        "      r'$\\text{WM}_W$'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mz_sweep/mz_4f_fsk5_ask1_r224_pred_dino_wm_depth6_repro_1roll_save_2n'),\n",
        "     'WM-prop'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mz_sweep/mz_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_2roll_save_2n'),\n",
        "     'WM-2-step'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mz_sweep/mz_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_3roll_save_2n'),\n",
        "     'WM-3-step'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mz_sweep/mz_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_6roll_hist7_bs4_save_2n'),\n",
        "      r'$\\text{WM}_W$-6-step'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mz_sweep/mz_4f_fsk5_ask1_r224_pred_dino_wm_dinovitl_depth6_noprop_repro_1roll_save_2n'),\n",
        "     'WM-L'),\n",
        "    (os.path.join(JEPAWM_LOGS, 'mz_sweep/mz_4f_fsk5_ask1_r224_pred_dino_wm_dinovitb_depth6_noprop_repro_1roll_save_2n'),\n",
        "     'WM-B'),\n",
        "    ]\n",
        "hist1_folders = [\n",
        "    os.path.join(JEPAWM_LOGS, \"mz_sweep/mz_4f_fsk5_ask1_r224_pred_dino_wm_depth6_noprop_repro_1roll_hist1_save_2n\"),\n",
        "]\n",
        "task_subset = [\"mz\"]\n",
        "task_eval_data = collect_task_eval_data(\n",
        "    model_training_folders, task_subset,\n",
        "    eval_setup_aliases=eval_setup_aliases,\n",
        "    hist1_folders=hist1_folders,\n",
        ")\n",
        "plot_task_eval_data(\n",
        "  task_eval_data, y_values=['SR'], smooth=True, alpha=0.2,\n",
        "  show_original=True, base_dir=base_dir, put_title=True, eval_setup_aliases=eval_setup_aliases,\n",
        "  y_min=0, y_max=100,\n",
        "  average_seeds=True,\n",
        ")\n"
      ]
    }
  ],
  "metadata": {
    "fileHeader": "",
    "fileUid": "4f5f8e8b-afd4-4c31-927b-76ef3f8057be",
    "isAdHoc": false,
    "kernelspec": {
      "display_name": "jepa-wms (conda)",
      "language": "python",
      "name": "conda_jepa-wms"
    },
    "language_info": {
      "name": "python"
    },
    "orig_nbformat": 4
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
