{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8381fb8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataclasses\n",
    "import time\n",
    "from datetime import datetime\n",
    "\n",
    "import flax\n",
    "import jax\n",
    "import orbax\n",
    "import xminigrid\n",
    "from flax.training import orbax_utils\n",
    "from jax.tree_util import Partial\n",
    "from xminigrid.environment import EnvParams\n",
    "\n",
    "from ulee_repo.DIAYN.config import TrainConfig as DIAYNTrainConfig\n",
    "from ulee_repo.DIAYN.main_loop import full_training as diayn_full_training\n",
    "from ulee_repo.DIAYN.setups import set_up_for_training as diayn_set_up_for_training\n",
    "from ulee_repo.experiments.paths import build_trained_weights_path\n",
    "from ulee_repo.PPO.config import TrainConfig as PPOTrainConfig\n",
    "from ulee_repo.PPO.main_loop import full_training as ppo_full_training\n",
    "from ulee_repo.PPO.main_loop import full_training_on_fixed_envs as ppo_full_training_on_fixed_envs\n",
    "from ulee_repo.PPO.setups import set_up_for_training as ppo_set_up_for_training\n",
    "from ulee_repo.RL2.config import TrainConfig as RL2TrainConfig\n",
    "from ulee_repo.RL2.main_loop import full_training as rl2_full_training\n",
    "from ulee_repo.RL2.setups import set_up_for_training as rl2_set_up_for_training\n",
    "from ulee_repo.shared_code.logging import (\n",
    "    generate_run_name,\n",
    "    wandb_log_training_metrics,\n",
    "    wandb_log_ulee_training_metrics,\n",
    ")\n",
    "from ulee_repo.ULEE.config import TrainConfig as ULEETrainConfig\n",
    "from ulee_repo.ULEE.main_loop import full_training as ulee_full_training\n",
    "from ulee_repo.ULEE.setups import set_up_for_training as ulee_set_up_for_training\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46def356",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"WANDB_SILENT\"] = \"true\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcecf7da",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(xminigrid.registered_benchmarks())\n",
    "print(\"-----------------------------------\")\n",
    "print(xminigrid.registered_environments())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a546be7",
   "metadata": {},
   "source": [
    "## ULEE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "336facdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def complete_config(config: ULEETrainConfig, env_params: EnvParams) -> ULEETrainConfig:\n",
    "    config.goal_search.goal_searching_steps_per_env = config.goal_search.goal_searching_episodes_per_env * env_params.max_steps\n",
    "    return config\n",
    "\n",
    "\n",
    "def run_ulee_training(config: ULEETrainConfig):\n",
    "    # setup\n",
    "    rng, env_no_goals, env_unsup_goals, env_real_goals, env_params, benchmark, meta_learner_train_state, judge_train_state, goal_search_train_state, judge_replay_buffer = ulee_set_up_for_training(\n",
    "        config\n",
    "    )\n",
    "    config = complete_config(config, env_params)\n",
    "\n",
    "    # train\n",
    "    print(f\"Training with seed {config.train_seed}\")\n",
    "    t = time.time()\n",
    "    full_training_partial = Partial(\n",
    "        ulee_full_training, env_no_goals=env_no_goals, env_unsup_goals=env_unsup_goals, env_real_goals=env_real_goals, env_params=env_params, benchmark=benchmark, config=config\n",
    "    )\n",
    "    jitted_full_training = jax.jit(full_training_partial)\n",
    "    train_info = jax.block_until_ready(\n",
    "        jitted_full_training(\n",
    "            rng=rng, meta_learner_train_state=meta_learner_train_state, judge_train_state=judge_train_state, goal_search_train_state=goal_search_train_state, judge_replay_buffer=judge_replay_buffer\n",
    "        )\n",
    "    )\n",
    "    elapsed_time = time.time() - t\n",
    "    print(f\"Done in {elapsed_time / 60:.2f}min\")\n",
    "\n",
    "    try:\n",
    "        # store results on disk\n",
    "        save_path = build_trained_weights_path(\n",
    "            algorithm_id=\"ulee\",\n",
    "            env_id=config.env_id,\n",
    "            benchmark_id=config.benchmark_id,\n",
    "            seed=config.train_seed,\n",
    "            goal_search_algorithm=config.goal_search_algorithm,\n",
    "            goal_sampling_method=config.goal_sampling_method,\n",
    "        )\n",
    "        # timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "        # save_path = save_path.parent / f\"{save_path.name}_{timestamp}\"\n",
    "        save_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "        train_config = dataclasses.asdict(config)\n",
    "        train_config = flax.core.freeze(train_config)\n",
    "        if config.goal_search_algorithm == \"ppo\":\n",
    "            goal_search_params = train_info[\"goal_search_train_state\"].params\n",
    "            best_goal_search_params = train_info[\"best\"][2]\n",
    "        elif config.goal_search_algorithm == \"diayn\":\n",
    "            policy_train_state, discriminator_train_state = train_info[\"goal_search_train_state\"]\n",
    "            goal_search_params = ({\"params\": policy_train_state.params, \"constants\": policy_train_state.constants}, discriminator_train_state.params)\n",
    "            best_goal_search_params = ({\"params\": train_info[\"best\"][2][0], \"constants\": policy_train_state.constants}, train_info[\"best\"][2][1])\n",
    "        elif config.goal_search_algorithm == \"random\":\n",
    "            goal_search_params = None\n",
    "            best_goal_search_params = None\n",
    "\n",
    "        training_results = {\n",
    "            \"config\": train_config,\n",
    "            \"meta_learner_params\": train_info[\"meta_learner_state\"].params,\n",
    "            \"judge_params\": train_info[\"judge_train_state\"].params,\n",
    "            \"goal_search_params\": goal_search_params,\n",
    "            \"best_meta_learner_params\": train_info[\"best\"][1],\n",
    "            \"best_goal_search_params\": best_goal_search_params,\n",
    "            \"metrics\": train_info[\"metrics\"],\n",
    "        }\n",
    "\n",
    "        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()\n",
    "        save_args = orbax_utils.save_args_from_target(training_results)\n",
    "        orbax_checkpointer.save(save_path, training_results, save_args=save_args)\n",
    "        print(\"saved training results to\", save_path)\n",
    "    except Exception as e:\n",
    "        print(f\"Error while saving training results to disk: {e}\")\n",
    "\n",
    "    # save logs to wandb\n",
    "    run_name = generate_run_name(algorithm_name=\"ULEE\", config=config, prefix=\"\")\n",
    "    tags = [\"ulee\", \"train\"]\n",
    "    wandb_log_ulee_training_metrics(train_info[\"metrics\"], config, run_name=run_name, tags=tags, num_final_episodes_for_evaluating_performance=10)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f2c718",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_seeds = [10, 20, 30, 40]\n",
    "goal_search_algorithms = [\"ppo\"] * len(train_seeds)\n",
    "goal_sampling_methods = [\"bounded_uniform\"] * len(train_seeds)\n",
    "total_timesteps = 5_000_000_000\n",
    "env_id = \"XLand-MiniGrid-R4-13x13\"\n",
    "benchmark_id = \"small-1m\"\n",
    "\n",
    "\n",
    "for (\n",
    "    seed,\n",
    "    search_algo,\n",
    "    sampling_method,\n",
    ") in zip(train_seeds, goal_search_algorithms, goal_sampling_methods):\n",
    "    config = ULEETrainConfig(\n",
    "        train_seed=seed,\n",
    "        benchmark_split_seed=seed + 100,\n",
    "        total_timesteps=total_timesteps,\n",
    "        env_id=env_id,\n",
    "        benchmark_id=benchmark_id,\n",
    "        goal_search_algorithm=search_algo,\n",
    "        goal_sampling_method=sampling_method,\n",
    "    )\n",
    "\n",
    "    run_ulee_training(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6f9920f",
   "metadata": {},
   "source": [
    "## PPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29edd320",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_ppo_training(config: PPOTrainConfig, fixed_envs: bool):\n",
    "    # setup\n",
    "    rng, env, env_params, benchmark, train_state = ppo_set_up_for_training(config)\n",
    "\n",
    "    # train\n",
    "    print(f\"Training with seed {config.train_seed}\")\n",
    "    t = time.time()\n",
    "\n",
    "    if fixed_envs:\n",
    "        full_training_partial = Partial(ppo_full_training_on_fixed_envs, env=env, env_params=env_params, benchmark=benchmark, config=config)\n",
    "        jitted_full_training = jax.jit(full_training_partial)\n",
    "        train_info = jax.block_until_ready(jitted_full_training(rng=rng, train_state=train_state))\n",
    "    else:\n",
    "        full_training_partial = Partial(ppo_full_training, env=env, env_params=env_params, benchmark=benchmark, config=config)\n",
    "        jitted_full_training = jax.jit(full_training_partial)\n",
    "        train_info = jax.block_until_ready(jitted_full_training(rng=rng, train_state=train_state))\n",
    "    elapsed_time = time.time() - t\n",
    "    print(f\"Done in {elapsed_time / 60:.2f}min\")\n",
    "\n",
    "    try:\n",
    "        # store results on disk\n",
    "        save_path = build_trained_weights_path(\n",
    "            algorithm_id=\"ppo\",\n",
    "            env_id=config.env_id,\n",
    "            benchmark_id=config.benchmark_id,\n",
    "            seed=config.train_seed,\n",
    "        )\n",
    "        save_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "        train_config = dataclasses.asdict(config)\n",
    "        train_config = flax.core.freeze(train_config)\n",
    "        if fixed_envs:\n",
    "            training_results = {\n",
    "                \"config\": train_config,\n",
    "                \"agent_params\": train_info[\"agent_state\"].params,\n",
    "                \"metrics\": train_info[\"metrics\"],\n",
    "            }\n",
    "        else:\n",
    "            training_results = {\n",
    "                \"config\": train_config,\n",
    "                \"agent_params\": train_info[\"agent_state\"].params,\n",
    "                \"best_agent_params\": train_info[\"best\"][1],\n",
    "                \"metrics\": train_info[\"metrics\"],\n",
    "            }\n",
    "\n",
    "        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()\n",
    "        save_args = orbax_utils.save_args_from_target(training_results)\n",
    "        orbax_checkpointer.save(save_path, training_results, save_args=save_args)\n",
    "        print(\"saved training results to\", save_path)\n",
    "    except Exception as e:\n",
    "        print(f\"Error while saving training results to disk: {e}\")\n",
    "\n",
    "    # log metrics to wandb\n",
    "    extra_logs = {\n",
    "        \"training/lr\": train_info[\"metrics\"][\"lr\"],\n",
    "    }\n",
    "    run_name = generate_run_name(algorithm_name=\"PPO\", config=config, prefix=\"\")\n",
    "    tags = [\"ppo\", \"train\"]\n",
    "    wandb_log_training_metrics(train_info[\"metrics\"], config, run_name, project_name=\"ULEE\", tags=tags, extra_batch_metrics=extra_logs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90fe0ac9",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_seeds = [10, 20, 30, 40]\n",
    "total_timesteps = 5_000_000_000\n",
    "env_id = \"XLand-MiniGrid-R4-13x13\"\n",
    "benchmark_id = \"small-1m\"\n",
    "\n",
    "for seed in train_seeds:\n",
    "    config = PPOTrainConfig(\n",
    "        train_seed=seed,\n",
    "        benchmark_split_seed=seed + 100,\n",
    "        total_timesteps=total_timesteps,\n",
    "        env_id=env_id,\n",
    "        benchmark_id=benchmark_id,\n",
    "    )\n",
    "\n",
    "    run_ppo_training(config, fixed_envs=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85e5edb2",
   "metadata": {},
   "source": [
    "## DIAYN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e02c21d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_diayn_training(config: DIAYNTrainConfig):\n",
    "    # setup\n",
    "    rng, env_no_goals, env_real_goals, env_params, benchmark, agent_train_state, discriminator_train_state = diayn_set_up_for_training(config)\n",
    "\n",
    "    # train\n",
    "    print(f\"Training with seed {config.train_seed}\")\n",
    "    t = time.time()\n",
    "    full_training_partial = Partial(diayn_full_training, env_no_goals=env_no_goals, env_real_goals=env_real_goals, env_params=env_params, benchmark=benchmark, config=config)\n",
    "    jitted_full_training = jax.jit(full_training_partial)\n",
    "    train_info = jax.block_until_ready(\n",
    "        jitted_full_training(\n",
    "            rng=rng,\n",
    "            agent_train_state=agent_train_state,\n",
    "            discriminator_train_state=discriminator_train_state,\n",
    "        )\n",
    "    )\n",
    "    elapsed_time = time.time() - t\n",
    "    print(f\"Done in {elapsed_time / 60:.2f}min\")\n",
    "\n",
    "    try:\n",
    "        # store results on disk\n",
    "        save_path = build_trained_weights_path(\n",
    "            algorithm_id=\"diayn\",\n",
    "            env_id=config.env_id,\n",
    "            benchmark_id=config.benchmark_id,\n",
    "            seed=config.train_seed,\n",
    "        )\n",
    "        save_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "        train_config = dataclasses.asdict(config)\n",
    "        train_config = flax.core.freeze(train_config)\n",
    "        agent_params = {\"params\": train_info[\"agent_state\"].params, \"constants\": train_info[\"agent_state\"].constants}\n",
    "        best_agent_params = {\"params\": train_info[\"best\"][1], \"constants\": train_info[\"agent_state\"].constants}\n",
    "        training_results = {\n",
    "            \"config\": train_config,\n",
    "            \"agent_params\": agent_params,\n",
    "            \"best_agent_params\": best_agent_params,\n",
    "            \"metrics\": train_info[\"metrics\"],\n",
    "        }\n",
    "\n",
    "        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()\n",
    "        save_args = orbax_utils.save_args_from_target(training_results)\n",
    "        orbax_checkpointer.save(save_path, training_results, save_args=save_args)\n",
    "        print(\"saved training results to\", save_path)\n",
    "    except Exception as e:\n",
    "        print(f\"Error while saving training results to disk: {e}\")\n",
    "\n",
    "    # log metrics to wandb\n",
    "    extra_logs = {\n",
    "        \"training/lr\": train_info[\"metrics\"][\"lr\"],\n",
    "        \"discriminator/discriminator_loss\": train_info[\"metrics\"][\"discriminator_loss\"],\n",
    "        \"discriminator/skills_logprob\": train_info[\"metrics\"][\"skills_log_prob\"],\n",
    "    }\n",
    "    run_name = generate_run_name(algorithm_name=\"DIAYN\", config=config, prefix=\"\")\n",
    "    tags = [\"diayn\", \"train\"]\n",
    "    wandb_log_training_metrics(train_info[\"metrics\"], config, run_name, project_name=\"ULEE\", tags=tags, extra_batch_metrics=extra_logs)\n",
    "    time.sleep(15)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6f2b8bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_seeds = [10, 20, 30, 40]\n",
    "total_timesteps = 5_000_000_000\n",
    "env_id = \"XLand-MiniGrid-R4-13x13\"\n",
    "benchmark_id = \"small-1m\"\n",
    "\n",
    "for seed in train_seeds:\n",
    "    config = DIAYNTrainConfig(\n",
    "        train_seed=seed,\n",
    "        benchmark_split_seed=seed + 100,\n",
    "        total_timesteps=total_timesteps,\n",
    "        env_id=env_id,\n",
    "        benchmark_id=benchmark_id,\n",
    "    )\n",
    "\n",
    "    run_diayn_training(config)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27cdfb29",
   "metadata": {},
   "source": [
    "## RL2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87e46ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_rl2_training(config: RL2TrainConfig):\n",
    "    # setup\n",
    "    rng, env, env_params, benchmark, train_state = rl2_set_up_for_training(config)\n",
    "\n",
    "    # train\n",
    "    print(f\"Training with seed {config.train_seed}\")\n",
    "    t = time.time()\n",
    "    full_training_partial = Partial(rl2_full_training, env=env, env_params=env_params, benchmark=benchmark, config=config)\n",
    "    jitted_full_training = jax.jit(full_training_partial)\n",
    "    train_info = jax.block_until_ready(jitted_full_training(rng=rng, train_state=train_state))\n",
    "    elapsed_time = time.time() - t\n",
    "    print(f\"Done in {elapsed_time / 60:.2f}min\")\n",
    "\n",
    "    try:\n",
    "        # store results on disk\n",
    "        save_path = build_trained_weights_path(\n",
    "            algorithm_id=\"rl2\",\n",
    "            env_id=config.env_id,\n",
    "            benchmark_id=config.benchmark_id,\n",
    "            seed=config.train_seed,\n",
    "        )\n",
    "        save_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "        train_config = dataclasses.asdict(config)\n",
    "        train_config = flax.core.freeze(train_config)\n",
    "        training_results = {\n",
    "            \"config\": train_config,\n",
    "            \"agent_params\": train_info[\"agent_state\"].params,\n",
    "            \"best_agent_params\": train_info[\"best\"][1],\n",
    "            \"metrics\": train_info[\"metrics\"],\n",
    "        }\n",
    "\n",
    "        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()\n",
    "        save_args = orbax_utils.save_args_from_target(training_results)\n",
    "        orbax_checkpointer.save(save_path, training_results, save_args=save_args)\n",
    "        print(\"saved training results to\", save_path)\n",
    "    except Exception as e:\n",
    "        print(f\"Error while saving training results to disk: {e}\")\n",
    "\n",
    "    # log metrics to wandb\n",
    "    extra_logs = {\n",
    "        \"training/lr\": train_info[\"metrics\"][\"lr\"],\n",
    "    }\n",
    "    run_name = generate_run_name(algorithm_name=\"RL2\", config=config, prefix=\"\")\n",
    "    tags = [\"rl2\", \"train\"]\n",
    "    wandb_log_training_metrics(train_info[\"metrics\"], config, run_name, project_name=\"ULEE\", tags=tags, extra_batch_metrics=extra_logs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2f756c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_seeds = [10, 20, 30, 40]\n",
    "total_timesteps = 5_000_000_000\n",
    "env_id = \"XLand-MiniGrid-R4-13x13\"\n",
    "benchmark_id = \"small-1m\"\n",
    "\n",
    "for seed in train_seeds:\n",
    "    config = RL2TrainConfig(\n",
    "        train_seed=seed,\n",
    "        benchmark_split_seed=seed + 100,\n",
    "        total_timesteps=total_timesteps,\n",
    "        env_id=env_id,\n",
    "        benchmark_id=benchmark_id,\n",
    "    )\n",
    "\n",
    "    run_rl2_training(config)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax_env_2025",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
