{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dfdd8dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "\n",
    "def star_graph(degSource, pathLen, numNodes, reverse=False):\n",
    "    source = np.random.randint(0, numNodes, 1)[0]\n",
    "    goal = np.random.randint(0, numNodes, 1)[0]\n",
    "    while goal == source:\n",
    "        goal = np.random.randint(0, numNodes, 1)[0]\n",
    "\n",
    "    path = [source]\n",
    "    edge_list = []\n",
    "\n",
    "    # Choose random nodes along the path\n",
    "    for _ in range(pathLen - 2):\n",
    "        node = np.random.randint(0, numNodes, 1)[0]\n",
    "        while node in path or node == goal:\n",
    "            node = np.random.randint(0, numNodes, 1)[0]\n",
    "        path.append(node)\n",
    "\n",
    "    path.append(goal)\n",
    "    # Connect the path\n",
    "    for i in range(len(path) - 1):\n",
    "        edge_list.append([path[i], path[i + 1]])\n",
    "\n",
    "    remaining_nodes = []\n",
    "    for i in range(numNodes):\n",
    "        if i not in path:\n",
    "            remaining_nodes.append(i)\n",
    "\n",
    "    i = 0\n",
    "    deg_nodes = set()\n",
    "    while i < degSource - 1:\n",
    "        # Add neighbour to source\n",
    "        node = source\n",
    "        next_node = np.random.randint(0, numNodes, 1)[0]\n",
    "        l = 1\n",
    "        while l < pathLen:\n",
    "            if next_node not in deg_nodes and next_node not in path:\n",
    "                edge_list.append([node, next_node])\n",
    "                deg_nodes.add(next_node)\n",
    "                node = next_node\n",
    "                l += 1\n",
    "            next_node = np.random.randint(0, numNodes, 1)[0]\n",
    "\n",
    "        i += 1\n",
    "\n",
    "    random.shuffle(edge_list)\n",
    "    if reverse:\n",
    "        path = path[::-1]\n",
    "\n",
    "    return path, edge_list, source, goal\n",
    "\n",
    "\n",
    "def generate_and_save(n_train, n_test, degSource, pathLen, numNodes, reverse=False):\n",
    "    \"\"\"\n",
    "    Generate a list of train and testing graphs and save them for reproducibility\n",
    "    \"\"\"\n",
    "    file = open('./data/datasets/graphs/' + 'deg_' + str(degSource) + '_path_' + str(pathLen) + '_nodes_' + str(\n",
    "        numNodes) + '_train_' +\n",
    "          str(n_train) + '.txt', 'w')\n",
    "\n",
    "    for i in range(n_train):\n",
    "        path, edge_list, start, goal = star_graph(degSource, pathLen, numNodes, reverse=reverse)\n",
    "        path_str = ''\n",
    "        for node in path:\n",
    "            path_str += str(node) + ','\n",
    "        path_str = path_str[:-1]\n",
    "\n",
    "        edge_str = ''\n",
    "        for e in edge_list:\n",
    "            edge_str += str(e[0]) + ',' + str(e[1]) + ' '\n",
    "        edge_str = edge_str[:-1]\n",
    "        edge_str += '/' + str(start) + ',' + str(goal) + '='\n",
    "\n",
    "        out = edge_str + path_str\n",
    "        file.write(out + '\\n')\n",
    "    file.close()\n",
    "\n",
    "    file = open('./data/datasets/graphs/' + 'deg_' + str(degSource) + '_path_' + str(pathLen) + '_nodes_' +\n",
    "                str(numNodes) + '_test_' +\n",
    "                 str(n_test) + '.txt', 'w')\n",
    "\n",
    "    for i in range(n_test):\n",
    "        path, edge_list, start, goal = star_graph(degSource, pathLen, numNodes, reverse=reverse)\n",
    "        path_str = ''\n",
    "        for node in path:\n",
    "            path_str += str(node) + ','\n",
    "        path_str = path_str[:-1]\n",
    "\n",
    "        edge_str = ''\n",
    "        for e in edge_list:\n",
    "            edge_str += str(e[0]) + ',' + str(e[1]) + '|'\n",
    "        edge_str = edge_str[:-1]\n",
    "        edge_str += '/' + str(start) + ',' + str(goal) + '='\n",
    "\n",
    "        out = edge_str + path_str\n",
    "        file.write(out + '\\n')\n",
    "\n",
    "    file.close()\n",
    "\n",
    "\n",
    "def prefix_target_list(filename=None, reverse=False):\n",
    "    \"\"\"\n",
    "    Load graphs and split them into prefix and target and return the list\n",
    "    \"\"\"\n",
    "    data_list = []\n",
    "    with open(filename, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    for line in lines:\n",
    "        prefix = line.strip().split('=')[0] + '='\n",
    "        target = line.strip().split('=')[1]\n",
    "        if reverse:\n",
    "            target = ','.join(target.split(',')[::-1])\n",
    "        data_list.append((prefix, target))\n",
    "\n",
    "    return data_list\n",
    "\n",
    "\n",
    "def get_edge_list(x, num_nodes, path_len):\n",
    "    \"\"\"\n",
    "    Given the tokenised input for the Transformer, map back to the edge_list\n",
    "    \"\"\"\n",
    "    edge_list = []\n",
    "    pair = []\n",
    "    x = x.squeeze().cpu().numpy()\n",
    "\n",
    "    for i, n in enumerate(x):\n",
    "        if n in range(num_nodes):\n",
    "            pair.append(n)\n",
    "        if len(pair) == 2:\n",
    "            edge_list.append(pair)\n",
    "            pair = []\n",
    "        if n == num_nodes + 2:\n",
    "            break\n",
    "\n",
    "    start = x[i + 1]\n",
    "    goal = x[i + 2]\n",
    "    path = [x[i + j] for j in range(4, 4 + path_len)]\n",
    "\n",
    "    return edge_list, start, goal, path\n",
    "\n",
    "\n",
    "def get_edge_list_byte(x, num_nodes, path_len, decode):\n",
    "    \"\"\"\n",
    "    Given the tokenised input for the Transformer, map back to the edge_list\n",
    "    \"\"\"\n",
    "    edge_list = []\n",
    "    x = list(x.squeeze().cpu().numpy())\n",
    "    dec = [decode([val]) for val in x]\n",
    "    edge = []\n",
    "    for i, val in enumerate(dec):\n",
    "        if val not in [',', '|', '=', '->']:\n",
    "            edge.append(val)\n",
    "        if len(edge) == 2:\n",
    "            edge_list.append(edge)\n",
    "            edge = []\n",
    "\n",
    "        if val == '->':\n",
    "            break\n",
    "    i += 2\n",
    "    start = dec[i + 1]\n",
    "    goal = dec[i -1]\n",
    "    path = [dec[i + 3 + 2 * j] for j in range(0, path_len - 2)]\n",
    "\n",
    "    return edge_list, start, goal, path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6aedc95a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset, DatasetDict\n",
    "from huggingface_hub import HfApi\n",
    "import os\n",
    "\n",
    "def push_dataset_to_hf_hub(n_train, n_test, degSource, pathLen, numNodes, reverse=False):\n",
    "    HF_USERNAME = \"\"  # Replace with your HF username # ENTER YOUR HF USERNAME\n",
    "    DATASET_NAME = f\"star-graph-deg-{degSource}-path-{pathLen}-nodes-{numNodes}\"  # Replace with your dataset name\n",
    "    REPO_ID = f\"{HF_USERNAME}/{DATASET_NAME}\"\n",
    "\n",
    "    trn_file = open('./data/datasets/graphs/' + 'deg_' + str(degSource) + '_path_' + str(pathLen) + '_nodes_' + str(\n",
    "        numNodes) + '_train_' +\n",
    "          str(n_train) + '.txt', 'r')\n",
    "    tst_file = open('./data/datasets/graphs/' + 'deg_' + str(degSource) + '_path_' + str(pathLen) + '_nodes_' + str(\n",
    "        numNodes) + '_test_' +\n",
    "          str(n_test) + '.txt', 'r')\n",
    "    \n",
    "    trn_pts = trn_file.read().splitlines()\n",
    "    tst_pts = tst_file.read().splitlines()\n",
    "\n",
    "    train_data = {\n",
    "      \"graph\": [],\n",
    "      \"source\": [],\n",
    "      \"destination\": [],\n",
    "      \"path\": [],\n",
    "    }\n",
    "\n",
    "    tst_data = {\n",
    "        \"graph\": [],\n",
    "        \"source\": [],\n",
    "        \"destination\": [],\n",
    "        \"path\": []\n",
    "    }\n",
    "    i = 0\n",
    "    for trn_pt in trn_pts:\n",
    "        question, path = trn_pt.split('=')\n",
    "        graph, source_and_destination = question.split(\"/\")\n",
    "        source, destination = source_and_destination.split(',')\n",
    "        if i == 0:\n",
    "            print(question)\n",
    "            print(path)\n",
    "            i += 1\n",
    "        train_data['graph'].append(graph)\n",
    "        train_data['path'].append(path)\n",
    "        train_data['source'].append(source)\n",
    "        train_data['destination'].append(destination)\n",
    "    \n",
    "    for tst_pt in tst_pts:\n",
    "        question, path = tst_pt.split('=')\n",
    "        graph, source_and_destination = question.split(\"/\")\n",
    "        source, destination = source_and_destination.split(',')\n",
    "        tst_data['graph'].append(graph)\n",
    "        tst_data['path'].append(path)\n",
    "        tst_data['source'].append(source)\n",
    "        tst_data['destination'].append(destination)\n",
    "    train_dataset = Dataset.from_dict(train_data)\n",
    "    test_dataset = Dataset.from_dict(tst_data)\n",
    "    dataset_dict = DatasetDict({\n",
    "        \"train\": train_dataset,\n",
    "        \"test\": test_dataset,\n",
    "    })\n",
    "    dataset_dict.push_to_hub(REPO_ID)\n",
    "    return train_dataset, test_dataset\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e78dc5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create graphs and save\n",
    "n_train = 200000\n",
    "n_test = 20000\n",
    "deg = 128\n",
    "path_len = 3\n",
    "num_nodes = 300\n",
    "reverse = False\n",
    "generate_and_save(n_train=n_train, n_test=n_test, degSource=deg, pathLen=path_len, numNodes=num_nodes,\n",
    "                    reverse=reverse)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "591fd6b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = push_dataset_to_hf_hub(n_train=n_train, n_test=n_test, degSource=deg, pathLen=path_len, numNodes=num_nodes,\n",
    "                    reverse=reverse)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
