{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "af74a755-a4a1-4461-89ee-585a21330369",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ortools.linear_solver import pywraplp\n",
    "from tqdm import tqdm\n",
    "\n",
    "def solve_directed_steiner(W, num_threads=1):\n",
    "    solver = pywraplp.Solver.CreateSolver('SCIP')\n",
    "    solver.SetNumThreads(num_threads)\n",
    "\n",
    "    R = nx.DiGraph(W).reverse()\n",
    "\n",
    "    roots = [n for n in R.nodes() if R.in_degree(n) == 0]\n",
    "    leaves = [n for n in R.nodes() if R.out_degree(n) == 0]\n",
    "    \n",
    "    edges = list(R.edges())\n",
    "    nodes = list(R.nodes())\n",
    "\n",
    "    # Decision variables: x[e] == 1 if edge e is in the solution\n",
    "    x = {e: solver.BoolVar(f\"x_{e}\") for e in edges}\n",
    "\n",
    "    # Flow variables: one per terminal per edge\n",
    "    f = {(e, t): solver.NumVar(0, 1, f\"f_{e}_{t}\") for e in edges for t in leaves}\n",
    "\n",
    "    # Objective: minimize number of selected edges\n",
    "    solver.Minimize(solver.Sum(x[e] for e in edges))\n",
    "\n",
    "    for t in leaves:\n",
    "        for v in nodes:\n",
    "            in_edges = [(u, v) for u in R.predecessors(v)]\n",
    "            out_edges = [(v, w) for w in R.successors(v)]\n",
    "            inflow = solver.Sum(f[(e, t)] for e in in_edges)\n",
    "            outflow = solver.Sum(f[(e, t)] for e in out_edges)\n",
    "    \n",
    "            if v in roots:\n",
    "                solver.Add(outflow - inflow == 1)\n",
    "            elif v == t:\n",
    "                solver.Add(inflow - outflow == 1)\n",
    "            else:\n",
    "                solver.Add(outflow - inflow == 0)\n",
    "\n",
    "    # Flow only allowed on selected edges\n",
    "    for e in edges:\n",
    "        for t in leaves:\n",
    "            solver.Add(f[(e, t)] <= x[e])\n",
    "\n",
    "    status = solver.Solve()\n",
    "    print(\"Solving took:\", solver.WallTime() / 1000)\n",
    "\n",
    "    if status == pywraplp.Solver.OPTIMAL:\n",
    "        selected_edges = [(b,a) for (a,b) in edges if x[(a,b)].solution_value() > 0.5]\n",
    "        # extract nodes\n",
    "        steiner_nodes = set()\n",
    "        for a,b in selected_edges:\n",
    "            steiner_nodes.add(a)\n",
    "            steiner_nodes.add(b)\n",
    "        steiner_nodes = steiner_nodes.difference(roots).difference(leaves)\n",
    "        return steiner_nodes, selected_edges\n",
    "    else:\n",
    "        raise ValueError(\"Failed to solve\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc4ec6ef-1a20-4051-90fe-cdb54bfa1456",
   "metadata": {},
   "source": [
    "# Make 24"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "921883fc-3495-4752-a5e0-8e94ca0b04c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "\n",
    "from itertools import combinations, permutations, combinations_with_replacement \n",
    "from functools import lru_cache\n",
    "import multiprocess as mp\n",
    "\n",
    "ops = {\n",
    "    \"+\": combinations,\n",
    "    \"*\": combinations,\n",
    "    \"/\": permutations,\n",
    "    \"-\": permutations\n",
    "}\n",
    "\n",
    "@lru_cache(maxsize=None)\n",
    "def get_new_node(l, r, op, n_tuple):\n",
    "    try:\n",
    "        eq = f\"({l} {op} {r})\"\n",
    "        num = sympy.together(eq)\n",
    "        if not num.is_real:\n",
    "            return None\n",
    "        leftovers = list(n_tuple)\n",
    "        leftovers.remove(l)\n",
    "        leftovers.remove(r)\n",
    "        new_node = tuple(sorted((num,) + tuple(leftovers)))\n",
    "        return (n_tuple, new_node, eq)\n",
    "    except Exception:\n",
    "        return None\n",
    "\n",
    "def process_leaf_batch(batch):\n",
    "    results = []\n",
    "    for n in batch:\n",
    "        for op, f in ops.items():\n",
    "            for l, r in f(n, 2):\n",
    "                res = get_new_node(l, r, op, n)\n",
    "                if res:\n",
    "                    results.append(res)\n",
    "    return results\n",
    "\n",
    "def chunkify(iterable, size):\n",
    "    for i in range(0, len(iterable), size):\n",
    "        yield iterable[i:i + size]\n",
    "\n",
    "def compute_full_graph(inputs, max_workers=32, chunk_size=4):\n",
    "    #input_combinations = [[1, 1, 1, 2], [1, 1, 1 ,1]]\n",
    "    print(\"Solving\",len(inputs),\"problems\")\n",
    "    roots = [tuple(sympy.sympify(str(x)) for x in comb) for comb in inputs]\n",
    "\n",
    "    T = nx.MultiDiGraph()\n",
    "    T.add_nodes_from(roots)\n",
    "\n",
    "    with mp.Pool(processes=max_workers) as pool:\n",
    "        while True:\n",
    "            leaves = [n for n in T.nodes() if T.out_degree(n) == 0 and len(n) > 1]\n",
    "            if not leaves:\n",
    "                break\n",
    "            batched_leaves = list(chunkify(leaves, chunk_size))\n",
    "            results = []\n",
    "            for batch_result in tqdm(pool.imap_unordered(process_leaf_batch, batched_leaves), total=len(batched_leaves), desc=\"Processing batches\"):\n",
    "                results.append(batch_result)\n",
    "\n",
    "            for batch_result in results:\n",
    "                for src, dst, label in batch_result:\n",
    "                    T.add_edge(src, dst, key=label, label=label)\n",
    "    # relabel to desired format\n",
    "    l = lambda n: \" \".join(str(x) for x in n)\n",
    "    m = {n:l(n) for n in T.nodes}\n",
    "    assert len(set(m.keys())) == len(set(m.values()))\n",
    "    T = nx.relabel_nodes(T, m)\n",
    "\n",
    "    return T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "056dcdc8-1902-4543-8caf-968578b292f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_make_24_dataset(inputs, folder, prefix, solve_optimal=True):\n",
    "    inputs = [tuple(x) for x in inputs]\n",
    "    T = compute_full_graph(inputs)\n",
    "    # inputs were likely reformatted, lets recover them\n",
    "    inputs = [n for n in T if T.in_degree(n) == 0]\n",
    "    print(\"# nodes global\", len(T))\n",
    "    print(\"# edges global\", len(T.edges))\n",
    "\n",
    "    winning_node = \"24\"\n",
    "\n",
    "    # dataset stats: how many nodes would I get if inputs were isolated?\n",
    "    single_sizes = [(len(k := nx.subgraph(T, nx.bfs_tree(T, x, reverse=False))),len(k.edges)) for x in tqdm(inputs)]\n",
    "    max_num_edges = sum(x[1] for x in single_sizes)\n",
    "    max_num_nodes = sum(x[0] for x in single_sizes)\n",
    "    print(\"# nodes isolation\", max_num_nodes)\n",
    "    print(\"# edges isolation\", max_num_edges)\n",
    "    X = T.copy().to_undirected()\n",
    "    if winning_node in T:\n",
    "        X.remove_node(winning_node)\n",
    "    comps = sorted(nx.connected_components(X), key=len)\n",
    "    print(\"# connected comps in T:\", len(comps))\n",
    "    \n",
    "    # construct tree of only winning paths\n",
    "    if winning_node not in T:\n",
    "        print(\"Your node\",winning_node,\"is not in the graph, typo?\\n\\tReturning empty W and S\")\n",
    "        W = nx.MultiDiGraph()\n",
    "    else:\n",
    "        reachable_from_winner = nx.bfs_tree(T, winning_node, reverse=True)\n",
    "        W = nx.subgraph(T, reachable_from_winner.nodes())\n",
    "        X = W.copy().to_undirected()\n",
    "        X.remove_node(winning_node)\n",
    "        comps = sorted(nx.connected_components(X), key=len)\n",
    "        print(\"# connected comps in W:\", len(comps))\n",
    "\n",
    "    num_paths = sum(len(list(nx.all_simple_paths(W, s, winning_node))) for s in [x for x in W.nodes() if W.in_degree(x) == 0])\n",
    "    # dataset stats: how large is the winning subgraph in T\n",
    "    print(\"# winning paths:\", num_paths)\n",
    "    print(\"# needed edges:\", len(W.edges))\n",
    "    print(\"# needed nodes\", len(W))\n",
    "    print(\"# unsolved problems:\", len([x for x in inputs if x not in W]), \"/\", len(inputs))\n",
    "\n",
    "    \n",
    "    if len(W) > 0 and solve_optimal:\n",
    "        stn, ste = solve_directed_steiner(W)\n",
    "    \n",
    "        print(\"# minimal edges:\", len(ste))\n",
    "        print(\"# minimal nodes:\", len(stn))\n",
    "        \n",
    "        # proportion of all possible edges to minimally needed\n",
    "        ratio = len(ste) / len(W.edges)\n",
    "        print(f\"{ratio * 100:.2f}%\")\n",
    "    \n",
    "        # optimal graph with all possible routes,\n",
    "        S = W.subgraph(nx.DiGraph(ste).nodes())\n",
    "    else:\n",
    "        S = nx.DiGraph()\n",
    "\n",
    "    nx.write_graphml(T, f\"{folder}/{prefix}_T.graphml\")\n",
    "    nx.write_graphml(W, f\"{folder}/{prefix}_W.graphml\")\n",
    "    nx.write_graphml(S, f\"{folder}/{prefix}_S.graphml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "19197d13-919d-4179-87df-95173f3a0fc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.makedirs(\"data/optimal_graphs\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ab37c08-dbb7-4de3-8210-ea62c6d91677",
   "metadata": {},
   "source": [
    "## All four digits (1-12) with replacement\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ee58e6b4-fdd4-4f35-be2d-7d3e193c074b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solving 1365 problems\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████| 342/342 [00:01<00:00, 251.25it/s]\n",
      "Processing batches: 100%|██████████| 2860/2860 [00:06<00:00, 426.52it/s]\n",
      "Processing batches: 100%|██████████| 8360/8360 [00:12<00:00, 649.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# nodes global 79496\n",
      "# edges global 425724\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1365/1365 [00:13<00:00, 97.84it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# nodes isolation 979716\n",
      "# edges isolation 2268259\n",
      "# connected comps in T: 1\n",
      "# connected comps in W: 7\n",
      "# winning paths: 13058\n",
      "# needed edges: 8447\n",
      "# needed nodes 2156\n",
      "# unsolved problems: 301 / 1365\n",
      "Solving took: 425.548\n",
      "# minimal edges: 1221\n",
      "# minimal nodes: 157\n",
      "14.45%\n"
     ]
    }
   ],
   "source": [
    "input_combinations = list(combinations_with_replacement(range(1, 13), 4))\n",
    "create_make_24_dataset(input_combinations, folder=\"data/make_24/optimal_graphs/\", prefix=\"four_digits_combinations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3760160e-2eb8-4339-a46d-b243974f60b5",
   "metadata": {},
   "source": [
    "# Solvable & Unsolvable four digits (1-12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "339d2d81-8b4b-4ca7-83a1-1de09e9eadb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import random\n",
    "random.seed(42)\n",
    "T,W,S = [nx.read_graphml(f\"data/optimal_graphs/four_digits_combinations_{x}.graphml\") for x in \"TWS\"]\n",
    "samples = [n for n in T.nodes() if T.in_degree(n) == 0]\n",
    "unsolved = [x for x in samples if x not in W]\n",
    "input_combinations = random.sample(unsolved, 100)\n",
    "input_combinations = [[int(y) for y in x.split()] for x in input_combinations]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "16b530a8-22c8-46fd-9f77-cbf7ef4f119f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solving 100 problems\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 295.06it/s]\n",
      "Processing batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 374/374 [00:01<00:00, 233.31it/s]\n",
      "Processing batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2051/2051 [00:04<00:00, 448.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# nodes global 20476\n",
      "# edges global 75115\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 162.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# nodes isolation 61495\n",
      "# edges isolation 136433\n",
      "# connected comps in T: 1\n",
      "Your node 24 is not in the graph, typo?\n",
      "\tReturning empty W and S\n",
      "# winning paths: 0\n",
      "# needed edges: 0\n",
      "# needed nodes 0\n",
      "# unsolved problems: 100 / 100\n"
     ]
    }
   ],
   "source": [
    "# unsolvable\n",
    "create_make_24_dataset(input_combinations, folder=\"data/optimal_graphs/\", prefix=\"four_digits_unsolvable\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "fc6a8062-a69b-40f8-82e3-f07879c9a8ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lllll}\n",
      "\\toprule\n",
      "0 & 1 & 2 & 3 & 4 \\\\\n",
      "\\midrule\n",
      "1 6 6 7 & 2 9 11 12 & 1 6 10 11 & 3 3 4 10 & 7 7 7 9 \\\\\n",
      "1 1 2 5 & 6 9 11 11 & 5 7 11 12 & 4 7 7 10 & 3 3 5 11 \\\\\n",
      "3 7 10 12 & 3 10 10 11 & 1 2 8 11 & 6 6 7 8 & 7 7 8 12 \\\\\n",
      "3 3 11 11 & 1 1 1 4 & 4 4 9 9 & 2 3 5 12 & 1 7 11 11 \\\\\n",
      "2 10 10 10 & 1 10 10 10 & 5 6 7 11 & 1 2 8 12 & 4 7 7 12 \\\\\n",
      "1 8 9 9 & 9 10 10 10 & 2 5 5 6 & 2 3 9 11 & 7 11 11 12 \\\\\n",
      "1 4 11 11 & 5 5 5 10 & 5 5 7 12 & 2 2 9 9 & 8 8 11 11 \\\\\n",
      "8 10 10 10 & 8 9 10 11 & 1 1 5 11 & 1 4 11 12 & 1 5 7 7 \\\\\n",
      "1 2 10 10 & 1 9 10 10 & 1 1 2 4 & 4 11 12 12 & 5 5 5 11 \\\\\n",
      "6 6 10 11 & 2 9 9 9 & 4 11 11 11 & 1 7 10 11 & 1 10 11 11 \\\\\n",
      "1 1 5 10 & 5 5 5 7 & 1 6 6 8 & 5 5 7 9 & 5 8 9 9 \\\\\n",
      "1 1 5 9 & 9 10 11 11 & 5 8 9 10 & 8 9 9 9 & 5 8 10 10 \\\\\n",
      "1 3 7 11 & 9 9 10 11 & 1 8 11 11 & 4 9 9 11 & 1 1 4 11 \\\\\n",
      "2 9 9 10 & 5 8 8 11 & 1 1 6 10 & 1 1 5 12 & 7 8 8 8 \\\\\n",
      "3 3 7 10 & 1 4 7 10 & 6 7 7 7 & 4 5 5 12 & 1 1 1 10 \\\\\n",
      "7 9 9 12 & 5 6 6 11 & 1 6 7 7 & 7 7 8 8 & 1 9 10 11 \\\\\n",
      "1 1 3 3 & 5 5 6 9 & 6 7 7 12 & 1 2 9 10 & 2 6 7 7 \\\\\n",
      "2 5 11 11 & 3 5 9 11 & 1 1 7 11 & 3 6 7 11 & 6 10 10 12 \\\\\n",
      "9 10 10 12 & 1 1 7 7 & 2 5 5 5 & 5 6 8 11 & 9 9 11 11 \\\\\n",
      "6 6 9 9 & 7 7 7 8 & 9 11 11 12 & 1 6 10 10 & 1 5 5 7 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "a = pd.DataFrame(np.array([\" \".join([str(x) for x in x]) for x in input_combinations]).reshape(5,-1))\n",
    "print(a.T.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b21e7f1-df96-4612-9034-20745f3b7d81",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "5d01d32f-10ba-4e52-b735-6b714c58efe1",
   "metadata": {},
   "source": [
    "## Five digits with replacement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "2854ec6b-d22c-4269-8c6f-06e8778c4f2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "random.seed(42)\n",
    "input_combinations = random.sample(list(combinations_with_replacement(range(1, 13), 5)), 100)\n",
    "len(input_combinations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "69d3961b-d5f9-405b-a463-0375038b8102",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solving 100 problems\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████| 25/25 [00:00<00:00, 87.88it/s]\n",
      "Processing batches: 100%|██████████| 772/772 [00:02<00:00, 263.00it/s]\n",
      "Processing batches: 100%|██████████| 7821/7821 [00:20<00:00, 390.33it/s]\n",
      "Processing batches: 100%|██████████| 26876/26876 [00:40<00:00, 668.16it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# nodes global 237928\n",
      "# edges global 1274764\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:20<00:00,  4.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# nodes isolation 873394\n",
      "# edges isolation 2959661\n",
      "# connected comps in T: 1\n",
      "# connected comps in W: 1\n",
      "# winning paths: 29354\n",
      "# needed edges: 16572\n",
      "# needed nodes 3833\n",
      "# unsolved problems: 0 / 100\n",
      "Solving took: 746.771\n",
      "# minimal edges: 168\n",
      "# minimal nodes: 68\n",
      "1.01%\n"
     ]
    }
   ],
   "source": [
    "create_make_24_dataset(input_combinations, folder=\"data/make_24/optimal_graphs/\", prefix=\"five_digits_combinations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a58e7554-3f6c-40d1-a178-a4a4e8d34754",
   "metadata": {},
   "source": [
    "# debugging dataset (2 tasks, sharing paths)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d8c643ea-144b-4efe-b720-d73fa98c5a30",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solving 2 problems\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████| 1/1 [00:00<00:00, 22.30it/s]\n",
      "Processing batches: 100%|██████████| 18/18 [00:00<00:00, 263.30it/s]\n",
      "Processing batches: 100%|██████████| 148/148 [00:00<00:00, 727.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# nodes global 1662\n",
      "# edges global 4785\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 78.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# nodes isolation 1896\n",
      "# edges isolation 4931\n",
      "# connected comps in T: 1\n",
      "# connected comps in W: 1\n",
      "# winning paths: 64\n",
      "# needed edges: 106\n",
      "# needed nodes 48\n",
      "# unsolved problems: 0 / 2\n",
      "Solving took: 0.021\n",
      "# minimal edges: 5\n",
      "# minimal nodes: 3\n",
      "4.72%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "inputs = [(2,4,6,8),(3,5,7,9)]\n",
    "create_make_24_dataset(inputs, folder=\"data/make_24/optimal_graphs/\", prefix=\"debug_2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74def24c-aff0-4eca-8c81-b71d4a56ea7a",
   "metadata": {},
   "source": [
    "## tasks from ToT paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "11bac2cf-b8f5-4edb-9013-9ad425251c89",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = [[4, 5, 6, 10], [1, 2, 4, 7], [2, 5, 8, 11], [3, 4, 4, 13], [6, 7, 8, 9], [1, 11, 11, 13], [1, 8, 10, 11], [2, 3, 6, 9], [1, 3, 5, 9], [3, 3, 7, 12], \n",
    "          [4, 5, 7, 9], [1, 2, 8, 13], [4, 6, 6, 9], [1, 4, 4, 8], [1, 5, 10, 11], [3, 4, 6, 11], [2, 4, 8, 9], [1, 4, 5, 13], [2, 2, 7, 12], [3, 3, 6, 7], \n",
    "          [1, 5, 9, 13], [5, 6, 7, 13], [5, 5, 8, 10], [2, 4, 6, 12], [6, 7, 8, 11], [7, 9, 9, 13], [3, 6, 9, 12], [6, 9, 12, 13], [4, 7, 9, 13], [5, 6, 8, 12],\n",
    "          [2, 4, 6, 7], [2, 5, 10, 10], [6, 6, 7, 12], [6, 9, 9, 11], [5, 8, 11, 12], [5, 6, 8, 10], [6, 11, 12, 13], [2, 2, 8, 8], [2, 7, 12, 13], [2, 6, 8, 12], \n",
    "          [3, 4, 9, 13], [4, 5, 10, 12], [1, 2, 7, 11], [4, 5, 6, 8], [6, 10, 12, 13], [1, 3, 9, 9], [1, 4, 4, 11], [2, 3, 9, 10], [1, 2, 3, 13], [1, 6, 6, 6], \n",
    "          [1, 2, 2, 9], [1, 3, 6, 11], [5, 10, 12, 13], [2, 3, 6, 6], [6, 7, 10, 12], [7, 8, 8, 12], [3, 4, 6, 8], [1, 7, 9, 11], [2, 3, 6, 13], [2, 2, 5, 12], \n",
    "          [2, 6, 8, 13], [8, 8, 10, 12], [1, 3, 8, 13], [4, 4, 7, 10], [1, 7, 10, 13], [1, 9, 10, 13], [3, 3, 4, 11], [2, 5, 7, 7], [3, 9, 10, 13], [2, 3, 4, 7], \n",
    "          [4, 4, 8, 12], [1, 2, 6, 10], [1, 5, 12, 12], [5, 6, 6, 8], [7, 7, 8, 11], [1, 3, 7, 10], [3, 3, 9, 12], [3, 5, 7, 10], [4, 10, 12, 13], [2, 3, 10, 12], \n",
    "          [3, 4, 6, 6], [5, 8, 8, 8], [6, 8, 8, 12], [2, 3, 4, 9], [2, 6, 7, 11], [5, 9, 12, 12], [1, 2, 7, 12], [2, 4, 5, 6], [5, 5, 8, 13], [2, 3, 3, 10], \n",
    "          [3, 4, 8, 12], [2, 4, 6, 11], [2, 2, 8, 9], [1, 5, 6, 7], [5, 8, 10, 11], [4, 4, 9, 12], [2, 5, 6, 6], [2, 4, 9, 12], [4, 8, 11, 13], [4, 9, 10, 13]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "1e739764-eec1-4273-9e05-d348c18cad6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lllll}\n",
      "\\toprule\n",
      "0 & 1 & 2 & 3 & 4 \\\\\n",
      "\\midrule\n",
      "4 5 6 10 & 1 5 9 13 & 3 4 9 13 & 2 6 8 13 & 3 4 6 6 \\\\\n",
      "1 2 4 7 & 5 6 7 13 & 4 5 10 12 & 8 8 10 12 & 5 8 8 8 \\\\\n",
      "2 5 8 11 & 5 5 8 10 & 1 2 7 11 & 1 3 8 13 & 6 8 8 12 \\\\\n",
      "3 4 4 13 & 2 4 6 12 & 4 5 6 8 & 4 4 7 10 & 2 3 4 9 \\\\\n",
      "6 7 8 9 & 6 7 8 11 & 6 10 12 13 & 1 7 10 13 & 2 6 7 11 \\\\\n",
      "1 11 11 13 & 7 9 9 13 & 1 3 9 9 & 1 9 10 13 & 5 9 12 12 \\\\\n",
      "1 8 10 11 & 3 6 9 12 & 1 4 4 11 & 3 3 4 11 & 1 2 7 12 \\\\\n",
      "2 3 6 9 & 6 9 12 13 & 2 3 9 10 & 2 5 7 7 & 2 4 5 6 \\\\\n",
      "1 3 5 9 & 4 7 9 13 & 1 2 3 13 & 3 9 10 13 & 5 5 8 13 \\\\\n",
      "3 3 7 12 & 5 6 8 12 & 1 6 6 6 & 2 3 4 7 & 2 3 3 10 \\\\\n",
      "4 5 7 9 & 2 4 6 7 & 1 2 2 9 & 4 4 8 12 & 3 4 8 12 \\\\\n",
      "1 2 8 13 & 2 5 10 10 & 1 3 6 11 & 1 2 6 10 & 2 4 6 11 \\\\\n",
      "4 6 6 9 & 6 6 7 12 & 5 10 12 13 & 1 5 12 12 & 2 2 8 9 \\\\\n",
      "1 4 4 8 & 6 9 9 11 & 2 3 6 6 & 5 6 6 8 & 1 5 6 7 \\\\\n",
      "1 5 10 11 & 5 8 11 12 & 6 7 10 12 & 7 7 8 11 & 5 8 10 11 \\\\\n",
      "3 4 6 11 & 5 6 8 10 & 7 8 8 12 & 1 3 7 10 & 4 4 9 12 \\\\\n",
      "2 4 8 9 & 6 11 12 13 & 3 4 6 8 & 3 3 9 12 & 2 5 6 6 \\\\\n",
      "1 4 5 13 & 2 2 8 8 & 1 7 9 11 & 3 5 7 10 & 2 4 9 12 \\\\\n",
      "2 2 7 12 & 2 7 12 13 & 2 3 6 13 & 4 10 12 13 & 4 8 11 13 \\\\\n",
      "3 3 6 7 & 2 6 8 12 & 2 2 5 12 & 2 3 10 12 & 4 9 10 13 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "a = pd.DataFrame(np.array([\" \".join([str(x) for x in x]) for x in inputs]).reshape(5,-1))\n",
    "print(a.T.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "933e4320-3f1f-4cad-a45b-ec079b2b6df0",
   "metadata": {},
   "outputs": [],
   "source": [
    "create_make_24_dataset(inputs, folder=\"data/make_24/optimal_graphs/\", prefix=\"tot_test_split\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81beafe8-2073-4724-a81a-7a15d8455f7e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hp_env2",
   "language": "python",
   "name": "hp_env2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
