{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "525a064e-4754-4e06-97fb-dddddec33818",
   "metadata": {},
   "outputs": [],
   "source": [
    "# todo: evaluate / validate winning paths extra (every round, extend good_moves by all winning moves with fewer validations than X, same for states!\n",
    "# todo: store actions and their outcomes. e.g. first proposal on node x gave y, blabla"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b3d90bb-84e8-40a7-adf5-adce6e4c1906",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "#from huggingface_hub import login\n",
    "\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from matplotlib import pyplot as plt\n",
    "import pickle\n",
    "\n",
    "tqdm.pandas()\n",
    "\n",
    "pd.set_option('display.max_colwidth', 400)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52f68788-7998-4681-8672-fca9b3309f10",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Parametrization\n",
    "from deploy_utils import get_available_gpus\n",
    "\n",
    "#model_name = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "model_name = \"RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8\"\n",
    "model_name = \"mistralai/Mistral-Small-24B-Instruct-2501\"\n",
    "#model_name = \"microsoft/phi-4\"\n",
    "#model_name = \"RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic\"\n",
    "\n",
    "# model instantiation kwargs\n",
    "model_kwargs = {\n",
    "    \"enable_prefix_caching\": True, \n",
    "    \"tensor_parallel_size\": 1, \n",
    "    \"max_model_len\": 4096,\n",
    "    \"max_num_seqs\": 1024,\n",
    "}\n",
    "\n",
    "\n",
    "available_gpus =  get_available_gpus(max_utilization=5)\n",
    "\n",
    "# openai api\n",
    "backend_kwargs = dict(\n",
    "    base_url=\"localhost\",\n",
    "    api_key = os.environ.get(\"OPENAI_API_KEY\",\"\")\n",
    ")\n",
    "\n",
    "# vllm backend kwargs\n",
    "backend_kwargs = dict(\n",
    "    base_port=8000,\n",
    "    api_key = os.environ[\"HF_TOKEN\"],\n",
    "    offline=True,\n",
    ")\n",
    "backend = \"vllm\"\n",
    "\n",
    "\n",
    "logdir = \"data/temp/\"\n",
    "resume_from = None\n",
    "\n",
    "dataset = \"tot_test_split\"\n",
    "#dataset = \"four_digits_unsolvable\"\n",
    "\n",
    "tot_b = 5\n",
    "do_verify_moves = False\n",
    "do_verify_nodes = False\n",
    "do_shortcut = True\n",
    "do_evaluate_nodes = True\n",
    "\n",
    "n_propose_sample = 1\n",
    "n_evaluate_sample=3\n",
    "n_select_sample=25\n",
    "n_verify_sample=3\n",
    "n_shortcuts = n_select_sample * 10\n",
    "root_expedience = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f404e09-a65c-4070-922c-ad81930f60b6",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "# Notebook logic\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c0b44ff-0aa0-4576-8000-27a5ff022769",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from inference_client import MultiOpenAIClient\n",
    "\n",
    "os.environ['TORCH_CUDA_ARCH_LIST'] = \"9.0\" # H100 cards\n",
    "os.environ[\"VLLM_CONFIGURE_LOGGING\"] = \"0\"\n",
    "\n",
    "client = MultiOpenAIClient(backend=backend, model_name=model_name, model_kwargs=model_kwargs, **backend_kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af2469b0-2926-4f83-9d25-6525874ce639",
   "metadata": {},
   "source": [
    "# Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26137aa3-1728-4b2e-b112-888283235655",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import networkx as nx\n",
    "import itertools\n",
    "import json\n",
    "\n",
    "from tasks import get_task\n",
    "from utils import (get_solved_roots, get_unsolved_roots, subgraph_from, get_winning_subgraph, visualize_graph, get_solution_nodes, get_clean_subgraph)\n",
    "\n",
    "task = get_task(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1326a904-a08c-4e39-bc7c-0f2d7713dbfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_proposals(data):\n",
    "    data[\"convos\"] = data.node.apply(task.propose_convo_wrap)\n",
    "    data[\"convos\"] = client.chat(messages=data.convos.tolist(), n=n_propose_sample, temperature=0.7, max_tokens=1000, return_format=\"chatml\", token_usage_key=\"proposals\")\n",
    "    data[\"proposals\"] = data.convos.apply(lambda r: list(set((itertools.chain(*[x[-1][\"content\"].split(\"\\n\") for x in r])))))\n",
    "    return data.drop(\"convos\", axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ace24bf3-6afd-4c56-9bdf-1259740a3e8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def verify_moves(data):\n",
    "    '''\n",
    "    data is a df with at least node (str) and proposals (list of str) columns\n",
    "    \n",
    "    returns data with two more columns (plus usage columns):\n",
    "    \"move_failure_reasons\": tuple of verification prompts that failed majority vote check\n",
    "    \"is_verified_move\" bool, whether all verification prompts passed.\n",
    "    '''\n",
    "    moves = data.explode(\"proposals\")\n",
    "    \n",
    "    convos = moves.apply(lambda row: task.get_move_verification_prompts(row.node, row.proposals) , axis=1, result_type=\"expand\")\n",
    "    moves = pd.concat([moves, convos], axis=1)\n",
    "    \n",
    "    # stack together all prompts, giving a long format df instead of wide\n",
    "    moves = moves.melt(id_vars=[\"node\", \"proposals\",], value_vars=convos.columns, var_name=\"prompt_type\", value_name=\"convos\")\n",
    "    \n",
    "    moves[\"convos\"] = client.chat(messages=moves.convos.tolist(), n=n_verify_sample, temperature=0.7, max_tokens=1000, return_format=\"chatml\", token_usage_key=\"verify_moves\")\n",
    "    \n",
    "    moves_flat = moves.explode(\"convos\")\n",
    "    moves_flat[\"convos\"] = moves_flat.convos.apply(task.add_verification_turn)\n",
    "    moves_flat[\"convos\"] = client.chat(messages=moves_flat.convos.tolist(), n=1, temperature=0.1, max_tokens=10, return_format=\"chatml\", token_usage_key=\"verify_moves_followup\")\n",
    "    moves_flat[\"is_correct\"] = moves_flat.convos.apply(lambda x: \"yes\" in x[0][-1][\"content\"].lower())\n",
    "    \n",
    "    g = moves_flat.groupby([\"node\", \"proposals\", \"prompt_type\"])\n",
    "        \n",
    "    aggs = g.apply(lambda x: x.is_correct.mean() > 0.5, include_groups=False)\n",
    "    aggs.name = \"is_correct\"\n",
    "    aggs = aggs.to_frame().reset_index()\n",
    "    \n",
    "    def aggregate_move_verifications(df):\n",
    "        return pd.Series({\"move_failure_reasons\": tuple(df.prompt_type[~df.is_correct]), \n",
    "                          \"is_verified_move\": df.is_correct.all()})\n",
    "        \n",
    "    r = aggs.groupby([\"node\", \"proposals\"]).apply(aggregate_move_verifications, include_groups=False).reset_index()\n",
    "    verification_results = data[[\"node\",\"proposals\",\"is_shortcut_candidate\"]].explode(\"proposals\").merge(r, on=[\"node\",\"proposals\"])\n",
    "    \n",
    "    return verification_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cfafcab-983d-4145-88a4-3fe18d16d666",
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_moves(moves):\n",
    "    '''\n",
    "    data has a node (str) and proposals (list of str) column\n",
    "    returns a new df, exploded on proposals, with additional columns.\n",
    "    most importantly: \"next_node\", str\n",
    "    '''\n",
    "    data_flat = moves.copy()\n",
    "\n",
    "    data_flat[\"convos\"] = data_flat.apply(lambda row: task.apply_convo_wrap(row.node, row.proposals), axis=1)\n",
    "    data_flat[\"convos\"] = client.chat(messages=data_flat.convos.tolist(), temperature=0.7, max_tokens=1000, n=1, return_format=\"chatml\", token_usage_key=\"apply\")\n",
    "    data_flat[\"convos\"] = data_flat.convos.apply(lambda x: task.add_apply_follow_up_turn(x[0]))\n",
    "    data_flat[\"convos\"] = client.chat(messages=data_flat.convos.tolist(), temperature=0.1, max_tokens=1000, n=1, return_format=\"chatml\", token_usage_key=\"apply_fu\")\n",
    "    \n",
    "    data_flat[\"next_node\"] = data_flat.convos.apply(lambda x: x[0][-1][\"content\"])\n",
    "    return data_flat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36c6744f-b0f8-43d9-9af2-157251c3110b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def verify_new_nodes(new_nodes):\n",
    "    '''\n",
    "    new_nodes is a df with at least node (str) and proposals (str) and next_node (str) columns\n",
    "    \n",
    "    returns data with two more columns:\n",
    "    \"node_failure_reasons\": tuple of verification prompts that failed majority vote check\n",
    "    \"is_verified_node\" bool, whether all verification prompts passed.\n",
    "    '''\n",
    "    if not do_verify_nodes:\n",
    "        new_nodes[\"is_verified_node\"] = True\n",
    "        new_nodes[\"node_failure_reasons\"] = [tuple()] * len(new_nodes)\n",
    "        return new_nodes\n",
    "    \n",
    "    nodes = new_nodes.copy()\n",
    "    convos = nodes.apply(lambda row: task.get_node_verification_prompts(row.node, row.next_node, row.proposals) , axis=1, result_type=\"expand\")\n",
    "    nodes = pd.concat([nodes, convos], axis=1)\n",
    "    \n",
    "    nodes = nodes.melt(id_vars=[\"node\", \"proposals\", \"next_node\",], value_vars=convos.columns, var_name=\"prompt_type\", value_name=\"convos\")\n",
    "    \n",
    "    nodes[\"convos\"] = client.chat(messages=nodes.convos.tolist(), n=n_verify_sample, temperature=0.7, max_tokens=1000, return_format=\"chatml\", token_usage_key=\"verify_nodes\")\n",
    "    \n",
    "    nodes_flat = nodes.explode(\"convos\")\n",
    "    nodes_flat[\"convos\"] = nodes_flat.convos.apply(task.add_verification_turn)\n",
    "    nodes_flat[\"convos\"] = client.chat(messages=nodes_flat.convos.tolist(), n=1, temperature=0.1, max_tokens=10, return_format=\"chatml\", token_usage_key=\"verify_nodes_fu\")\n",
    "    nodes_flat[\"is_correct\"] = nodes_flat.convos.apply(lambda x: \"yes\" in x[0][-1][\"content\"].lower())\n",
    "    \n",
    "    g = nodes_flat.groupby([\"node\", \"proposals\", \"next_node\", \"prompt_type\"])    \n",
    "    aggs = g.apply(lambda x: x.is_correct.mean() > 0.5, include_groups=False)\n",
    "    aggs.name = \"is_correct\"\n",
    "    aggs = aggs.to_frame().reset_index()\n",
    "    \n",
    "    def aggregate_node_verifications(df):\n",
    "        return pd.Series({\"node_failure_reasons\": tuple(df.prompt_type[~df.is_correct]), \"is_verified_node\": df.is_correct.all()})\n",
    "        \n",
    "    r = aggs.groupby([\"node\", \"proposals\", \"next_node\"]).apply(aggregate_node_verifications, include_groups=False).reset_index()\n",
    "    results = new_nodes.merge(r, on=[\"node\",\"proposals\", \"next_node\"])\n",
    "    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f7f787b-3965-4338-b98c-5a15cccbae2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# stopping condition (a): all inputs need at least one path to a solution node\n",
    "def all_inputs_have_a_solution_criterion(G):\n",
    "    roots = [x for x,y in G.nodes(data=True) if y.get(\"is_root\",False)]\n",
    "    solutions = [x for x, d in G.nodes(data=True) if d.get(\"is_solution\", False)]\n",
    "    all_roots_solved = all(True if any(nx.has_path(G, r, s) for s in solutions) else False for r in roots)\n",
    "    return all_roots_solved"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5957392a-0da9-4c16-8443-7e0b7aa4b04a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_nodes(leaves):\n",
    "    if not do_evaluate_nodes or len(leaves) == 0:\n",
    "        leaves[\"expedience\"] = 1.0\n",
    "        return leaves\n",
    "    \n",
    "    nodes = leaves.copy()\n",
    "    nodes[\"convos\"] = nodes.next_node.apply(task.value_prompt_wrap)\n",
    "    nodes[\"convos\"] = client.chat(messages=nodes.convos.tolist(), n=n_evaluate_sample, temperature=0.7, max_tokens=1000, return_format=\"chatml\", token_usage_key=\"eval\")\n",
    "    \n",
    "    nodes_flat = nodes.explode(\"convos\")\n",
    "    nodes_flat[\"convos\"] = nodes_flat.convos.apply(task.add_value_turn)\n",
    "    nodes_flat[\"convos\"] = client.chat(messages=nodes_flat.convos.tolist(), n=1, temperature=0.1, max_tokens=10, return_format=\"chatml\", token_usage_key=\"eval_fu\")\n",
    "    nodes_flat[\"expedience\"] = nodes_flat.convos.apply(lambda x: task.value_outputs_unwrap(x[0][-1][\"content\"]))\n",
    "    \n",
    "    g = nodes_flat.groupby([\"node\",\"proposals\", \"next_node\"])\n",
    "    exp = g.expedience.mean().reset_index()\n",
    "    \n",
    "    return exp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "396d09e1-530d-43eb-b493-446e6efa2bcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "    solved_roots = pd.DataFrame({\"root\":[n for n in W_pred.nodes() if W_pred.in_degree(n) == 0]})\n",
    "    solved_roots[\"paths\"] = solved_roots.root.apply(lambda r: list(get_paths_to_solutions(r)))\n",
    "    solved_roots = solved_roots.explode(\"paths\").reset_index(drop=True)\n",
    "    # num_verified is not implemented currently...\n",
    "    solved_roots[\"min_verified\"] = solved_roots.paths.apply(lambda p: min([min(W_pred.nodes[l].get(\"num_verified\",0), \n",
    "                                                                               W_pred.nodes[r].get(\"num_verified\",0),     \n",
    "                                                                               W_pred.edges[(l,r,d)].get(\"num_verified\",0)) \n",
    "                                                                           for l,r,d in p]))\n",
    "    \n",
    "    # when multiple paths are available, choose the one with the strongest weakest link (measured in number of verifications)\n",
    "    idx = solved_roots.groupby(\"root\").min_verified.idxmax()\n",
    "    solved_roots = solved_roots.loc[idx.values]\n",
    "'''\n",
    "\n",
    "def get_final_answers(W_pred):\n",
    "    solution_nodes = get_solution_nodes(W_pred)\n",
    "    def get_paths_to_solutions(r):\n",
    "        for s in solution_nodes:\n",
    "            for p in nx.all_simple_edge_paths(W_pred, r, s):\n",
    "                yield tuple(p)\n",
    "    \n",
    "    solved_roots = pd.DataFrame({\"root\":[n for n in W_pred.nodes() if W_pred.in_degree(n) == 0]})\n",
    "    solved_roots[\"paths\"] = solved_roots.root.apply(lambda r: next(get_paths_to_solutions(r)))\n",
    "    solved_roots[\"moves\"] = solved_roots.paths.apply(lambda l: [e[2] for e in l])\n",
    "    \n",
    "    solved_roots[\"convos\"] = solved_roots.apply(lambda row: task.finalize_answer_prompt_wrap(row.root, row.moves), axis=1)\n",
    "    solved_roots[\"convos\"] = client.chat(messages=solved_roots.convos.tolist(), n=1, temperature=0.7, max_tokens=2000, return_format=\"chatml\", token_usage_key=\"final_answers\")\n",
    "    \n",
    "    solved_roots[\"convos\"] = solved_roots.convos.apply(lambda x: task.add_finalize_answer_turn(x[0]))\n",
    "    solved_roots[\"convos\"] = client.chat(messages=solved_roots.convos.tolist(), n=1, temperature=0.1, max_tokens=128, return_format=\"chatml\", token_usage_key=\"final_answers_fu\")\n",
    "    \n",
    "    solved_roots[\"solution\"] = solved_roots.convos.apply(lambda x: x[0][-1][\"content\"])\n",
    "    return solved_roots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c510f099-f59a-4468-af91-5940720d4be0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "def get_shortcut_candidates(G):\n",
    "    #we need two subgraphs to operate:\n",
    "    # the subgraph of currently winning nodes (W_pred)\n",
    "    W_pred = get_winning_subgraph(G)\n",
    "    if W_pred is None:\n",
    "        return None\n",
    "\n",
    "    # the subgraph of open roots\n",
    "    open_roots = get_unsolved_roots(G)\n",
    "    G_clean = nx.subgraph(G, [x for x in G.nodes() if not G.nodes[x].get(\"is_error\", False)])\n",
    "    reachable_from_open = nx.compose_all([nx.bfs_tree(G_clean, n) for n in open_roots])\n",
    "    OR = nx.subgraph(G, reachable_from_open.nodes())\n",
    "    \n",
    "    # we can then easily construct the \"interesting\" edges - those that could, to the best of our knowledge, actually be a shortcut\n",
    "    layers_or = list(nx.bfs_layers(OR, sources=open_roots))\n",
    "    layers_w = list(nx.bfs_layers(W_pred, sources=[n for n in W_pred.nodes if W_pred.in_degree(n) == 0]))\n",
    "    \n",
    "    interesting_edges = []\n",
    "    for i, (lor, lwp) in enumerate(zip(layers_or, layers_w[1:])):\n",
    "        a = pd.DataFrame(itertools.product(lor, lwp), columns=[\"parent\",\"child\"])\n",
    "        a[\"layer\"] = i\n",
    "        interesting_edges.append(a)\n",
    "    interesting_edges = pd.concat(interesting_edges, axis=0)\n",
    "\n",
    "    # compute distances\n",
    "    from Levenshtein import distance as lstein\n",
    "    interesting_edges[\"distance\"] = interesting_edges.apply(lambda row: lstein(row.parent, row.child, weights=(10,1,10)), axis=1)\n",
    "\n",
    "    # estimate some weak notion of \"reward\"\n",
    "    # assume a shortcut on layer i saves us materialising\n",
    "    # a subtree approximately as large as the average subtree of a node in layer i+1\n",
    "    # to estimate, we need to exclude shortcuts, since those could drastically lower the subtree size\n",
    "    g = G.copy()\n",
    "    g.remove_edges_from([e for e in G.edges if G.edges[e].get(\"is_shortcut\", False)])\n",
    "    solved_roots = get_solved_roots(g)\n",
    "    subtrees = [subgraph_from(g, s) for s in solved_roots]\n",
    "    \n",
    "    # for each tree, get the layer sizes\n",
    "    sizes = []\n",
    "    for st, sr in zip(subtrees, solved_roots):\n",
    "        sizes.append(list(len(x) for x in nx.bfs_layers(st, sources=sr)))\n",
    "    \n",
    "    mean_layer_size = pd.DataFrame(sizes).mean(axis=0).to_numpy()\n",
    "    # at layer i (i=0 is the root layer), we expect expected_new_nodes[i] new nodes to discover\n",
    "    expected_new_nodes = mean_layer_size[::-1].cumsum()[::-1][1:]\n",
    "    interesting_edges[\"expected_nodes_saved\"] = interesting_edges.layer.apply(lambda x: expected_new_nodes[x] if x < len(expected_new_nodes) else np.inf)\n",
    "    interesting_edges[\"expected_reward\"] = interesting_edges.expected_nodes_saved * ( 1 / interesting_edges.distance)\n",
    "    interesting_edges = interesting_edges.sort_values(\"expected_reward\", ascending=False)\n",
    "\n",
    "    return interesting_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d853b4f3-4126-4420-8192-02c03281c78a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_shortcut_moves(candidates):\n",
    "    data = candidates.copy()\n",
    "\n",
    "    data[\"convos\"] = data.apply(lambda row: task.shortcut_prompt_wrap(row.parent, row.child), axis=1)\n",
    "    data[\"convos\"] = client.chat(messages=data.convos.tolist(), n=1, temperature=0.7, max_tokens=1000, return_format=\"chatml\", token_usage_key=\"shortcuts\")\n",
    "\n",
    "    found = data[data.convos.apply(lambda x: (a:=x[0][-1][\"content\"].lower()).rfind(\"yes\") > a.rfind(\"no\"))].copy()\n",
    "    if len(found) == 0:\n",
    "        # return early\n",
    "        return None\n",
    "\n",
    "    found[\"convos\"] = found.convos.apply(lambda x: task.add_shortcut_turn(x[0]))\n",
    "    found[\"convos\"] = client.chat(messages=found.convos.tolist(), n=1, temperature=0.7, max_tokens=1000, return_format=\"chatml\", token_usage_key=\"shortcuts_fu\")\n",
    "\n",
    "    found_proposals = found[candidates.columns].copy().rename({\"parent\":\"node\"}, axis=1)\n",
    "    found_proposals[\"proposals\"] = found.convos.apply(lambda x: [x[0][-1][\"content\"]])\n",
    "    \n",
    "    return found_proposals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b61212bc-09ff-4f5f-88b4-1c48b17166d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_move_already_verified(G, node, move):\n",
    "    return any(1 for (_, c, d) in G.edges(node, data=True) if d.get(\"label\", None) == move and G.nodes[c].get(\"verified\", False))\n",
    "\n",
    "def get_proposal_candidates(G):\n",
    "    starting_nodes = get_starting_nodes(G)\n",
    "    proposal_results = get_proposals(starting_nodes)\n",
    "    proposal_results[\"is_shortcut_candidate\"] = False\n",
    "    return proposal_results\n",
    "\n",
    "def get_shortcut_proposals(G, shortcut_cache = set()):\n",
    "    candidates = get_shortcut_candidates(G)\n",
    "    if candidates is None or not do_shortcut:\n",
    "        return None\n",
    "\n",
    "    # filter out already investigated nodes\n",
    "    is_cached = candidates.apply(lambda row: (row.parent, row.child) in shortcut_cache , axis=1)\n",
    "    print(\"cache hit rate\", is_cached.sum(),\"/\", len(is_cached), f\"({100*is_cached.mean():0.2f}%)\")\n",
    "    print(\"cache size\", len(shortcut_cache))\n",
    "    candidates = candidates[~is_cached]\n",
    "    # get next best candidates\n",
    "\n",
    "    selection = candidates[:n_shortcuts]\n",
    "    # add to cache\n",
    "    shortcut_cache.update(zip(selection['parent'], selection['child']))\n",
    "\n",
    "    # return early if there are no uncached candidates\n",
    "    if len(selection) == 0:\n",
    "        return None\n",
    "\n",
    "    shortcut_proposals = get_shortcut_moves(selection)\n",
    "    if shortcut_proposals is not None:\n",
    "        proposals = shortcut_proposals[[\"node\", \"proposals\"]].copy().explode(\"proposals\")\n",
    "        # throw out moves that are already verified in the graph\n",
    "        proposals = proposals[proposals.apply(lambda row: not is_move_already_verified(G, row.node, row.proposals), axis=1)]\n",
    "        proposals = proposals.groupby(\"node\").sum().reset_index()\n",
    "        # stitch both together\n",
    "        proposals[\"is_shortcut_candidate\"] = True\n",
    "        return proposals[[\"node\",\"proposals\", \"is_shortcut_candidate\"]]\n",
    "    # no moves were found, returning none + usage\n",
    "    return None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "607951f0-bc55-47c0-92a0-2ab1eb807ce0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_open_nodes(G):\n",
    "    # get open nodes\n",
    "    # all leaf nodes in G, that are not a solution node, that have at least one root with no path to a solution node\n",
    "\n",
    "    # get all leaf nodes that are not _finished_ nodes (all solution nodes are finished nodes, too) or error nodes\n",
    "    G = G.copy()\n",
    "    W_pred = get_winning_subgraph(G)\n",
    "    G.remove_edges_from([e for e in G.edges if G.edges[e].get(\"is_shortcut\", False) and not e[1] in W_pred])\n",
    "    leaves = [x for x, d in G.nodes(data=True) if \n",
    "                                                  not d.get(\"is_finished\", False)\n",
    "                                                  and not d.get(\"is_error\", False) \n",
    "                                                  and d.get(\"verified\", True)]\n",
    "    open_roots = get_unsolved_roots(G)\n",
    "\n",
    "    gr = G.reverse()\n",
    "    tc = nx.transitive_closure(gr)\n",
    "    \n",
    "    # filter out all nodes that are not reachable from an open root\n",
    "    leaves = [l for l in leaves if any(tc.has_edge(l, r) or r == l for r in open_roots)]\n",
    "    if len(open_roots) == 0 or len(leaves) == 0:\n",
    "        return pd.DataFrame(columns=[\"node\"])\n",
    "\n",
    "    ldf = pd.DataFrame({\"node\":leaves})\n",
    "    ldf[\"roots_per_leaf\"] = ldf.node.apply(lambda l: [r for r in open_roots if tc.has_edge(l, r) or l==r])\n",
    "    ldf[\"num_roots\"] = ldf.roots_per_leaf.apply(len)\n",
    "\n",
    "    ldf[\"depth\"] = ldf.apply(lambda row: min((nx.shortest_path_length(G, r, row.node) for r in row.roots_per_leaf), default=0), axis=1) + 1\n",
    "    ldf[\"expedience\"] = ldf.node.apply(lambda l: G.nodes[l].get(\"expedience\",0))\n",
    "    ldf[\"num_expanded\"] = ldf.node.apply(lambda x: G.nodes[x][\"num_expanded\"])\n",
    "    \n",
    "    #ldf = ldf[ldf.num_expanded == ldf.num_expanded.min()]\n",
    "    #ldf[\"prod\"] = ldf.num_roots * ldf.expedience * ldf.depth/ 2**(ldf.num_expanded)\n",
    "    #sorted_leaves = ldf.sort_values(\"prod\", ascending=False)\n",
    "    \n",
    "    sorted_leaves = ldf.sort_values([\"num_expanded\", \"expedience\", \"num_roots\", \"depth\"], ascending=[True, False, False, False])\n",
    "\n",
    "    return sorted_leaves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71231c84-0fa9-4c73-85b5-1851fd396189",
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(x, temperature=1.0):\n",
    "    x = x / temperature\n",
    "    e_x = np.exp(x - np.max(x))\n",
    "    return e_x / e_x.sum()\n",
    "\n",
    "def get_starting_nodes(G):\n",
    "    # we need a better idea on which nodes to explore next!\n",
    "    o = get_open_nodes(G)\n",
    "    data = o[:n_select_sample]\n",
    "    #o[\"rank\"] = range(len(o)+1, 1, -1)\n",
    "    #o[\"prob\"] = o[\"rank\"] / o[\"rank\"].sum()\n",
    "    #o[\"prob\"] = softmax(o[\"prob\"], temperature=0.2)\n",
    "    #data = o.sample(n_select_sample, weights=\"prob\")\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6426dad-b7ca-4dd2-b46e-f24206c9ae4c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def load_run(folder, update_global_params=False, overwrite_logdir=False, iteration=None):\n",
    "    with open(folder + \"0.dat\", \"rb\") as f:\n",
    "        logs = pickle.load(f)\n",
    "    \n",
    "    with open(folder + \"config.json\", \"r\") as f:\n",
    "        config = json.load(f)\n",
    "\n",
    "    if update_global_params:\n",
    "        if not overwrite_logdir:\n",
    "            config[\"logdir\"] = logdir\n",
    "        globals().update(config)\n",
    "\n",
    "    # patch token usage\n",
    "    client.reset_token_usage()\n",
    "    client.completion_token_usage.update(logs[-1][\"token_usage\"])\n",
    "    client.call_count = logs[-1].get(\"model_calls\", 0)\n",
    "\n",
    "    if iteration == None:\n",
    "        iteration = len(logs) - 1\n",
    "\n",
    "    shortcut_cache = set()\n",
    "    for i in range(iteration,0,-1):\n",
    "        if (s := logs[i].get(\"shortcut_cache\", None)):\n",
    "            shortcut_cache = s\n",
    "            G = logs[i][\"graph\"]\n",
    "            l = logs[:i]\n",
    "            print(\"Latest found iteration:\",i)\n",
    "            break\n",
    "            \n",
    "    return G, l, shortcut_cache, config, logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ba675ff-0b81-4464-a695-f54fac9940c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "stop_criterion = \"return_first\"\n",
    "\n",
    "if stop_criterion == \"return_first\":\n",
    "    stop_func = all_inputs_have_a_solution_criterion\n",
    "\n",
    "def save_to_disk(logs):\n",
    "    # write to disk\n",
    "    with open(f\"{logdir}/0.dat\", \"wb\") as f:\n",
    "        pickle.dump(logs, f)\n",
    "\n",
    "    # write config to disk too\n",
    "    import json\n",
    "    params = {\n",
    "        \"dataset\": dataset,\n",
    "        \"tot_b\": tot_b,\n",
    "        \"do_verify_moves\": do_verify_moves,\n",
    "        \"do_verify_nodes\": do_verify_nodes,\n",
    "        \"do_shortcut\": do_shortcut,\n",
    "        \"do_evaluate_nodes\": do_evaluate_nodes,\n",
    "        \"n_propose_sample\": n_propose_sample,\n",
    "        \"n_evaluate_sample\": n_evaluate_sample,\n",
    "        \"n_select_sample\": n_select_sample,\n",
    "        \"n_verify_sample\": n_verify_sample,\n",
    "        \"n_shortcuts\": n_shortcuts,\n",
    "        \"shortcut_factor\": n_shortcuts // n_select_sample,\n",
    "        \"logdir\":logdir,\n",
    "        \"backend\": model_name,\n",
    "    }\n",
    "\n",
    "    # Write to JSON file\n",
    "    with open(f\"{logdir}/config.json\", \"w\") as f:\n",
    "        json.dump(params, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1173967-bab3-4ca8-a2a6-aa4d68ef74df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "do_plot_graph=False\n",
    "\n",
    "max_total_selections = task.max_depth * tot_b * len(task.samples)\n",
    "max_steps, remainder = divmod(max_total_selections, n_select_sample)\n",
    "print(\"We are allowed to explore\", max_total_selections, \"nodes for:\")\n",
    "print(f\"\\tper sample breadth (tot_b): {tot_b}\\n\\tdepth per sample (steps): {task.max_depth}\\n\\tnum samples: {len(task.samples)}\")\n",
    "print(f\"With n_select_sample = {n_select_sample}, we can run {max_steps} iterations at most for a fair comparison.\")\n",
    "if remainder != 0:\n",
    "    print(f\"WARNING: division resulted in a remainder of {remainder} selections.\")\n",
    "\n",
    "client.wait_done()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36b5d777-d4e1-4b91-9ba3-05f87b8022f6",
   "metadata": {},
   "source": [
    "logdir=\"data/temp/a/\"\n",
    "resume_from = \"data/logs/tot_test_split/hp/mistralai/Mistral-Small-24B-Instruct-2501/1_1_1_1_1_3_10_3_20_5/\"\n",
    "#resume_from = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7c6c580-8ba4-4201-b12c-61e6e3d90840",
   "metadata": {},
   "outputs": [],
   "source": [
    "if resume_from is None:\n",
    "    logs = []\n",
    "    shortcut_cache = set()\n",
    "\n",
    "    G = nx.MultiDiGraph()\n",
    "    for n in task.samples:\n",
    "        G.add_node(n, is_root=True, expedience=root_expedience, verified=True, num_expanded=0)\n",
    "    print(\"Setup new graph with\",len(G),\"tasks\")\n",
    "else:\n",
    "    print(\"Resuming from\", resume_from)\n",
    "    G, logs, shortcut_cache, config, all_logs = load_run(resume_from, update_global_params=True, iteration=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3b96a2f-3338-462d-a1b6-4a7853272fe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_log(start_time, **kwargs):\n",
    "    diff = time.time() - start_time\n",
    "    logs.append(dict(token_usage=dict(client.completion_token_usage), model_calls=client.call_count, graph=deepcopy(G), \n",
    "                     time=diff, shortcut_cache=deepcopy(shortcut_cache), **kwargs))\n",
    "    save_to_disk(logs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60db3f69-0e83-4400-a1da-7e9d260b6027",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for current_step in range(len(logs), max_steps):\n",
    "    print(current_step, max_steps, \"SHORTCUT ROUND\" if current_step % 2 == 1 else \"PROPOSAL ROUND\")\n",
    "    start_time = time.time()\n",
    "\n",
    "    open_roots = get_unsolved_roots(G)\n",
    "    print(\"Open problems:\", len(open_roots))\n",
    "    if open_roots == 0:\n",
    "        print(\"no more roots to solve, early stopping at\", current_step+1,\"of\",max_steps)\n",
    "        break\n",
    "\n",
    "    # get move candidates\n",
    "    if current_step % 2 == 0:\n",
    "        print(\"Getting proposals\")\n",
    "        proposal_results = get_proposal_candidates(G)\n",
    "        for i, row in proposal_results.iterrows():\n",
    "            G.nodes[row.node][\"num_expanded\"] += n_propose_sample\n",
    "    else:\n",
    "        print(\"Getting shortcuts\")\n",
    "        proposal_results = get_shortcut_proposals(G, shortcut_cache=shortcut_cache)\n",
    "\n",
    "    if proposal_results is None or len(proposal_results) == 0:\n",
    "        # we didnt find any nodes\n",
    "        print(\"NO NODES, CONTINUING\")\n",
    "        save_log(start_time)\n",
    "        continue\n",
    "\n",
    "    # TODO: find all edges in W_pred with less verifications than necessary\n",
    "\n",
    "    # verify moves\n",
    "    print(\"Verifying moves\")\n",
    "    move_verification_results = verify_moves(proposal_results)\n",
    "    good_moves = move_verification_results[move_verification_results.is_verified_move]\n",
    "    bad_moves = move_verification_results[~move_verification_results.is_verified_move]\n",
    "\n",
    "    # add \"error\" nodes to graph\n",
    "    # these need to be unique, or else the visualization looks bad\n",
    "    k = len([x for x,d in G.nodes(data=True) if d.get(\"is_error\", False)])\n",
    "    for i, row in bad_moves.iterrows():\n",
    "        node = f\"error_{k}\"\n",
    "        G.add_node(node, label=\"X\", is_error=True, num_expanded=0)\n",
    "        G.add_edge(row.node, node, key=row.proposals, label=row.proposals, is_error=True, \n",
    "                   is_shortcut=row.is_shortcut_candidate)\n",
    "        k += 1\n",
    "\n",
    "    if len(good_moves) == 0:\n",
    "        print(\"ALL MOVES WERE BAD\")\n",
    "        save_log(start_time)\n",
    "        continue\n",
    "\n",
    "    # apply the good moves\n",
    "    print(\"Applying moves\")\n",
    "    application_results = apply_moves(good_moves)\n",
    "\n",
    "    # filter new nodes\n",
    "    print(\"Verifying new nodes\")\n",
    "    node_verification_results = verify_new_nodes(application_results.drop(\"convos\", axis=1))\n",
    "    good_new = node_verification_results[node_verification_results.is_verified_node].copy()\n",
    "\n",
    "    # add to graph\n",
    "    for _, row in node_verification_results.iterrows():\n",
    "        if row.is_verified_node:\n",
    "            ne = 0 if row.next_node not in G else G.nodes[row.next_node].get(\"num_expanded\",0)\n",
    "            is_shortcut = row.is_shortcut_candidate if (row.node, row.next_node) not in G.edges else G.edges[(row.node, row.next_node)][\"is_shortcut\"]\n",
    "            G.add_node(row.next_node, verified=row.is_verified_node, num_expanded=ne)\n",
    "            G.add_edge(row.node, row.next_node, key=row.proposals, label=row.proposals, is_shortcut=is_shortcut)\n",
    "        else:\n",
    "            # let the move, though deemed a correct move, point to an error node, since the apply step caused an error\n",
    "            node = f\"error_{k}\"\n",
    "            G.add_node(node, label=row.next_node, is_error=True, num_expanded=0)\n",
    "            G.add_edge(row.node, node, key=row.proposals, label=row.proposals, is_shortcut=row.is_shortcut_candidate)\n",
    "            k += 1\n",
    "\n",
    "    if len(good_new) == 0:\n",
    "        print(\"ALL NODES WERE BAD\")\n",
    "        save_log(start_time)\n",
    "        continue\n",
    "    \n",
    "    print(\"Getting finished nodes\")\n",
    "    good_new[\"has_finished\"] = good_new.next_node.apply(task.is_irreducible_node)\n",
    "    good_new[\"is_solution\"] = good_new.next_node.apply(task.is_solution_node)\n",
    "\n",
    "    # add finished state to graph\n",
    "    for _, row in good_new[good_new.has_finished].iterrows():\n",
    "        G.add_node(row.next_node, is_finished=True, expedience=-1)\n",
    "\n",
    "    # add state to graph\n",
    "    for _, row in good_new[good_new.is_solution].iterrows():\n",
    "        G.add_node(row.next_node, is_solution=True, expedience=20)          \n",
    "\n",
    "    # get all non-finished nodes -- the new leaves in the game tree\n",
    "    leaves = good_new[~good_new.has_finished][[\"node\",\"proposals\", \"next_node\"]].copy()\n",
    "    # filter out those that have been evaluated before!\n",
    "    if len(leaves) > 0:\n",
    "        leaves = leaves[leaves.next_node.apply(lambda x: not \"expedience\" in G.nodes[x])]\n",
    "    if len(leaves) > 0:\n",
    "        # estimate \"reward\", here called \"expedience\"\n",
    "        print(\"Evaluate nodes\")\n",
    "        evaluation_results = evaluate_nodes(leaves)\n",
    "        # add to graph\n",
    "        for _, row in evaluation_results.iterrows():\n",
    "            G.add_node(row.next_node, expedience=row.expedience)\n",
    "\n",
    "    save_log(start_time)\n",
    "    # check done\n",
    "    if stop_func(G):\n",
    "        print(f\"Early stopping at {current_step+1}/{max_steps}\")\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f2bfe21-7c10-43ed-8801-5b2f4a747f29",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "W_pred = get_winning_subgraph(G)\n",
    "if W_pred is None:\n",
    "    solutions = pd.DataFrame({\"root\":task.samples, \"solution\":None, \"moves\":None})\n",
    "    predictions = solutions\n",
    "else:\n",
    "    solutions = get_final_answers(W_pred)\n",
    "    unsolved_samples = list(set(task.samples).difference(solutions.root))\n",
    "    \n",
    "    u = pd.DataFrame({\"root\": unsolved_samples, \"solution\": None})\n",
    "    predictions = pd.concat([solutions[[\"root\",\"moves\",\"solution\"]], u], axis=0)\n",
    "\n",
    "# save to disk\n",
    "save_log(start_time, solutions=solutions, results=predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f847c01d-0b9d-4c2d-a843-7ff58e260346",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(task.evaluate_results(predictions).is_correct.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ced29dd5-6179-450b-a946-ad15013adef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sum(client.completion_token_usage.values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee2ab006-ab66-4b5a-8467-28503423ed65",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "client.close()\n",
    "sys.exit()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "480c25e3-c29f-4b75-a967-6443ccb97198",
   "metadata": {},
   "source": [
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e287c57e-4c65-455e-a0a4-a8f8caa3c8cd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a8f1f729-8b2d-4329-865d-de7ff875a5c2",
   "metadata": {},
   "source": [
    "# Debug stuff\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a056fa0-ad5d-429a-89a3-87cb73c0480a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = proposal_results.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39986f5f-2aeb-4608-a189-7322370727ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "moves = data.explode(\"proposals\")\n",
    "\n",
    "convos = moves.apply(lambda row: task.get_move_verification_prompts(row.node, row.proposals) , axis=1, result_type=\"expand\")\n",
    "moves = pd.concat([moves, convos], axis=1)\n",
    "\n",
    "# stack together all prompts, giving a long format df instead of wide\n",
    "moves = moves.melt(id_vars=[\"node\", \"proposals\",], value_vars=convos.columns, var_name=\"prompt_type\", value_name=\"convos\")\n",
    "\n",
    "moves[\"convos\"] = client.chat(messages=moves.convos.tolist(), n=1, temperature=0.7, max_tokens=1000, return_format=\"openai\", token_usage_key=\"verify_moves\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89c804b2-9141-4a29-9860-3450dec11cc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from fractions import Fraction\n",
    "import glob\n",
    "import re\n",
    "from collections import Counter\n",
    "import sympy as sp\n",
    "from functools import lru_cache\n",
    "\n",
    "TOKEN_REGEX = re.compile(r'[^\\s]+')\n",
    "@lru_cache(maxsize=None)\n",
    "def extract_numbers_from_eq(s):\n",
    "    tokens = TOKEN_REGEX.findall(s.replace(\"(\", \" \").replace(\")\", \" \").strip())\n",
    "    numbers = []\n",
    "    for token in tokens:\n",
    "        try:\n",
    "            expr = sp.nsimplify(token)\n",
    "            for node in sp.preorder_traversal(expr):\n",
    "                if isinstance(node, sp.Number):\n",
    "                    numbers.append(node)\n",
    "        except Exception:\n",
    "            continue  # Skip tokens that can't be parsed\n",
    "    return numbers\n",
    "\n",
    "@lru_cache(maxsize=None)\n",
    "def canonicalize_node(node_str):\n",
    "    numbers = sorted(extract_numbers_from_eq(node_str))\n",
    "    return \" \".join([str(x) for x in numbers])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a96be1-8820-461d-9250-9cc0ab23d76e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import reachable_leaves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "752928c6-718d-43f9-b761-90c7595634ec",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data = []\n",
    "for l in tqdm(old_logs):\n",
    "    if \"graph\" not in l: continue\n",
    "    G = l[\"graph\"]\n",
    "    open_roots = get_unsolved_roots(G)\n",
    "    if len(open_roots) == 0: continue\n",
    "    sc = get_shortcut_candidates(G)\n",
    "    if sc is None or len(sc) == 0:\n",
    "        data.append((0,0))\n",
    "        continue\n",
    "    sc = sc.reset_index(drop=True)\n",
    "    \n",
    "    sc[\"pc\"] = sc.parent.apply(canonicalize_node)\n",
    "    sc[\"cc\"] = sc.child.apply(canonicalize_node)\n",
    "    \n",
    "    sc[\"is_winning_edge\"] = sc.apply(lambda row: W.has_edge(row.pc, row.cc), axis=1)\n",
    "    \n",
    "    parents = set(sc[sc.is_winning_edge].parent)\n",
    "    # all roots that reach a node in parents\n",
    "    gr = G.reverse()\n",
    "    all_reachable_roots = reachable_leaves(gr, set(sc[sc.is_winning_edge].parent))\n",
    "    selected_reachable_roots = reachable_leaves(gr, set(sc.parent[:n_shortcuts][sc.is_winning_edge]))\n",
    "    \n",
    "    total_num_shortcuttable_roots = len(all_reachable_roots)\n",
    "    selected_num_shortcuttable_roots = len(selected_reachable_roots)\n",
    "    data.append((selected_num_shortcuttable_roots, total_num_shortcuttable_roots))\n",
    "    \n",
    "    #print(f\"We have {selected_num_shortcuttable_roots} good shortcuts in our batch of {len(sc[:n_shortcuts])} (total of {total_num_shortcuttable_roots} in {len(sc)} candidates)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcb0f1de-756b-4d2d-8fa3-76dda86524c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot([x[0] for x in data], label=\"Good Shortcuts in Selection\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21bfa6f4-df16-432c-abe1-2f9ba64860bd",
   "metadata": {},
   "source": [
    "# Debug stuff\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe2493e7-8caf-4eb0-a273-b99763521539",
   "metadata": {},
   "outputs": [],
   "source": [
    "from fractions import Fraction\n",
    "import glob\n",
    "\n",
    "\n",
    "def parse_number(x):\n",
    "    try:\n",
    "        return int(x) if '.' not in x and '/' not in x else float(Fraction(x))\n",
    "    except ValueError:\n",
    "        return None   # or raise or handle differently\n",
    "\n",
    "def is_winning_node(s):\n",
    "    try:\n",
    "        if s in task.samples: return False\n",
    "        numbers = [parse_number(x) for x in s.replace(\",\",\"\").split()]\n",
    "        numbers_sorted = sorted(numbers)\n",
    "        numbers_tuple_str = \" \".join(str(int(n)) if n.is_integer() else str(n) for n in numbers_sorted)\n",
    "    except:\n",
    "        print(s)\n",
    "        return False\n",
    "    return numbers_tuple_str in task.W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e84418e6-8397-4751-bc14-fac9f6d98d21",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_run(analyzed_config, x=\"_idx_\"):\n",
    "    root_shortcut_rate = analyzed_config[\"root shortcut rate\"]\n",
    "    winning_paths_with_shortcuts_rate = analyzed_config[\"winning paths with shortcuts rate\"]\n",
    "    a = analyzed_config[\"count winning\"]\n",
    "    b = analyzed_config[\"idx of first winning\"]\n",
    "    d = analyzed_config[\"num solved\"]\n",
    "    x = range(len(a)) if x == \"_idx_\" else analyzed_config[x]\n",
    "    logs = analyzed_config[\"logs\"]\n",
    "\n",
    "    fig, ax1 = plt.subplots()\n",
    "    ax2 = ax1.twinx()\n",
    "\n",
    "    l4, = ax2.plot(x,root_shortcut_rate, \"--o\", c=\"tab:red\", label=\"root shortcut rate\")\n",
    "    l5, = ax2.plot(x,winning_paths_with_shortcuts_rate, \"--o\", c=\"black\", label=\"winning paths with shortcuts rate\")\n",
    "\n",
    "    l1, = ax1.plot(x,a, \"-o\", label=\"count winning\")\n",
    "    l2, = ax1.plot(x, b, \"-o\", label=\"idx of first winning\")\n",
    "    l3, = ax1.plot(x, d, \"-o\", label=\"num solved at iter\")\n",
    "    lines = [l1,l2,l3,l4,l5]\n",
    "    \n",
    "    ax3 = ax1.twinx()\n",
    "    kk = analyzed_config[\"nodes per layer\"]\n",
    "    for i in range(kk.shape[1]):\n",
    "        l6, = ax3.plot(x, kk[:,i], label=f\"nodes in layer {i+1}\")\n",
    "        lines.append(l6)\n",
    "    \n",
    "\n",
    "    if analyzed_config[\"success_rates\"] is not None:\n",
    "        l6, = ax2.plot(x, analyzed_config[\"success_rates\"], \"-o\",c=\"pink\", label=\"success rate\")\n",
    "        lines.append(l6)\n",
    "\n",
    "    labels = [l.get_label() for l in lines]\n",
    "\n",
    "    \n",
    "    ax1.legend(lines, labels, loc='upper right', bbox_to_anchor=(2.1, 1.0), ncol=2)\n",
    "\n",
    "    if \"results\" in logs[-1]:\n",
    "        plt.title(task.evaluate_results(logs[-1][\"results\"]).is_correct.mean())\n",
    "    props = dict(boxstyle='round', facecolor='whitesmoke', alpha=0.8)\n",
    "    \n",
    "    keys_to_extract = [\"task_name\", \"dataset\", \"tot_b\", \"do_verify_moves\", \"do_verify_nodes\",\n",
    "                       \"do_shortcut\", \"do_evaluate_nodes\", \"n_propose_sample\", \"n_evaluate_sample\",\n",
    "                       \"n_select_sample\", \"n_verify_sample\", \"n_shortcuts\" ]\n",
    "\n",
    "    # Extract subdict\n",
    "    sub_config = {k: analyzed_config[k] for k in keys_to_extract if k in analyzed_config}    # Place textbox outside the right of the plot area\n",
    "    plt.figtext(1.5, .4, json.dumps(sub_config, indent=4)[1:-1].strip(), ha='right', va='center', fontsize=10,\n",
    "                bbox=props, fontfamily='monospace')\n",
    "    \n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hp_env2",
   "language": "python",
   "name": "hp_env2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
