{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you don't wanna change subset ratios through re-partitioning dataset\n",
    "# Here is the fast script\n",
    "import numpy as np\n",
    "import pickle\n",
    "import json\n",
    "import hashlib\n",
    "\n",
    "dataset = \"cifar10\"\n",
    "\n",
    "new_testset_ratio = 0.5\n",
    "new_valset_ratio = 0\n",
    "\n",
    "partition = pickle.load(open(f\"{dataset}/partition.pkl\", \"rb\"))\n",
    "\n",
    "for i in range(len(partition[\"data_indices\"])):\n",
    "    indices = np.concatenate(\n",
    "        [\n",
    "            partition[\"data_indices\"][i][\"train\"],\n",
    "            partition[\"data_indices\"][i][\"val\"],\n",
    "            partition[\"data_indices\"][i][\"test\"],\n",
    "        ],\n",
    "        axis=0,\n",
    "    )\n",
    "    new_testset_size = int(len(indices) * new_testset_ratio)\n",
    "    new_valset_size = int(len(indices) * new_valset_ratio)\n",
    "\n",
    "    partition[\"data_indices\"][i][\"test\"] = indices[:new_testset_size]\n",
    "    partition[\"data_indices\"][i][\"val\"] = indices[\n",
    "        new_testset_size : new_testset_size + new_valset_size\n",
    "    ]\n",
    "    partition[\"data_indices\"][i][\"train\"] = indices[\n",
    "        new_testset_size + new_valset_size :\n",
    "    ]\n",
    "\n",
    "pickle.dump(partition, open(f\"{dataset}/partition.pkl\", \"wb\"))\n",
    "\n",
    "args = json.load(open(f\"{dataset}/args.json\", \"r\"))\n",
    "args[\"test_ratio\"] = new_testset_ratio\n",
    "args[\"val_ratio\"] = new_valset_ratio\n",
    "json.dump(args, open(f\"{dataset}/args.json\", \"w\"), indent=4)\n",
    "with open(f\"{dataset}/partition_md5.txt\", \"w\") as f:\n",
    "    f.write(hashlib.md5(json.dumps(args.__dict__).encode()).hexdigest())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
