{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "096be2c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import pickle\n",
    "\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18340d51",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_path = ''\n",
    "target_path = ''"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73366936",
   "metadata": {},
   "source": [
    "Prepare einspace data\n",
    "- split by seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e13929d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "einspace = pd.read_csv(os.path.join(source_path, 'einspace.csv'))\n",
    "einspace_aug = pd.read_csv(os.path.join(source_path, 'einspace_augmentation.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4e8978fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['onnx_encoding', 'accuracy', 'onnx_encoding_tokens', 'dataset', 'name'], dtype='object')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "einspace.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "950dd45f",
   "metadata": {},
   "outputs": [],
   "source": [
    "einspace['seed'] = einspace['name'].apply(lambda x: x.split('_')[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "462baf84",
   "metadata": {},
   "outputs": [],
   "source": [
    "einspace_val = einspace[einspace['seed'] == 'seed=4']\n",
    "einspace_train = einspace[(einspace['seed'] != 'seed=4') & (einspace['seed'] != 'seed=0')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "71eaf7f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "einspace_train = pd.concat([einspace_train, einspace_aug], ignore_index=True)\n",
    "einspace_train = einspace_train[einspace_train['accuracy'] > 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "753dbaa5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training set shape: (37416, 6)\n",
      "Validation set shape: (1582, 6)\n"
     ]
    }
   ],
   "source": [
    "print(\"Training set shape:\", einspace_train.shape)\n",
    "print(\"Validation set shape:\", einspace_val.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "25b7642f",
   "metadata": {},
   "outputs": [],
   "source": [
    "einspace_train.to_csv(os.path.join(target_path, 'einspace_train.csv'), index=False)\n",
    "einspace_val.to_csv(os.path.join(target_path, 'einspace_val.csv'), index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca724313",
   "metadata": {},
   "source": [
    "Prepare hnasbench201 data\n",
    "- split by seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cb331d66",
   "metadata": {},
   "outputs": [],
   "source": [
    "hnasbench201 = pd.read_csv(os.path.join(source_path, 'hnasbench201.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "56a6d7b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "hnasbench201['seed'] = hnasbench201['name'].apply(lambda x: x.split('_')[0])\n",
    "hnasbench201['index'] = hnasbench201['name'].apply(lambda x: x.split('_')[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "94bfe490",
   "metadata": {},
   "outputs": [],
   "source": [
    "hnasbench201_val = hnasbench201[hnasbench201['seed'] == 'seed=4']\n",
    "hnasbench201_train = hnasbench201[hnasbench201['seed'] != 'seed=4']\n",
    "hnasbench201_train = hnasbench201_train[hnasbench201_train['accuracy'] > 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0725ed2e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training set shape: (6403, 7)\n",
      "Validation set shape: (1000, 7)\n"
     ]
    }
   ],
   "source": [
    "print(\"Training set shape:\", hnasbench201_train.shape)\n",
    "print(\"Validation set shape:\", hnasbench201_val.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6738868d",
   "metadata": {},
   "outputs": [],
   "source": [
    "hnasbench201_train.to_csv(os.path.join(target_path, 'hnasbench201_train.csv'), index=False)\n",
    "hnasbench201_val.to_csv(os.path.join(target_path, 'hnasbench201_val.csv'), index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "290c3ba5",
   "metadata": {},
   "source": [
    "Prepare nasbench201 + natsbench data\n",
    "- these two search space considered together\n",
    "- random split by 80%/20%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "80bc7989",
   "metadata": {},
   "outputs": [],
   "source": [
    "nasbench201 = pd.read_csv(os.path.join(source_path, 'nasbench201.csv'))\n",
    "natsbench = pd.read_csv(os.path.join(source_path, 'natsbench.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c7882e6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "nas201nats = pd.concat([nasbench201, natsbench], ignore_index=True)\n",
    "nas201nats_train, nas201nats_val = train_test_split(\n",
    "    nas201nats, test_size=0.2, random_state=42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b3ece777",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training set shape: (38714, 5)\n",
      "Validation set shape: (9679, 5)\n"
     ]
    }
   ],
   "source": [
    "print(\"Training set shape:\", nas201nats_train.shape)\n",
    "print(\"Validation set shape:\", nas201nats_val.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "7146bda2",
   "metadata": {},
   "outputs": [],
   "source": [
    "nas201nats_train.to_csv(os.path.join(target_path, 'nas201nats_train.csv'), index=False)\n",
    "nas201nats_val.to_csv(os.path.join(target_path, 'nas201nats_val.csv'), index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "707415b5",
   "metadata": {},
   "source": [
    "Prepare nasbench101\n",
    "- split based on hash file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8cb18f7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"nb101_hash.txt\", \"rb\") as fp:\n",
    "    nb101_hash = pickle.load(fp)\n",
    "nasbench101 = pd.read_csv(os.path.join(source_path, 'nasbench101.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "53b65c19",
   "metadata": {},
   "outputs": [],
   "source": [
    "nasbench101['hash'] = nasbench101['name'].apply(lambda x: x.split('_')[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "8f24e041",
   "metadata": {},
   "outputs": [],
   "source": [
    "nasbench101_val = nasbench101[nasbench101['hash'].isin(nb101_hash)]\n",
    "nasbench101_train = nasbench101[~nasbench101['hash'].isin(nb101_hash)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "2b905052",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training set shape: (416334, 6)\n",
      "Validation set shape: (7290, 6)\n"
     ]
    }
   ],
   "source": [
    "print(\"Training set shape:\", nasbench101_train.shape)\n",
    "print(\"Validation set shape:\", nasbench101_val.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b8e1c867",
   "metadata": {},
   "outputs": [],
   "source": [
    "nasbench101_train.to_csv(os.path.join(target_path, 'nasbench101_train.csv'), index=False)\n",
    "nasbench101_val.to_csv(os.path.join(target_path, 'nasbench101_val.csv'), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "f0824698",
   "metadata": {},
   "outputs": [],
   "source": [
    "nasbench101_train = nasbench101_train.sample(n = 50_000, random_state=42).reset_index(drop=True)\n",
    "nasbench101_train.to_csv(os.path.join(target_path, 'nasbench101_50k.csv'), index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f41f324",
   "metadata": {},
   "source": [
    "Prepare nasbench301\n",
    "- split based on source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "8acde6dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "nasbench301 = pd.read_csv(os.path.join(source_path, 'nasbench301.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c99db3c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def list_onnx_files(folder_path):\n",
    "    \"\"\"\n",
    "    Return a list of all .onnx file paths found within 'folder_path' (recursively).\n",
    "    \"\"\"\n",
    "    onnx_files = []\n",
    "    for root, dirs, files in os.walk(folder_path):\n",
    "        for file in files:\n",
    "            if file.lower().endswith('.onnx'):\n",
    "                full_path = os.path.join(root, file)\n",
    "                onnx_files.append(full_path)\n",
    "    return onnx_files\n",
    "\n",
    "onnx_files = list_onnx_files('../onnx/nasbench301')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "b26d98dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "index_source_map = {}\n",
    "for i in onnx_files:\n",
    "    index = i.split('/')[-1].split('.')[0]\n",
    "    source = i.split('/')[-2]\n",
    "    index_source_map[index] = source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "033860e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "nasbench301['source'] = nasbench301['name'].apply(lambda x: index_source_map[str(x)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "b44835e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_sources = ['only_avg_pool_3x3', 'local_search', 'bananas']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "21fd76b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# val_sources = random.sample(list(nasbench301['source'].unique()), 3)\n",
    "# val_sources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "abf3481e",
   "metadata": {},
   "outputs": [],
   "source": [
    "nasbench301_train = nasbench301[~nasbench301['source'].isin(val_sources)]\n",
    "nasbench301_val = nasbench301[nasbench301['source'].isin(val_sources)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "fd340c98",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training set shape: (51297, 6)\n",
      "Validation set shape: (5892, 6)\n"
     ]
    }
   ],
   "source": [
    "print(\"Training set shape:\", nasbench301_train.shape)\n",
    "print(\"Validation set shape:\", nasbench301_val.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "dd38808d",
   "metadata": {},
   "outputs": [],
   "source": [
    "nasbench301_train.to_csv(os.path.join(target_path, 'nasbench301_train.csv'), index=False)\n",
    "nasbench301_val.to_csv(os.path.join(target_path, 'nasbench301_val.csv'), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9cc324d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ddc70b8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "einspace",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
