{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7c35901d",
   "metadata": {},
   "source": [
    "# Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60d1af89",
   "metadata": {},
   "outputs": [],
   "source": [
    "## NOTE: please install any packages if necessary; environment.yml is provided."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "43c8d9b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Import relevant packages.\n",
    "import copy\n",
    "from munch import Munch\n",
    "from src.train import train\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5f802e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "USE_WANDB = 0\n",
    "\n",
    "if USE_WANDB:\n",
    "    ## Optional: logging to wandb\n",
    "    import wandb\n",
    "    wandb.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f3eda11a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set default training args\n",
    "default_args = {\n",
    "    \"random_seed\": 0,\n",
    "    \"max_steps\": 30000,\n",
    "    \"graph_dir\": \"graphs\",\n",
    "    \"train_type\": \"Two-levelGraph_SFT2\",\n",
    "    \"graph_type\": \"Two-levelGraph\",\n",
    "    \"eval_rate\": 0.999,\n",
    "    \"graph_data_dir\": None,\n",
    "    \"max_generation_length\": 100,\n",
    "    \"num_hidden_layers\": 3,\n",
    "    \"num_attention_heads\": 3,\n",
    "    \"vocab_size\": 5100,\n",
    "    \"hidden_size\": 768,\n",
    "    \"position_embedding\": \"learned\",\n",
    "    \"lr\": 3e-4,\n",
    "    \"batch_size\": 256,\n",
    "    \"log_steps\": 128,\n",
    "    \"save_steps\": 128,\n",
    "    \"eval_steps\": 128,\n",
    "    \"eval_size\": 1024,\n",
    "    \"weight_decay\": 0,\n",
    "    \"warmup_ratio\": 0,\n",
    "    \"model_dir\": None,\n",
    "    \"output_dir\": \"model\",\n",
    "    \"model_config_path\": \"config/gpt2_tiny_wpetrain.py\",\n",
    "    \"world_size\": 1,\n",
    "    \"report_to_wandb\": USE_WANDB,\n",
    "    \"random_planning\": False,\n",
    "    \"fix_interval\": False,\n",
    "    \"planning_with_cluster_token\": False,\n",
    "    \"onehot_embed\": False,\n",
    "    \"provide_planning\": False,\n",
    "    \"planning_with_ST\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "834443cd",
   "metadata": {},
   "source": [
    "# Anchoring helps with pathfinding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c378f25b",
   "metadata": {},
   "source": [
    "## Comparing w/ and w/o Anchoring in k-partite graph (Figure 1(a)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59dbec04",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "Generating the Graphs\n",
    "- Two-level Graph: N1=50, N2=100, p1=1, p2=0.2\n",
    "\"\"\"\n",
    "from src.gen_graphs.kpartite_graph import gen_graph_kpartite\n",
    "config = {\n",
    "    \"Type\": \"KPartiteGraph_Bernoulli\",\n",
    "    \"K\": 9,\n",
    "    \"N\": 2500,\n",
    "    \"edge_probability\": 0.001,\n",
    "    \"directed\": True,\n",
    "    \"random_seed\": 0,\n",
    "}\n",
    "gen_graph_kpartite(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afaa6d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "Pathfinding without anchoring\n",
    "\"\"\"\n",
    "\n",
    "args = copy.deepcopy(default_args)\n",
    "args[\"vocab_size\"] = 22600\n",
    "args[\"max_steps\"] = 60000\n",
    "args[\"eval_rate\"] = 0.9\n",
    "args[\"train_type\"] = \"KPartiteGraph1\"\n",
    "args[\"graph_type\"] = \"KPartiteGraph_Bernoulli\"\n",
    "args[\"graph_data_dir\"] = \"data/Exp2_GeneralPathFinding/Graphs/KPartiteGraph_Bernoulli_9_2500_0.001_rs0.json\"\n",
    "args = Munch(args)\n",
    "train(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8c2f561",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "Pathfinding with anchoring\n",
    "\"\"\"\n",
    "\n",
    "args = copy.deepcopy(default_args)\n",
    "args[\"vocab_size\"] = 22600\n",
    "args[\"max_steps\"] = 60000\n",
    "args[\"eval_rate\"] = 0.9\n",
    "args[\"train_type\"] = \"KPartiteGraph1_planning2\"\n",
    "args[\"graph_type\"] = \"KPartiteGraph_Bernoulli\"\n",
    "args[\"graph_data_dir\"] = \"data/Exp2_GeneralPathFinding/Graphs/KPartiteGraph_Bernoulli_9_2500_0.001_rs0.json\"\n",
    "args[\"split_layer\"] = [4] # anchor at x+1 layer. Multiple anchoring tokens are allowed, e.g, [3, 5]\n",
    "args = Munch(args)\n",
    "train(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f07f08b8",
   "metadata": {},
   "source": [
    "## Comparing w/ and w/o Anchoring in two-level graph (Figure 1(b)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d72c321",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Generating the Graphs\n",
    "- Two-level Graph: N1=50, N2=100, p1=1, p2=0.2\n",
    "\"\"\"\n",
    "from src.gen_graphs.twolevel_graph import gen_graph_twolevel\n",
    "config = {\n",
    "    \"Type\": \"Two-levelGraph\",\n",
    "    \"N1\": 50,\n",
    "    \"N2\": 100,\n",
    "    \"graph_type1\": \"Clique\",\n",
    "    \"graph_type2\": \"TAE\",\n",
    "    \"additional_edge_probability\": 0.2, # p2\n",
    "    \"random_seed\": 0,\n",
    "    \"directed\": False,\n",
    "}\n",
    "gen_graph_twolevel(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "223d0d03",
   "metadata": {},
   "outputs": [],
   "source": [
    "default_args = {\n",
    "    \"random_seed\": 0,\n",
    "    \"max_steps\": 30000,\n",
    "    \"graph_dir\": \"graphs\",\n",
    "    \"train_type\": \"Two-levelGraph_SFT2\",\n",
    "    \"graph_type\": \"Two-levelGraph\",\n",
    "    \"eval_rate\": 0.999,\n",
    "    \"graph_data_dir\": None,\n",
    "    \"max_generation_length\": 40,\n",
    "    \"num_hidden_layers\": 3,\n",
    "    \"num_attention_heads\": 3,\n",
    "    \"vocab_size\": 5100,\n",
    "    \"hidden_size\": 768,\n",
    "    \"position_embedding\": \"learned\",\n",
    "    \"lr\": 3e-4,\n",
    "    \"batch_size\": 256,\n",
    "    \"log_steps\": 128,\n",
    "    \"save_steps\": 128,\n",
    "    \"eval_steps\": 128,\n",
    "    \"eval_size\": 1024,\n",
    "    \"weight_decay\": 0,\n",
    "    \"warmup_ratio\": 0,\n",
    "    \"model_dir\": None,\n",
    "    \"output_dir\": \"model\",\n",
    "    \"model_config_path\": \"config/gpt2_tiny_wpetrain.py\",\n",
    "    \"world_size\": 1,\n",
    "    \"report_to_wandb\": USE_WANDB,\n",
    "    \"random_planning\": False,\n",
    "    \"fix_interval\": None,\n",
    "    \"planning_with_cluster_token\": False,\n",
    "    \"onehot_embed\": False,\n",
    "    \"provide_planning\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cff0efa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "Training (without anchoring) -- Each run takes about 0.5 hours on a single 3090 GPU.\n",
    "\"\"\"\n",
    "\n",
    "from src.train import train\n",
    "args = copy.deepcopy(default_args)\n",
    "args[\"train_type\"] = \"Two-levelGraph_SFT2\"\n",
    "args[\"eval_rate\"] = 0.999\n",
    "args[\"graph_data_dir\"] = \"data/Exp2_GeneralPathFinding/Graphs/Two-levelGraph_Clique_TAE_0.2_50_100.json\"\n",
    "args = Munch(args)\n",
    "train(args)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8879ab09",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "Training (with anchoring) -- Each run takes about 0.5 hours on a single 3090 GPU.\n",
    "\"\"\"\n",
    "\n",
    "from src.train import train\n",
    "args = copy.deepcopy(default_args)\n",
    "args[\"train_type\"] = \"Two-levelGraph_SFT2_planning1\"\n",
    "args[\"eval_rate\"] = 0.999\n",
    "args[\"graph_data_dir\"] = \"data/Exp2_GeneralPathFinding/Graphs/Two-levelGraph_Clique_TAE_0.2_50_100.json\"\n",
    "args = Munch(args)\n",
    "train(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f44c07f",
   "metadata": {},
   "source": [
    "# Comparing Different Anchoring Strategy for Two-level Path-Finding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a0dba76",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "Generating the Graphs\n",
    "- Two-level Graph: N1=200, N2=50, p1=0.1, p2=0.4\n",
    "\"\"\"\n",
    "\n",
    "from src.gen_graphs.twolevel_graph import gen_graph_twolevel\n",
    "config = {\n",
    "    \"Type\": \"Two-levelGraph\",\n",
    "    \"N1\": 200,\n",
    "    \"N2\": 50,\n",
    "    \"graph_type1\": \"TAE\", \n",
    "    \"graph_type2\": \"TAE\",\n",
    "    \"upper_edge_probability\": 0.1,      # p1\n",
    "    \"additional_edge_probability\": 0.4, # p2\n",
    "    \"random_seed\": 0,\n",
    "    \"upper_directed\": False,\n",
    "    \"directed\": False,\n",
    "}\n",
    "gen_graph_twolevel(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21e160a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "Training (without anchoring)\n",
    "\"\"\"\n",
    "\n",
    "args = copy.deepcopy(default_args)\n",
    "args[\"train_type\"] = \"Two-levelGraph_SFT3\"\n",
    "args[\"eval_rate\"] = 0.9995\n",
    "args[\"graph_data_dir\"] = \"data/Exp2_GeneralPathFinding/Graphs/Two-levelGraph_TAE_0.1_TAE_0.4_200_50.json\"\n",
    "args[\"vocab_size\"] = 10100\n",
    "args = Munch(args)\n",
    "train(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8a2577e",
   "metadata": {},
   "source": [
    "## Comparing 3 different anchoring strategies\n",
    "- Inter-Cluster Anchoring: the anchors are endpoints of inter-cluster edges.\n",
    "- Fixed-interval Anchoring: the anchors are fixed at a certain interval; in this case, every other steps.\n",
    "- Random anchoring: the anchors are randomly selected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a90bf42d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "Inter-Cluster Anchoring\n",
    "\"\"\"\n",
    "from src.train import train\n",
    "args = copy.deepcopy(default_args)\n",
    "args[\"train_type\"] = \"Two-levelGraph_SFT3_planning2\"\n",
    "args[\"eval_rate\"] = 0.9995\n",
    "args[\"graph_data_dir\"] = \"data/Exp2_GeneralPathFinding/Graphs/Two-levelGraph_TAE_0.1_TAE_0.4_200_50.json\"\n",
    "args[\"vocab_size\"] = 10100\n",
    "args = Munch(args)\n",
    "train(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0626085f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "Fixed-interval Anchoring\n",
    "\"\"\"\n",
    "from src.train import train\n",
    "args = copy.deepcopy(default_args)\n",
    "args[\"train_type\"] = \"Two-levelGraph_SFT3_planning2\"\n",
    "args[\"eval_rate\"] = 0.9995\n",
    "args[\"graph_data_dir\"] = \"data/Exp2_GeneralPathFinding/Graphs/Two-levelGraph_TAE_0.1_TAE_0.4_200_50.json\"\n",
    "args[\"vocab_size\"] = 10100\n",
    "args[\"fix_interval\"] = True\n",
    "args = Munch(args)\n",
    "train(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f2e9a3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Random anchoring\n",
    "\"\"\"\n",
    "from src.train import train\n",
    "args = copy.deepcopy(default_args)\n",
    "args[\"train_type\"] = \"Two-levelGraph_SFT3_planning2\"\n",
    "args[\"eval_rate\"] = 0.9995\n",
    "args[\"graph_data_dir\"] = \"data/Exp2_GeneralPathFinding/Graphs/Two-levelGraph_TAE_0.1_TAE_0.4_200_50.json\"\n",
    "args[\"vocab_size\"] = 10100\n",
    "args[\"random_planning\"] = True\n",
    "args = Munch(args)\n",
    "train(args)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "common",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
