{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d9f33d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import jax\n",
    "\n",
    "from ulee_repo.evaluations.diayn_evals import eval_diayn_finetune\n",
    "from ulee_repo.evaluations.meta_learner_evals import eval_meta_learner_finetune, eval_meta_learner_finetune_on_meta_rl\n",
    "from ulee_repo.evaluations.rollouts_on_trained import rollout_on_trained_weights\n",
    "from ulee_repo.experiments.paths import build_best_weights_rollouts_path, build_finetuned_on_meta_rl_path, build_finetuned_weights_path, build_trained_weights_path\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d3e1b5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"WANDB_SILENT\"] = \"true\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aeae191f",
   "metadata": {},
   "source": [
    "## Rollouts on trained weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35cfec63",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_seed = 42\n",
    "eval_rng = jax.random.key(eval_seed)\n",
    "eval_num_envs = 16384\n",
    "eval_num_episodes = 30\n",
    "\n",
    "env_id = \"XLand-MiniGrid-R4-13x13\"\n",
    "benchmark_id = \"small-1m\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c23e9d6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ULEE\n",
    "\n",
    "goal_search_algorithm = \"random\"\n",
    "goal_sampling_method = \"uniform\"\n",
    "seeds = [10, 20, 30, 40]\n",
    "\n",
    "\n",
    "ulee_trained_weights_paths = [build_trained_weights_path(\"ulee\", env_id, benchmark_id, seed, goal_search_algorithm, goal_sampling_method) for seed in seeds]\n",
    "\n",
    "ulee_best_weights_rollouts_paths = [build_best_weights_rollouts_path(\"ulee\", env_id, benchmark_id, seed, goal_search_algorithm, goal_sampling_method) for seed in seeds]\n",
    "\n",
    "\n",
    "for weight_path, result_path in zip(ulee_trained_weights_paths, ulee_best_weights_rollouts_paths, strict=True):\n",
    "    rollout_on_trained_weights(\n",
    "        rng=eval_rng,\n",
    "        num_envs=eval_num_envs,\n",
    "        num_episodes=eval_num_episodes,\n",
    "        algorithm_id=\"ulee\",\n",
    "        env_id=env_id,\n",
    "        benchmark_id=benchmark_id,\n",
    "        weights_path=weight_path,\n",
    "        results_path=result_path,\n",
    "        eval_on_test_benchmark=True,\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d253d4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# RANDOM POLICY\n",
    "\n",
    "seeds = [10, 20, 30, 40]\n",
    "\n",
    "random_rollouts_paths = [build_best_weights_rollouts_path(\"random\", env_id, benchmark_id, seed) for seed in seeds]\n",
    "\n",
    "for result_path in random_rollouts_paths:\n",
    "    rollout_on_trained_weights(\n",
    "        rng=eval_rng, num_envs=eval_num_envs, num_episodes=eval_num_episodes, algorithm_id=\"random\", env_id=env_id, benchmark_id=benchmark_id, weights_path=None, results_path=result_path\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5780ad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# DIAYN\n",
    "\n",
    "seeds = [10, 20, 30, 40]\n",
    "\n",
    "\n",
    "diayn_trained_weights_paths = [build_trained_weights_path(\"diayn\", env_id, benchmark_id, seed) for seed in seeds]\n",
    "\n",
    "diayn_best_weights_rollouts_paths = [build_best_weights_rollouts_path(\"diayn\", env_id, benchmark_id, seed) for seed in seeds]\n",
    "\n",
    "for weight_path, result_path in zip(diayn_trained_weights_paths, diayn_best_weights_rollouts_paths, strict=True):\n",
    "    rollout_on_trained_weights(\n",
    "        rng=eval_rng,\n",
    "        num_envs=eval_num_envs,\n",
    "        num_episodes=eval_num_episodes,\n",
    "        algorithm_id=\"diayn\",\n",
    "        env_id=env_id,\n",
    "        benchmark_id=benchmark_id,\n",
    "        weights_path=weight_path,\n",
    "        results_path=result_path,\n",
    "        eval_on_test_benchmark=True,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3a0eeb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PPO\n",
    "\n",
    "seeds = [10, 20, 30, 40]\n",
    "\n",
    "\n",
    "ppo_trained_weights_paths = [build_trained_weights_path(\"ppo\", env_id, benchmark_id, seed) for seed in seeds]\n",
    "\n",
    "ppo_best_weights_rollouts_paths = [build_best_weights_rollouts_path(\"ppo\", env_id, benchmark_id, seed) for seed in seeds]\n",
    "\n",
    "for weight_path, result_path in zip(ppo_trained_weights_paths, ppo_best_weights_rollouts_paths, strict=True):\n",
    "    rollout_on_trained_weights(\n",
    "        rng=eval_rng,\n",
    "        num_envs=eval_num_envs,\n",
    "        num_episodes=eval_num_episodes,\n",
    "        algorithm_id=\"standard_ppo\",\n",
    "        env_id=env_id,\n",
    "        benchmark_id=benchmark_id,\n",
    "        weights_path=weight_path,\n",
    "        results_path=result_path,\n",
    "        eval_on_test_benchmark=True,\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "355c4724",
   "metadata": {},
   "source": [
    "## Fine-tuning on fixed tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3778c9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_seed = 42\n",
    "eval_rng = jax.random.key(eval_seed)\n",
    "env_id = \"XLand-MiniGrid-R4-13x13\"\n",
    "benchmark_id = \"small-1m\"\n",
    "\n",
    "num_envs = 2048\n",
    "total_timesteps = 1_000_000_000\n",
    "num_steps_per_env = 5120\n",
    "num_steps_per_update = 256\n",
    "eval_num_episodes = 30\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9e750c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ULEE\n",
    "goal_search_algorithm = \"ppo\"\n",
    "goal_sampling_method = \"bounded_uniform\"\n",
    "train_seeds = [10, 20, 30, 40]  # Seeds on which pre-training was performed\n",
    "finetune_seeds = [210, 220, 230, 240]  # Seeds to carry out finetuning\n",
    "\n",
    "\n",
    "ulee_trained_weights_paths = [build_trained_weights_path(\"ulee\", env_id, benchmark_id, train_seed, goal_search_algorithm, goal_sampling_method) for train_seed in train_seeds]\n",
    "\n",
    "ulee_finetuned_weights_paths = [\n",
    "    build_finetuned_weights_path(\"ulee\", env_id, benchmark_id, train_seed, finetune_seed, goal_search_algorithm, goal_sampling_method)\n",
    "    for (train_seed, finetune_seed) in zip(train_seeds, finetune_seeds, strict=True)\n",
    "]\n",
    "\n",
    "\n",
    "# set extra configurations for fine-tuning\n",
    "extra_configs = {\n",
    "    \"eval_num_episodes\": eval_num_episodes,\n",
    "}\n",
    "\n",
    "# perform fine-tuning evaluation\n",
    "for weight_path, result_path in zip(ulee_trained_weights_paths, ulee_finetuned_weights_paths, strict=True):\n",
    "    eval_meta_learner_finetune(\n",
    "        rng=eval_rng,\n",
    "        env_id=env_id,\n",
    "        benchmark_id=benchmark_id,\n",
    "        weights_path=weight_path,\n",
    "        results_path=result_path,\n",
    "        num_envs=num_envs,\n",
    "        total_timesteps=total_timesteps,\n",
    "        num_steps_per_env=num_steps_per_env,\n",
    "        num_steps_per_update=num_steps_per_update,\n",
    "        eval_on_test_benchmark=True,\n",
    "        **extra_configs,\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab977502",
   "metadata": {},
   "outputs": [],
   "source": [
    "# DIAYN\n",
    "train_seeds = [10, 20, 30, 40]  # Seeds on which pre-training was performed\n",
    "finetune_seeds = [210, 220, 230, 240]  # Seeds to carry out finetuning\n",
    "\n",
    "\n",
    "diayn_trained_weights_paths = [build_trained_weights_path(\"diayn\", env_id, benchmark_id, train_seed) for train_seed in train_seeds]\n",
    "\n",
    "diayn_finetuned_weights_paths = [\n",
    "    build_finetuned_weights_path(\"diayn\", env_id, benchmark_id, train_seed, finetune_seed) for (train_seed, finetune_seed) in zip(train_seeds, finetune_seeds, strict=True)\n",
    "]\n",
    "\n",
    "\n",
    "# set extra configurations for finetuning\n",
    "extra_configs = {\n",
    "    \"num_eval_episodes_with_best_skill\": eval_num_episodes,  # when finetuning, num_eval_episodes_with_best_skill controls de total number of eval episodes executed per environment\n",
    "    \"num_eval_episodes_per_skill\": 10,  # when finetuning num_eval_episodes_per_skill controls the number of episodes per skill on each env used to determine the best skill for each env (which remains fixed throughout the finetuning process)\n",
    "}\n",
    "\n",
    "# perform finetuning evaluation\n",
    "for weight_path, result_path in zip(diayn_trained_weights_paths, diayn_finetuned_weights_paths, strict=True):\n",
    "    eval_diayn_finetune(\n",
    "        rng=eval_rng,\n",
    "        env_id=env_id,\n",
    "        benchmark_id=benchmark_id,\n",
    "        weights_path=weight_path,\n",
    "        results_path=result_path,\n",
    "        num_envs=num_envs,\n",
    "        total_timesteps=total_timesteps,\n",
    "        num_steps_per_env=num_steps_per_env,\n",
    "        num_steps_per_update=num_steps_per_update,\n",
    "        **extra_configs,\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39a51029",
   "metadata": {},
   "source": [
    "## Method 4 - Evaluation of finetuning on meta RL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b8c188a",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_seed = 42\n",
    "eval_rng = jax.random.key(eval_seed)\n",
    "env_id = \"XLand-MiniGrid-R4-13x13\"\n",
    "benchmark_id = \"small-1m\"\n",
    "\n",
    "num_envs = 2048\n",
    "total_timesteps = 5_000_000_000\n",
    "num_steps_per_env = 5120\n",
    "num_steps_per_update = 256\n",
    "eval_num_episodes = 25\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deac95b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ULEE\n",
    "goal_search_algorithm = \"ppo\"\n",
    "goal_sampling_method = \"bounded_uniform\"\n",
    "train_seeds = [10, 20, 30, 40]  # Seeds on which pre-training was performed\n",
    "finetune_seeds = [210, 220, 230, 240]  # Seeds to carry out fine-tuning\n",
    "\n",
    "\n",
    "ulee_trained_weights_paths = [build_trained_weights_path(\"ulee\", env_id, benchmark_id, train_seed, goal_search_algorithm, goal_sampling_method) for train_seed in train_seeds]\n",
    "\n",
    "ulee_finetuned_on_meta_rl_weights_paths = [\n",
    "    build_finetuned_on_meta_rl_path(\"ulee\", env_id, benchmark_id, train_seed, finetune_seed, goal_search_algorithm, goal_sampling_method)\n",
    "    for (train_seed, finetune_seed) in zip(train_seeds, finetune_seeds, strict=True)\n",
    "]\n",
    "\n",
    "\n",
    "extra_configs = {\n",
    "    \"eval_num_episodes\": eval_num_episodes,\n",
    "}\n",
    "\n",
    "# perform fine-tuning evaluation on meta rl for meta learning algorithm\n",
    "for weight_path, result_path in zip(ulee_trained_weights_paths, ulee_finetuned_on_meta_rl_weights_paths, strict=True):\n",
    "    eval_meta_learner_finetune_on_meta_rl(\n",
    "        rng=eval_rng,\n",
    "        env_id=env_id,\n",
    "        benchmark_id=benchmark_id,\n",
    "        weights_path=weight_path,\n",
    "        results_path=result_path,\n",
    "        num_envs=num_envs,\n",
    "        total_timesteps=total_timesteps,\n",
    "        num_steps_per_env=num_steps_per_env,\n",
    "        num_steps_per_update=num_steps_per_update,\n",
    "        **extra_configs,\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e464672",
   "metadata": {},
   "source": [
    "## Eval on MiniGrid environments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a543067",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_seed = 42\n",
    "eval_rng = jax.random.key(eval_seed)\n",
    "eval_num_envs = 2048\n",
    "eval_num_episodes = 30\n",
    "\n",
    "benchmark_id = \"small-1m\"\n",
    "envs_ids = [\n",
    "    \"MiniGrid-BlockedUnlockPickUp\",\n",
    "    \"MiniGrid-DoorKey-5x5\",\n",
    "    \"MiniGrid-DoorKey-8x8\",\n",
    "    \"MiniGrid-DoorKey-16x16\",\n",
    "    \"MiniGrid-Empty-8x8\",\n",
    "    \"MiniGrid-Empty-16x16\",\n",
    "    \"MiniGrid-EmptyRandom-8x8\",\n",
    "    \"MiniGrid-EmptyRandom-16x16\",\n",
    "    \"MiniGrid-FourRooms\",\n",
    "    \"MiniGrid-LockedRoom\",\n",
    "    \"MiniGrid-MemoryS8\",\n",
    "    \"MiniGrid-MemoryS16\",\n",
    "    \"MiniGrid-MemoryS64\",\n",
    "    \"MiniGrid-Unlock\",\n",
    "    \"MiniGrid-UnlockPickUp\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6577f85b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ULEE\n",
    "\n",
    "goal_search_algorithm = \"ppo\"\n",
    "goal_sampling_method = \"bounded_uniform\"\n",
    "seeds = [10, 20, 30, 40]\n",
    "\n",
    "\n",
    "ulee_trained_weights_paths = [build_trained_weights_path(\"ulee\", \"XLand-MiniGrid-R4-13x13\", \"small-1m\", seed, goal_search_algorithm, goal_sampling_method) for seed in seeds]\n",
    "\n",
    "for env_id in envs_ids:\n",
    "    ulee_best_weights_rollouts_paths = [build_best_weights_rollouts_path(\"ulee\", env_id, benchmark_id, seed, goal_search_algorithm, goal_sampling_method) for seed in seeds]\n",
    "\n",
    "    for weight_path, result_path in zip(ulee_trained_weights_paths, ulee_best_weights_rollouts_paths, strict=True):\n",
    "        rollout_on_trained_weights(\n",
    "            rng=eval_rng,\n",
    "            num_envs=eval_num_envs,\n",
    "            num_episodes=eval_num_episodes,\n",
    "            algorithm_id=\"ulee\",\n",
    "            env_id=env_id,\n",
    "            benchmark_id=benchmark_id,\n",
    "            weights_path=weight_path,\n",
    "            results_path=result_path,\n",
    "            eval_on_test_benchmark=True,\n",
    "        )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax_env_2025",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
