{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95157476-70f1-40b5-b7ad-5aad4f68ccee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from deploy_utils import get_available_gpus\n",
    "from deploy import NotebookDeployer, make_param_grid\n",
    "import os\n",
    "import glob\n",
    "import itertools\n",
    "import time\n",
    "import asyncio"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "298491cc-1e65-4c4a-ad37-de40144c4e6a",
   "metadata": {},
   "source": [
    "# GPU Tasks\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b129f267-0d69-4372-b8a3-4701b06af2a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_slurm = True\n",
    "\n",
    "if run_slurm:\n",
    "    gpus = [0,1,2,3] # our nodes always have 4 gpus per node\n",
    "    os.environ[\"PARTITIONS\"] = \"kisski-h100\"\n",
    "    md = NotebookDeployer(gpu_list=gpus, default_backend=\"python-slurm\")\n",
    "else:\n",
    "    gpus = get_available_gpus(max_utilization=20)\n",
    "    md = NotebookDeployer(gpu_list=gpus, default_backend=\"python\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a52b3ed-3896-418e-9dfc-af325b1752f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "md.status()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3ef48ea-af5c-4af6-a7bc-6339f7f6feba",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"tot_test_split\", \"four_digits_unsolvable\"]\n",
    "models = [\"mistralai/Mistral-Small-24B-Instruct-2501\", \"microsoft/phi-4\", \"RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc9e206d-dfd5-4a67-8ca0-bb5e52e74542",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "def is_exception_pickle(path, size_threshold=100_000):\n",
    "    \"\"\"\n",
    "    Check whether the pickle file at `path` contains an Exception.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    path : str\n",
    "        Path to the pickle file.\n",
    "    size_threshold : int, optional\n",
    "        Max size (in bytes) to attempt loading. Defaults to 100 KB.\n",
    "        Larger files are assumed to be \"normal runs\" and skipped.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    bool\n",
    "        True if the file contains an Exception object, False otherwise.\n",
    "    \"\"\"\n",
    "    if not os.path.exists(path):\n",
    "        return False\n",
    "\n",
    "    if os.path.getsize(path) < size_threshold:\n",
    "        print(\"checking\", path,\"for being an exception\")\n",
    "        try:\n",
    "            with open(path, \"rb\") as f:\n",
    "                obj = pickle.load(f)\n",
    "         \n",
    "        except Exception as e:\n",
    "            obj = e\n",
    "        if isinstance(obj, BaseException):\n",
    "            print(f\"'{obj}'\", \"for path\", logdir)\n",
    "            return True\n",
    "        return \n",
    "    else:\n",
    "        return False\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "107ff598-ac68-4225-80fc-8abed922b9e5",
   "metadata": {},
   "source": [
    "## IO + CoT Baselines\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0444ce0-5b27-4ac7-b727-6e574d1a60b8",
   "metadata": {
    "editable": true,
    "scrolled": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "methods = [\"io\", \"cot\"]\n",
    "n = 100\n",
    "for model_name in models:\n",
    "    for method in methods:\n",
    "        for d in datasets:\n",
    "            logdir = f\"data/logs/{d}/{method}/{model_name}/\"\n",
    "            os.makedirs(logdir, exist_ok=True)\n",
    "            if f\"{n-1}.dat\" in os.listdir(logdir):\n",
    "                print(f\"Skipping '{logdir}' as it seems already present\")\n",
    "                continue\n",
    "            t = md.enqueue(notebook=f\"baseline_io_cot.ipynb\", logdir=logdir, overwrite_logdir=True, num_gpus=len(gpus), \n",
    "                           dataset=d, n=n, model_name=model_name, baseline=method)\n",
    "            await asyncio.sleep(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05b27e53-ce55-4d93-b7fd-cb384c2dff25",
   "metadata": {},
   "source": [
    "# ToT Baseline\n",
    "***"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffc19e71-f23e-488b-804d-96385a622a94",
   "metadata": {},
   "source": [
    "## broad search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5b57768-d5e5-4e32-8947-9d7f9e394cf6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "method = \"tot_broad_search\"\n",
    "\n",
    "param_ranges = {\n",
    "    \"do_verify\": [True, False],\n",
    "    \"n_propose_sample\": [1],\n",
    "    \"n_verify_sample\": [3],\n",
    "    \"n_evaluate_sample\": [3,7,11],\n",
    "    \"n_select_sample\": [1,3,5,10,20],\n",
    "}\n",
    "\n",
    "d = \"tot_test_split\"\n",
    "for model_name in models:\n",
    "    for values in itertools.product(*param_ranges.values()):\n",
    "        params = dict(zip(param_ranges.keys(), values))\n",
    "        folder = \"_\".join(str(int(v)) if isinstance(v, bool) else str(v) for v in values)\n",
    "        logdir = f\"data/logs/{d}/{method}/{model_name}/{folder}/\"\n",
    "        os.makedirs(logdir, exist_ok=True)\n",
    "        if \"0.dat\" in os.listdir(logdir) and not is_exception_pickle(logdir+\"0.dat\"):\n",
    "            print(f\"Skipping '{logdir}' as it seems already present\")\n",
    "            continue\n",
    "        t = md.enqueue(notebook=\"baseline_tot.ipynb\", logdir=logdir, overwrite_logdir=True, num_gpus=len(gpus), \n",
    "                       dataset=d, model_name=model_name, **params)\n",
    "        await asyncio.sleep(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "882ee1d2-4aea-4e85-b252-dbaf52a5c486",
   "metadata": {},
   "source": [
    "## repetitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "859b15ba-7f3c-4b22-8ec7-79847eb20c95",
   "metadata": {},
   "outputs": [],
   "source": [
    "param_ranges = {\n",
    "    \"do_verify\": [True, False],\n",
    "    \"n_propose_sample\": [1],\n",
    "    \"n_verify_sample\": [3],\n",
    "    \"n_evaluate_sample\": [3,7,11],\n",
    "    \"n_select_sample\": [1,3,5,10,20],\n",
    "}\n",
    "\n",
    "top_params = [\n",
    " {'model_name': 'RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic', 'do_verify': True, 'n_evaluate_sample': 11, 'n_select_sample': 20}, \n",
    " {'model_name': 'RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic', 'do_verify': True, 'n_evaluate_sample': 7, 'n_select_sample': 20}, \n",
    " {'model_name': 'RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic', 'do_verify': False, 'n_evaluate_sample': 7, 'n_select_sample': 20}, \n",
    " {'model_name': 'microsoft/phi-4', 'do_verify': True, 'n_evaluate_sample': 11, 'n_select_sample': 5},\n",
    " {'model_name': 'microsoft/phi-4', 'do_verify': True, 'n_evaluate_sample': 11, 'n_select_sample': 20}, \n",
    " {'model_name': 'microsoft/phi-4', 'do_verify': True, 'n_evaluate_sample': 7, 'n_select_sample': 20}, \n",
    " {'model_name': 'mistralai/Mistral-Small-24B-Instruct-2501', 'do_verify': True, 'n_evaluate_sample': 11, 'n_select_sample': 10}, \n",
    " {'model_name': 'mistralai/Mistral-Small-24B-Instruct-2501', 'do_verify': True, 'n_evaluate_sample': 11, 'n_select_sample': 20}, \n",
    " {'model_name': 'mistralai/Mistral-Small-24B-Instruct-2501', 'do_verify': False, 'n_evaluate_sample': 11, 'n_select_sample': 3}]\n",
    "\n",
    "for d in top_params:\n",
    "    d[\"n_propose_sample\"] = 1\n",
    "    d[\"n_verify_sample\"] = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe9ec760-a6d6-4429-893d-8656ef97ea53",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n = 5\n",
    "d = \"tot_test_split\"\n",
    "method = \"tot\"\n",
    "\n",
    "for tp in top_params:\n",
    "    for i in range(n):\n",
    "        params = tp.copy()\n",
    "        model_name = params.pop(\"model_name\")\n",
    "        folder = \"_\".join(str(int(v)) if isinstance(v, bool) else str(v) for v in [params[k] for k in param_ranges.keys()])\n",
    "        logdir = f\"data/logs/{d}/{method}_{i+1}/{model_name}/{folder}/\"\n",
    "        os.makedirs(logdir, exist_ok=True)\n",
    "        if \"0.dat\" in os.listdir(logdir) and not is_exception_pickle(logdir+\"0.dat\"):\n",
    "            print(f\"Skipping '{logdir}' as it seems already present\")\n",
    "            continue\n",
    "        t = md.enqueue(notebook=\"baseline_tot.ipynb\", logdir=logdir, overwrite_logdir=True, num_gpus=len(gpus), \n",
    "                       dataset=d, model_name=model_name, **params)\n",
    "        await asyncio.sleep(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8657f8b4-9f36-496c-b962-f25e0217fc98",
   "metadata": {},
   "source": [
    "## unsolvable repetitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0832832-608f-493d-a501-225e3a3abf59",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# hyper parameter search gave the same config for all models...\n",
    "param_ranges = {\n",
    "    \"do_verify\": [True],\n",
    "    \"n_propose_sample\": [1],\n",
    "    \"n_verify_sample\": [3],\n",
    "    \"n_evaluate_sample\": [11],\n",
    "    \"n_select_sample\": [20],\n",
    "}\n",
    "\n",
    "method = \"tot\"\n",
    "d = \"four_digits_unsolvable\"\n",
    "n = 5\n",
    "for model_name in models:\n",
    "    for i in range(n):\n",
    "        for values in itertools.product(*param_ranges.values()):\n",
    "            params = dict(zip(param_ranges.keys(), values))\n",
    "            folder = \"_\".join(str(int(v)) if isinstance(v, bool) else str(v) for v in values)\n",
    "            logdir = f\"data/logs/{d}/{method}_{i+1}/{model_name}/{folder}/\"\n",
    "            os.makedirs(logdir, exist_ok=True)\n",
    "            if \"0.dat\" in os.listdir(logdir) and not is_exception_pickle(logdir+\"0.dat\"):\n",
    "                print(f\"Skipping '{logdir}' as it seems already present\")\n",
    "                continue\n",
    "            t = md.enqueue(notebook=\"baseline_tot.ipynb\", logdir=logdir, overwrite_logdir=True, num_gpus=len(gpus), \n",
    "                           dataset=d, model_name=model_name, **params)\n",
    "            await asyncio.sleep(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea3e0848-89bc-4e29-9175-18f14de0e357",
   "metadata": {},
   "source": [
    "# Holistic Prompting\n",
    "***"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e732b4d-0c7c-4efa-97d9-9be458c14471",
   "metadata": {},
   "source": [
    "## Main experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5954b0dd-6c4b-43fe-90ee-aea9f1fd7ee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "param_ranges = {\n",
    "    \"do_verify_moves\": [True],\n",
    "    \"do_verify_nodes\": [True],\n",
    "    \"do_shortcut\": [True],\n",
    "    \"do_evaluate_nodes\": [True],\n",
    "    \"n_propose_sample\": [1],\n",
    "    \"n_evaluate_sample\": [3],\n",
    "    \"n_select_sample\": [10],\n",
    "    \"n_verify_sample\": [3],\n",
    "    \"shortcut_factor\": [0, 1, 5, 10, 20],\n",
    "    \"tot_b\": [5]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cd0d9cb-d463-4898-a4f7-107427b40154",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n = 5\n",
    "method = \"hp\"\n",
    "d = \"tot_test_split\"\n",
    "de = False\n",
    "for model_name in models:\n",
    "    for values in itertools.product(*param_ranges.values()):\n",
    "        for i in range(n):\n",
    "            params = dict(zip(param_ranges.keys(), values))\n",
    "            folder = \"_\".join(str(int(v)) if isinstance(v, bool) else str(v) for v in values)\n",
    "            logdir = f\"data/logs/{d}/{method}_{i+1}/{model_name}/{folder}/\"\n",
    "            os.makedirs(logdir, exist_ok=True)\n",
    "            if \"0.dat\" in os.listdir(logdir) and not is_exception_pickle(logdir+\"0.dat\"):\n",
    "                print(f\"Skipping '{logdir}' as it seems already present\")\n",
    "                continue\n",
    "            params[\"n_shortcuts\"] = params[\"shortcut_factor\"] * params[\"n_select_sample\"]\n",
    "            t = md.enqueue(notebook=f\"hp_shortcut_rounds.ipynb\", logdir=logdir, overwrite_logdir=True, num_gpus=len(gpus), \n",
    "                           dataset=d, model_name=model_name, **params)\n",
    "            await asyncio.sleep(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3e55bd0-0f65-4905-9f58-61d75c36921b",
   "metadata": {},
   "source": [
    "## Unsolvable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5c54e04-81e4-436e-8c84-ecc4d5f397e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "defaults = {\n",
    "    \"do_verify_moves\": True,\n",
    "    \"do_verify_nodes\": True,\n",
    "    \"do_shortcut\": True,\n",
    "    \"do_evaluate_nodes\": True,\n",
    "    \"n_propose_sample\": 1,\n",
    "    \"n_evaluate_sample\": 3,\n",
    "    \"n_select_sample\": 10,\n",
    "    \"n_verify_sample\": 3,\n",
    "    \"shortcut_factor\": 0,\n",
    "    \"tot_b\": 5\n",
    "}\n",
    "\n",
    "top_params = [\n",
    " {'model_name': 'RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic', \"shortcut_factor\":1}, \n",
    " {'model_name': 'microsoft/phi-4', \"shortcut_factor\":10}, \n",
    " {'model_name': 'mistralai/Mistral-Small-24B-Instruct-2501', \"shortcut_factor\": 20}]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2aa51470-8657-48c2-9ac3-5f7e365c5100",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n = 5\n",
    "method = \"hp\"\n",
    "d = \"four_digits_unsolvable\"\n",
    "\n",
    "for i in range(n):\n",
    "    for tp in top_params:\n",
    "        params = defaults.copy()\n",
    "        params.update(tp)\n",
    "        model_name = params.pop(\"model_name\")\n",
    "        folder = \"_\".join(str(int(v)) if isinstance(v, bool) else str(v) for v in [params[k] for k in defaults.keys()])\n",
    "        logdir = f\"data/logs/{d}/{method}_{i+1}/{model_name}/{folder}/\"\n",
    "        os.makedirs(logdir, exist_ok=True)\n",
    "        if \"0.dat\" in os.listdir(logdir) and not is_exception_pickle(logdir+\"0.dat\"):\n",
    "            print(f\"Skipping '{logdir}' as it seems already present\")\n",
    "            continue\n",
    "        params[\"n_shortcuts\"] = params[\"shortcut_factor\"] * params[\"n_select_sample\"]\n",
    "        t = md.enqueue(notebook=\"hp_shortcut_rounds.ipynb\", logdir=logdir, overwrite_logdir=True, num_gpus=len(gpus), \n",
    "                       dataset=d, model_name=model_name, **params)\n",
    "        await asyncio.sleep(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3924c04-2bc5-411a-a476-2735c309fd7d",
   "metadata": {},
   "source": [
    "## Ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eeed185-c60d-489f-bfde-fad062f0d022",
   "metadata": {},
   "outputs": [],
   "source": [
    "param_ranges = {\n",
    "    \"do_verify_moves\": [True, False],\n",
    "    \"do_verify_nodes\": [True], ### is set to do_verify_moves below!!!\n",
    "    \"do_shortcut\": [True, False],\n",
    "    \"do_evaluate_nodes\": [True],\n",
    "    \"n_propose_sample\": [1],\n",
    "    \"n_evaluate_sample\": [3],\n",
    "    \"n_select_sample\": [10],\n",
    "    \"n_verify_sample\": [3],\n",
    "    \"shortcut_factor\": [10],\n",
    "    \"tot_b\": [5]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55b0d3c8-9ff7-417f-b121-3de1fd3d0a45",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n = 5\n",
    "method = \"hp_ablation\"\n",
    "d = \"tot_test_split\"\n",
    "model_name = \"microsoft/phi-4\"\n",
    "\n",
    "for values in itertools.product(*param_ranges.values()):\n",
    "    for i in range(n):\n",
    "        params = dict(zip(param_ranges.keys(), values))\n",
    "        params[\"do_verify_nodes\"] = params[\"do_verify_moves\"]\n",
    "        folder = \"_\".join(str(int(v)) if isinstance(v, bool) else str(v) for v in values)\n",
    "        logdir = f\"data/logs/{d}/{method}_{i+1}/{model_name}/{folder}/\"\n",
    "        os.makedirs(logdir, exist_ok=True)\n",
    "        if \"0.dat\" in os.listdir(logdir) and not is_exception_pickle(logdir+\"0.dat\"):\n",
    "            print(f\"Skipping '{logdir}' as it seems already present\")\n",
    "            continue\n",
    "        params[\"n_shortcuts\"] = params[\"shortcut_factor\"] * params[\"n_select_sample\"]\n",
    "        t = md.enqueue(notebook=f\"hp_shortcut_rounds.ipynb\", logdir=logdir, overwrite_logdir=True, num_gpus=len(gpus), \n",
    "                       dataset=d, model_name=model_name, **params)\n",
    "        await asyncio.sleep(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a0c49c9-02ff-4cd0-a446-2eca1416dc83",
   "metadata": {},
   "source": [
    "# eval\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9be72e72-7f97-43dd-9230-b5bc39b0d713",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "Ws = {d:nx.read_graphml(f\"data/make_24/optimal_graphs/{d}_W.graphml\") for d in datasets}\n",
    "Ts = {d:nx.read_graphml(f\"data/make_24/optimal_graphs/{d}_T.graphml\") for d in datasets}\n",
    "\n",
    "samples = {d:[n for n in T if T.in_degree(n) == 0] for d, T in Ts.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06def84-5be4-4745-9c93-03b19c84f3b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import pickle\n",
    "import json\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from utils import get_unsolved_roots\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "def dirs_without_config_json(base_path):\n",
    "    result = []\n",
    "    for dirpath, dirnames, filenames in os.walk(base_path):\n",
    "        has_ipynb = any(f.endswith('.ipynb') for f in filenames)\n",
    "        has_config = 'config.json' in filenames\n",
    "        if has_ipynb and not has_config:\n",
    "            result.append(dirpath)\n",
    "    return result\n",
    "\n",
    "import shutil\n",
    "def delete_folders_with_errors(folders):\n",
    "    # runs through all given folders, attempts to load the 0.dat file\n",
    "    # if it is an error object, deletes the whole folder\n",
    "    for folder in folders:\n",
    "        with open(folder + \"0.dat\", \"rb\") as f:\n",
    "            o = pickle.load(f)\n",
    "        if isinstance(o, BaseException):\n",
    "            shutil.rmtree(folder)\n",
    "            print(\"deleted\", folder)\n",
    "\n",
    "def fix_tot_backends(folders):\n",
    "    # sometimes, the backend was set to gpt4 instead of the true model...\n",
    "    for folder in folders:\n",
    "        for model_name in model_names.keys():\n",
    "            if model_name in folder:\n",
    "                break\n",
    "        else:\n",
    "            print(\"Found no model\")\n",
    "        if not os.path.exists(folder + \"0.dat\"):\n",
    "            print(f\"Skipping folder {folder} since no 0.dat exists\")\n",
    "        with open(folder + \"0.dat\", \"rb\") as f:\n",
    "            logs = pickle.load(f)\n",
    "        logs[\"args\"].backend = model_name\n",
    "        with open(folder + \"0.dat\", \"wb\") as f:\n",
    "            pickle.dump(logs, f)\n",
    "\n",
    "hp_dirs = [os.path.dirname(x)+\"/\" for x in glob.glob(\"data/logs/game24/**/hp/**/config.json\", recursive=True)]\n",
    "tot_dirs = [os.path.dirname(x)+\"/\" for x in glob.glob(\"data/logs/tot_test_split/tot/**/0.dat\", recursive=True)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af4133e6-554a-4a91-9455-b608315304d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dirs_without_config_json(\"data/logs/tot_test_split/hp/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f65a0d0a-e01b-4670-b97b-825fa73c3e22",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyse_tot_baseline_solvable(folder):\n",
    "    with open(folder+ \"0.dat\", \"rb\") as f:\n",
    "        logs = pickle.load(f)\n",
    "    if isinstance(logs, BaseException):\n",
    "        print(\"FAILED\", folder)\n",
    "        return logs\n",
    "\n",
    "    d = logs[\"args\"].dataset\n",
    "\n",
    "    results = logs[\"results\"]\n",
    "    out = logs[\"outputs\"] # list of tuples (idx, (ys, infos))\n",
    "\n",
    "    if not all([n in Ws[d] for n in samples[d]]):\n",
    "        print(\"WARNING! There are unsolvable samples in your task, this is the wrong method!\")\n",
    "    \n",
    "    b = pd.DataFrame([x[1][1] for x in out])\n",
    "    mean_visited_nodes = b.steps.apply(lambda x: sum([len(y[\"new_ys\"]) for y in x])).mean()\n",
    "    rates = [results[:,:i].any(axis=1).mean() for i in range(1,results.shape[1]+1)]\n",
    "    r = dict(visited_nodes=[mean_visited_nodes] * len(rates), success_rates=rates, args=logs[\"args\"],\n",
    "             completion_tokens=[logs[\"completion_tokens\"]] * len(rates), prompt_tokens=[logs[\"prompt_tokens\"]] * len(rates))\n",
    "    return r\n",
    "\n",
    "def analyse_tot_baseline_with_unsolvables(fp):\n",
    "    raise NotImplementedError()\n",
    "    with open(fp, \"rb\") as f:\n",
    "        logs = pickle.load(f)\n",
    "    \n",
    "    results = logs[\"results\"]\n",
    "    out = logs[\"outputs\"] # list of tuples (idx, (ys, infos))\n",
    "    b = pd.DataFrame([x[1][1] for x in out])\n",
    "    mean_visited_nodes = b.steps.apply(lambda x: sum([len(y[\"new_ys\"]) for y in x])).mean()\n",
    "    \n",
    "    roots = [samples[d][o[0]] for o in out]\n",
    "    solutions = [o[1][0][0].strip().split(\"\\n\")[-1] for o in out]\n",
    "    result_df = pd.DataFrame({\"root\":roots, \"solution\":solutions})\n",
    "    result_df[\"claims_24\"] = result_df.solution.apply(claims_24)\n",
    "    result_df[\"solution\"] = result_df.apply(lambda row: row.solution if row.claims_24 else None, axis=1)\n",
    "    \n",
    "    results = evaluate_results_with_unsolvable(Ws[d], result_df)\n",
    "    sr = results.is_correct.mean()\n",
    "    r = dict(visited_nodes=[mean_visited_nodes], success_rates=[sr], results=results, args=logs[\"args\"],\n",
    "             completion_tokens=[logs[\"completion_tokens\"]], prompt_tokens=[logs[\"prompt_tokens\"]])\n",
    "    return r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8caf11ff-01ba-47bf-bea3-986866a35b6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tot_data = [y for x in tqdm(tot_dirs) if not isinstance((y:=analyse_tot_baseline_solvable(x)), BaseException)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87f61d34-2de4-4768-a68c-7ec487e2b178",
   "metadata": {},
   "outputs": [],
   "source": [
    "for d in tot_data:\n",
    "    if isinstance(d, BaseException): continue\n",
    "    l = f\"v={str(d[\"args\"].do_verify)},b=\"+str(d[\"args\"].n_select_sample)+\",m=\"+str(d[\"args\"].n_evaluate_sample)\n",
    "    plt.plot(d[\"completion_tokens\"], d[\"success_rates\"], \"-o\", label=l)\n",
    "\n",
    "plt.legend(loc=\"center right\", bbox_to_anchor=(1.8,0.5), ncols=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e30515e-acf0-43b4-978a-e4bc7a826b43",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = pd.DataFrame(tot_data)\n",
    "a[\"b\"] = a.args.apply(lambda x: x.n_select_sample)\n",
    "a[\"m\"] = a.args.apply(lambda x: x.n_evaluate_sample)\n",
    "a[\"do_verify\"] = a.args.apply(lambda x: x.do_verify)\n",
    "a = a.sort_values([\"b\",\"m\", \"do_verify\"])\n",
    "a = a.explode([\"visited_nodes\", \"success_rates\", \"completion_tokens\",\"prompt_tokens\"])\n",
    "a[\"bofn\"] = a.index.to_series().groupby(a.index).cumcount() + 1\n",
    "\n",
    "#mi = pd.MultiIndex.from_arrays([a.index, a.index.to_series().groupby(a.index).cumcount().values + 1])\n",
    "#a.index = mi\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f68b4a9-0899-4f69-b54e-729195d2ac99",
   "metadata": {},
   "outputs": [],
   "source": [
    "a[a.bofn == 1].iloc[:2].args.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c2b3a06-9d4a-4dc6-8450-01aea12e4481",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = a[a.bofn==1]\n",
    "#\n",
    "groups = x.groupby(\"b\")\n",
    "for b, g in groups:\n",
    "    l = f\"b={b},m={g.m.min()}..{g.m.max()}\"\n",
    "    plt.plot(g.completion_tokens, g.success_rates, \"-o\", label=l)\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "groups = x.groupby(\"m\")\n",
    "for m, g in groups:\n",
    "    l = f\"m={m},b={g.b.min()}..{g.b.max()}\"\n",
    "    plt.plot(g.completion_tokens, g.success_rates, \"-o\", label=l)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05ab06cc-e606-4a23-a00d-ac0094358b4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e334ae45-22a8-40c9-b9a5-c4cf55ba6b29",
   "metadata": {},
   "outputs": [],
   "source": [
    "c = [c for c in data if c[\"n_select_sample\"] == 5][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ea82399-fe96-42f9-8a60-05290529dc0e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "c535dc30-e2ba-44cb-89f5-7070da524088",
   "metadata": {},
   "source": [
    "# CPU Tasks\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e9ab15-3009-4a0b-a968-4cd778b69e64",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_slurm = False\n",
    "\n",
    "if run_slurm:\n",
    "    os.environ[\"PARTITIONS\"] = \"PARTITION GOES HERE\"\n",
    "    md = NotebookDeployer(default_backend=\"python-slurm\")\n",
    "else:\n",
    "    md = NotebookDeployer(gpu_list=[], default_backend=\"python\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9790394-9377-4452-aa7c-62bc041d065f",
   "metadata": {},
   "source": [
    "## Retrosynthesis\n",
    "***"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2935857f-77dc-422f-ac60-678eb53f43aa",
   "metadata": {},
   "source": [
    "### Baseline\n",
    "***\n",
    "See `retrosynthesis_baseline.ipynb`!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7426633-fbc5-4cd6-92fc-e2301cd84ba5",
   "metadata": {},
   "source": [
    "### Ours\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60479062-f548-4aae-b248-924f493955d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "for iterations in [10,50,100,200,300,400,500]:\n",
    "    for expansion_topk in [50]:\n",
    "        use_value_fn = True\n",
    "\n",
    "        folder = f\"data/chem/hp/{iterations}_{expansion_topk}_{use_value_fn}/\"\n",
    "        if os.path.exists(folder + \"result.json\"):\n",
    "            print(\"skipping\", folder)\n",
    "            continue\n",
    "        md.enqueue(notebook=\"retrosynthesis_hp.ipynb\", iterations=iterations, expansion_topk=expansion_topk, \n",
    "                   use_value_fn=use_value_fn, logdir=folder, num_gpus=0, overwrite_logdir=True)\n",
    "        await asyncio.sleep(0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1207cd87-0c46-4f0d-8ba2-c06a834805ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "md.status()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfdd903d-75fa-4d9d-8382-666feb389c6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "md.stop_all()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hp_env2",
   "language": "python",
   "name": "hp_env2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
