{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import pandas as pd\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_random_splits(num_elem, val_ratio, test_ratio, seed=42):\n",
    "    \"\"\"Make random splits of the data into train, validation, and test sets.\n",
    "\n",
    "    Args:\n",
    "        num_elem (int): Number of elements in the dataset.\n",
    "        val_ratio (float): Ratio of the dataset to use for validation.\n",
    "        test_ratio (float): Ratio of the dataset to use for testing.\n",
    "\n",
    "    Returns:\n",
    "        train_idx (list): List of indices for the training set.\n",
    "        val_idx (list): List of indices for the validation set.\n",
    "        test_idx (list): List of indices for the test set.\n",
    "    \"\"\"\n",
    "    # Create a list of indices\n",
    "    idx = list(range(num_elem))\n",
    "    # Shuffle the list of indices\n",
    "    random.seed(seed)\n",
    "    random.shuffle(idx)\n",
    "    # Compute the number of elements in each set\n",
    "    num_val = int(num_elem * val_ratio)\n",
    "    num_test = int(num_elem * test_ratio)\n",
    "    num_train = num_elem - num_val - num_test\n",
    "    # Split the list of indices into three sets\n",
    "    train_idx = idx[:num_train]\n",
    "    val_idx = idx[num_train:num_train + num_val]\n",
    "    test_idx = idx[num_train + num_val:]\n",
    "    # Return the three lists of indices\n",
    "    return train_idx, val_idx, test_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_random_splits_from_file(in_file, out_file, val_ratio, test_ratio, seed=42):\n",
    "    # Make the splits for QM9\n",
    "    df = pd.read_csv(in_file, usecols=[\"smiles\"])\n",
    "    train_idx, val_idx, test_idx = make_random_splits(len(df), val_ratio, test_ratio, seed=seed)\n",
    "    # Save the splits\n",
    "    splits_dict = {\"train\": train_idx, \"val\": val_idx, \"test\": test_idx}\n",
    "\n",
    "    # Check the splits validity\n",
    "    assert len(set(train_idx).intersection(set(val_idx))) == 0\n",
    "    assert len(set(train_idx).intersection(set(test_idx))) == 0\n",
    "    assert len(set(val_idx).intersection(set(test_idx))) == 0\n",
    "    assert len(train_idx) + len(val_idx) + len(test_idx) == len(df)\n",
    "    \n",
    "    torch.save(splits_dict, out_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "make_random_splits_from_file(\"qm9.csv.gz\", \"qm9_random_splits.pt\", 0.1, 0.1)\n",
    "make_random_splits_from_file(\"Tox21-7k-12-labels.csv.gz\", \"Tox21_random_splits.pt\", 0.1, 0.1)\n",
    "make_random_splits_from_file(\"ZINC12k.csv.gz\", \"ZINC12k_random_splits.pt\", 0.1, 0.1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graphium",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
