{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": 1,
            "id": "aac2cca0",
            "metadata": {},
            "outputs": [],
            "source": [
                "import torch\n",
                "import numpy as np\n",
                "import scipy.stats as st\n",
                "import sys\n",
                "import os\n",
                "sys.path.append(\"..\")\n",
                "\n",
                "from metamotivo.agents.fb.flow_bc.agent import FBFlowBCAgent\n",
                "from metamotivo.envs.ogbench import OGBenchEnvConfig, ALL_TASKS\n",
                "from metamotivo.data_loading.ogbench import OGBenchDataConfig\n",
                "from metamotivo.envs.utils.rollout import rollout\n",
                "\n",
                "def calculate_success(infos):\n",
                "    \"\"\"Computes binary success (True/False) for each episode in the rollout.\"\"\"\n",
                "    return [any([step.get(\"success\", False) for step in info]) for info in infos]\n",
                "\n",
                "def get_ogbench_config(results_path, domain, task, results_root=\"/home/jovyan/bobrin/td_jepa/results_fb_ogbench_proprio\"):\n",
                "    \"\"\"Robustly find best OGBench params from your sweep file, falling back to disk search if missing.\"\"\"\n",
                "    import json\n",
                "    from pathlib import Path\n",
                "    \n",
                "    ckpt_path = None\n",
                "    best_params = None\n",
                "    \n",
                "    # 1. Try to load from sweep results\n",
                "    if results_path and os.path.exists(results_path):\n",
                "        try:\n",
                "            with open(results_path, \"r\") as f:\n",
                "                data = json.load(f)\n",
                "                summary = data.get(\"summary\", {})\n",
                "                best_entry = None\n",
                "            for entry_data in summary.values():\n",
                "                if entry_data.get(\"domain\") == domain and entry_data.get(\"task\") == task:\n",
                "                    # Handle both 'best_zol_sr' and the older 'best_zol_score'\n",
                "                    score = entry_data.get(\"best_zol_sr\", entry_data.get(\"best_zol_score\", -1))\n",
                "                    if best_entry is None or score > best_entry.get(\"best_zol_sr\", best_entry.get(\"best_zol_score\", -1)):\n",
                "                        best_entry = entry_data\n",
                "            \n",
                "            if best_entry: \n",
                "                ckpt_path = best_entry.get(\"checkpoint\")\n",
                "                best_params = best_entry.get(\"best_zol_params\")\n",
                "        except (json.JSONDecodeError, IOError):\n",
                "            pass\n",
                "            \n",
                "    # 2. Fallback to searching on disk if checkpoint not found\n",
                "    if not ckpt_path:\n",
                "        root = Path(results_root)\n",
                "        domain_dir = root / domain\n",
                "        if domain_dir.exists():\n",
                "            for seed_dir in domain_dir.glob(\"*\"):\n",
                "                potential_ckpt = seed_dir / \"checkpoint\"\n",
                "                if potential_ckpt.exists():\n",
                "                    ckpt_path = str(potential_ckpt)\n",
                "                    print(f\"  Fallback: Found checkpoint on disk: {ckpt_path}\")\n",
                "                    break\n",
                "                    \n",
                "    return ckpt_path, best_params\n",
                "\n",
                "def run_ogbench_zol_task_evaluation(agent, task_env, env_cfg, domain, task, batch, zol_search_params=None):\n",
                "    \"\"\"Performs baseline eval, ZOL search, and post-search eval with Success Rate reporting.\"\"\"\n",
                "    print(f\"\\n--- [OGBench Task: {task}] ---\")\n",
                "    device = agent.device\n",
                "    \n",
                "    if zol_search_params is None:\n",
                "        zol_search_params = {\n",
                "            \"mu_source\": \"init\",      # Better for navigation\n",
                "            \"use_exp_weights\": True,\n",
                "            \"weight_temp\": 1.0,       # Lower temp for stability\n",
                "            \"mu_reward_top_frac\": 0.05,\n",
                "            \"self_normalized_obj\": True,\n",
                "        }\n",
                "    \n",
                "    # Use .detach() to avoid RuntimeError\n",
                "    relabel_fn = env_cfg.get_relabel_fn(task)\n",
                "    next_physics = batch[\"next\"][\"physics\"].detach().cpu().numpy()\n",
                "    actions = batch[\"action\"].detach().cpu().numpy()\n",
                "    rewards_np = relabel_fn(next_physics, actions)\n",
                "    \n",
                "    # CRITICAL: Reward Shift to [0, 1]\n",
                "    rewards_np += 1.0\n",
                "    \n",
                "    rewards = torch.tensor(rewards_np, dtype=torch.float32).to(device)\n",
                "    batch_obs = batch[\"next\"][\"observation\"].to(device)\n",
                "    \n",
                "    # Baseline\n",
                "    initial_z = agent._model.reward_inference(batch_obs, rewards.reshape(-1, 1))\n",
                "    print(f\"  Evaluating Baseline (100 episodes)...\")\n",
                "    base_stats, base_infos, _ = rollout(task_env, agent=agent._model, ctx=initial_z, num_episodes=100)\n",
                "    base_successes = calculate_success(base_infos)\n",
                "    base_sr_m, base_sr_ci = get_stats(base_successes)\n",
                "    base_rew_m, base_rew_ci = get_stats(base_stats['reward'])\n",
                "    print(f\"  Baseline: {base_sr_m*100:.1f}% success ({base_rew_m:.2f} reward)\")\n",
                "\n",
                "    # ZOL optimization\n",
                "    print(f\"  Optimizing z (ZOL Search)...\")\n",
                "    z_zol = agent.zol_latent_search(task_env, batch_obs, rewards.flatten(), initial_z, **zol_search_params)\n",
                "    \n",
                "    # Final Eval\n",
                "    print(f\"  Evaluating ZOL (100 episodes)...\")\n",
                "    zol_stats, zol_infos, _ = rollout(task_env, agent=agent._model, ctx=z_zol, num_episodes=100)\n",
                "    zol_successes = calculate_success(zol_infos)\n",
                "    zol_sr_m, zol_sr_ci = get_stats(zol_successes)\n",
                "    zol_rew_m, zol_rew_ci = get_stats(zol_stats['reward'])\n",
                "    print(f\"  ZOL:      {zol_sr_m*100:.1f}% success ({zol_rew_m:.2f} reward)\")\n",
                "    \n",
                "    return {\n",
                "        \"base_sr\": (base_sr_m, base_sr_ci),\n",
                "        \"zol_sr\": (zol_sr_m, zol_sr_ci)\n",
                "    }\n",
                "\n",
                "def get_stats(data):\n",
                "    mean = np.mean(data)\n",
                "    sem = st.sem(data)\n",
                "    ci = 1.96 * sem\n",
                "    return mean, ci"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 2,
            "id": "78b59b96",
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Loading data from: /home/jovyan/bobrin/td_jepa/ogbench_data/antmaze-large-navigate-v0/buffer\n",
                        "  Fallback: Found checkpoint on disk: /home/jovyan/bobrin/td_jepa/results_fb_ogbench_proprio/antmaze-large-navigate-v0/1/checkpoint\n",
                        "compile True\n",
                        "compiling with mode 'reduce-overhead'\n",
                        "cudagraphs False\n",
                        "\n",
                        "--- [OGBench Task: antmaze-large-navigate-singletask-task1-v0] ---\n",
                        "  Evaluating Baseline (100 episodes)...\n",
                        "  Baseline: 60.0% success (-786.76 reward)\n",
                        "  Optimizing z (ZOL Search)...\n",
                        "  Evaluating ZOL (100 episodes)...\n",
                        "  ZOL:      62.0% success (-776.51 reward)\n",
                        "\n",
                        "--- [OGBench Task: antmaze-large-navigate-singletask-task2-v0] ---\n",
                        "  Evaluating Baseline (100 episodes)...\n",
                        "  Baseline: 81.0% success (-664.59 reward)\n",
                        "  Optimizing z (ZOL Search)...\n",
                        "  Evaluating ZOL (100 episodes)...\n",
                        "  ZOL:      82.0% success (-669.18 reward)\n",
                        "\n",
                        "--- [OGBench Task: antmaze-large-navigate-singletask-task3-v0] ---\n",
                        "  Evaluating Baseline (100 episodes)...\n",
                        "  Baseline: 24.0% success (-904.72 reward)\n",
                        "  Optimizing z (ZOL Search)...\n",
                        "  Evaluating ZOL (100 episodes)...\n",
                        "  ZOL:      0.0% success (-1000.00 reward)\n",
                        "\n",
                        "--- [OGBench Task: antmaze-large-navigate-singletask-task4-v0] ---\n",
                        "  Evaluating Baseline (100 episodes)...\n",
                        "  Baseline: 75.0% success (-668.46 reward)\n",
                        "  Optimizing z (ZOL Search)...\n",
                        "  Evaluating ZOL (100 episodes)...\n",
                        "  ZOL:      79.0% success (-614.47 reward)\n",
                        "\n",
                        "--- [OGBench Task: antmaze-large-navigate-singletask-task5-v0] ---\n",
                        "  Evaluating Baseline (100 episodes)...\n",
                        "  Baseline: 86.0% success (-609.94 reward)\n",
                        "  Optimizing z (ZOL Search)...\n",
                        "  Evaluating ZOL (100 episodes)...\n",
                        "  ZOL:      93.0% success (-598.48 reward)\n",
                        "\n",
                        "Task                           | Baseline Success     | ZOL Success         \n",
                        "---------------------------------------------------------------------------\n",
                        "antmaze-large-navigate-singletask-task1-v0 |   60.0% ±  9.7%     |   62.0% ±  9.6%\n",
                        "antmaze-large-navigate-singletask-task2-v0 |   81.0% ±  7.7%     |   82.0% ±  7.6%\n",
                        "antmaze-large-navigate-singletask-task3-v0 |   24.0% ±  8.4%     |    0.0% ±  0.0%\n",
                        "antmaze-large-navigate-singletask-task4-v0 |   75.0% ±  8.5%     |   79.0% ±  8.0%\n",
                        "antmaze-large-navigate-singletask-task5-v0 |   86.0% ±  6.8%     |   93.0% ±  5.0%\n"
                    ]
                }
            ],
            "source": [
                "# Configuration\n",
                "OG_DOMAIN = \"antmaze-large-navigate-v0\" \n",
                "RESULTS_PATH = \"../zol_sweep_ogbench_results.json\"\n",
                "DATASET_ROOT = \"/home/jovyan/bobrin/td_jepa/ogbench_data\"\n",
                "\n",
                "# 1. Load Data for Reward Inference\n",
                "data_cfg = OGBenchDataConfig(domain=OG_DOMAIN, dataset_root=DATASET_ROOT)\n",
                "replay_buffer = data_cfg.build(buffer_device=\"cuda\", batch_size=10_000, frame_stack=1)\n",
                "batch = replay_buffer[\"train\"].sample(10_000)\n",
                "\n",
                "# 2. Iterate through tasks\n",
                "all_results = {}\n",
                "agent = None\n",
                "\n",
                "for i, task in enumerate(ALL_TASKS[OG_DOMAIN]):\n",
                "    env_cfg = OGBenchEnvConfig(domain=OG_DOMAIN, task=task)\n",
                "    task_env, _ = env_cfg.build()\n",
                "    \n",
                "    if agent is None:\n",
                "        ckpt_path, best_params = get_ogbench_config(RESULTS_PATH, OG_DOMAIN, task)\n",
                "        if ckpt_path is None:\n",
                "            raise ValueError(f\"Could not find checkpoint for {OG_DOMAIN}. Check your RESULTS_PATH or disk.\")\n",
                "            \n",
                "        agent = FBFlowBCAgent.load(\n",
                "            ckpt_path, \n",
                "            device=\"cuda\", \n",
                "            obs_space=task_env.observation_space, \n",
                "            action_dim=batch[\"action\"].shape[-1]\n",
                "        )\n",
                "        agent._model.train(False)\n",
                "    \n",
                "    # Update agent config if we have best params from sweep\n",
                "    search_kwargs = None\n",
                "    if best_params:\n",
                "        print(f\"  Applying sweep params: {best_params}\")\n",
                "        config_keys = {\"lr\", \"num_steps\", \"n_mu\", \"early_stop_patience\", \"early_stop_tol\", \n",
                "                    \"chi2_coef\", \"trust_l2_coef\", \"weight_clip\", \"center_rewards\"}\n",
                "        cfg_updates = {k: v for k, v in best_params.items() if k in config_keys}\n",
                "        search_kwargs = {k: v for k, v in best_params.items() if k not in config_keys}\n",
                "        \n",
                "        if cfg_updates:\n",
                "            agent.cfg = agent.cfg.model_copy(update={\n",
                "                \"train\": agent.cfg.train.model_copy(\n",
                "                    update={\"zol\": agent.cfg.train.zol.model_copy(update=cfg_updates)}\n",
                "                )\n",
                "            })\n",
                "            \n",
                "    res = run_ogbench_zol_task_evaluation(agent, task_env, env_cfg, OG_DOMAIN, task, batch, zol_search_params=search_kwargs)\n",
                "    all_results[task] = res\n",
                "    task_env.close()\n",
                "\n",
                "# 3. Print Final Summary Table\n",
                "print(f\"\\n{'Task':<30} | {'Baseline Success':<20} | {'ZOL Success':<20}\")\n",
                "print(\"-\" * 75)\n",
                "for task, res in all_results.items():\n",
                "    b_m, b_ci = res[\"base_sr\"]\n",
                "    z_m, z_ci = res[\"zol_sr\"]\n",
                "    print(f\"{task:<30} | {b_m*100:6.1f}% ± {b_ci*100:4.1f}%     | {z_m*100:6.1f}% ± {z_ci*100:4.1f}%\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "19811770",
            "metadata": {},
            "outputs": [],
            "source": []
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "bd3f934d",
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": ".venv",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.11.13"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 5
}
