{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b89436e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import src.ds_utils as ds_utils\n",
    "import os\n",
    "import torch\n",
    "import src.pytorch_datasets as pytorch_datasets\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f72ecc43",
   "metadata": {},
   "source": [
    "# CIFAR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "988f1988",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_name = 'cifar'\n",
    "config = f\"dataset_configs/{ds_name}.yaml\"\n",
    "hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', \"REDACTED\")\n",
    "\n",
    "unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=2)\n",
    "path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')\n",
    "torch.save(unlabeled_split, path)\n",
    "\n",
    "only_val_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5)\n",
    "path = os.path.join('index_files', f'{ds_name}.pt')\n",
    "torch.save(only_val_split, path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf02f5ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CIFAR 25%\n",
    "# VAL: 1/5, UNLABELED: 3/5, TRAIN: 1/5\n",
    "config = f\"dataset_configs/cifar.yaml\"\n",
    "\n",
    "ds_name = 'cifar_0.25'\n",
    "hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', \"REDACTED\")\n",
    "\n",
    "unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=4)\n",
    "unlabeled_split = {\n",
    "    'val_indices': unlabeled_split['val_indices'],\n",
    "    'train_indices': unlabeled_split['unlabeled_indices'],\n",
    "    'unlabeled_indices': unlabeled_split['train_indices'],\n",
    "}\n",
    "path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')\n",
    "torch.save(unlabeled_split, path)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dca2d89b",
   "metadata": {},
   "source": [
    "# CIFAR-100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8ea6b60",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_name = 'cifar100'\n",
    "config = f\"dataset_configs/{ds_name}.yaml\"\n",
    "hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', \"REDACTED\")\n",
    "\n",
    "unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=2)\n",
    "path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')\n",
    "torch.save(unlabeled_split, path)\n",
    "\n",
    "only_val_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5)\n",
    "path = os.path.join('index_files', f'{ds_name}.pt')\n",
    "torch.save(only_val_split, path)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba57f983",
   "metadata": {},
   "source": [
    "## ImageNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf12e6a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_name = 'imagenet'\n",
    "config = f\"dataset_configs/{ds_name}.yaml\"\n",
    "hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', \"REDACTED\")\n",
    "\n",
    "unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=2)\n",
    "path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')\n",
    "torch.save(unlabeled_split, path)\n",
    "\n",
    "only_val_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5)\n",
    "path = os.path.join('index_files', f'{ds_name}.pt')\n",
    "torch.save(only_val_split, path)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ba30bc7",
   "metadata": {},
   "source": [
    "# Super CIFAR-100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e88912d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_name = 'supercifar100'\n",
    "config = f\"dataset_configs/{ds_name}.yaml\"\n",
    "hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', \"REDACTED\")\n",
    "\n",
    "unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=2)\n",
    "path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')\n",
    "torch.save(unlabeled_split, path)\n",
    "\n",
    "only_val_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5)\n",
    "path = os.path.join('index_files', f'{ds_name}.pt')\n",
    "torch.save(only_val_split, path)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22e895b6",
   "metadata": {},
   "source": [
    "# Spurious CIFAR 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66d3070a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "ds = pytorch_datasets.SuperCIFAR100(root=\"REDACTED\", train=True)\n",
    "config = f\"dataset_configs/supercifar100.yaml\"\n",
    "hparams, train_labels = ds_utils.get_all_beton_labels(config, 'train', \"REDACTED\")\n",
    "\n",
    "subclass_targets = np.array(ds.subclass_targets)\n",
    "targets = np.array(ds.targets)\n",
    "classes_to_drop = []\n",
    "# for c in range(20):\n",
    "#     classes_to_drop.append(np.unique(np.array(subclass_targets[targets == c]))[0])\n",
    "classes_to_drop = [4, 73, 54, 10, 51, 40, 84, 18, 3, 12, 33, 38, 64, 45, 2, 44, 80, 96, 13, 81]\n",
    "\n",
    "def split_spurious_cifar(orig_indices):\n",
    "    new_train_indices = []\n",
    "    for c in range(100):\n",
    "        mask = subclass_targets[orig_indices] == c\n",
    "        if c in classes_to_drop:\n",
    "            new_train_indices.append(orig_indices[mask][::4])\n",
    "        else:\n",
    "            new_train_indices.append(orig_indices[mask])\n",
    "    new_train_indices = torch.cat(new_train_indices)\n",
    "    return new_train_indices\n",
    "\n",
    "unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5, unlabeled_split_amt=2)\n",
    "unlabeled_split['train_indices'] = split_spurious_cifar(unlabeled_split['train_indices'])\n",
    "unlabeled_split['unlabeled_indices'] = unlabeled_split['unlabeled_indices']\n",
    "unlabeled_split['classes_to_drop'] = classes_to_drop\n",
    "\n",
    "only_val_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], val_split_amt=5)\n",
    "subsampled = split_spurious_cifar(only_val_split['train_indices'])\n",
    "only_val_split['train_indices'] = subsampled\n",
    "only_val_split['classes_to_drop'] = classes_to_drop\n",
    "\n",
    "ds_name = 'spurious_cifar100'\n",
    "path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')\n",
    "torch.save(unlabeled_split, path)\n",
    "\n",
    "path = os.path.join('index_files', f'{ds_name}.pt')\n",
    "torch.save(only_val_split, path)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62b3cda9",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.arange(10)[::0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c1dbb5e",
   "metadata": {},
   "source": [
    "# CelebA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4c20fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = \"dataset_configs/celeba.yaml\"\n",
    "\n",
    "hparams, train_labels, train_spuriouses = ds_utils.get_all_beton_labels(config, 'train', \"REDACTED\", include_spurious=True)\n",
    "hparams, val_labels, val_spuriouses = ds_utils.get_all_beton_labels(config, 'val', \"REDACTED\", include_spurious=True)\n",
    "\n",
    "# spuriouses is 1 if blond, 2 if black hair, 0 if neither\n",
    "# train_labels is 0 if female, 1 if male"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1102f32",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_celeba_split(labels, spuriouses, majority_multiplier=7):\n",
    "    all_indices = np.arange(len(labels))\n",
    "\n",
    "    female = labels == 0\n",
    "    male = labels == 1\n",
    "    blond = spuriouses == 1\n",
    "    black = spuriouses == 2\n",
    "\n",
    "    female_blond = all_indices[female & blond]\n",
    "    female_black = all_indices[female & black]\n",
    "    male_blond = all_indices[male & blond]\n",
    "    male_black = all_indices[male & black]\n",
    "    print(len(female_blond), len(female_black), len(male_black), len(male_blond))\n",
    "\n",
    "    smallest_minority = len(male_blond)\n",
    "    smallest_majority = len(male_black)\n",
    "    minority = min(smallest_minority, int(smallest_majority/majority_multiplier))\n",
    "    \n",
    "    majority = minority*majority_multiplier\n",
    "    train_indices = [female_blond[:majority], female_black[:minority], male_black[:majority], male_blond[:minority]]\n",
    "    print([len(u) for u in train_indices])\n",
    "    return np.concatenate(train_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03578d86",
   "metadata": {},
   "outputs": [],
   "source": [
    "# val\n",
    "mult = 5\n",
    "\n",
    "ds_name = f\"celeba_1_{mult}\"\n",
    "\n",
    "val_indices = get_celeba_split(val_labels, val_spuriouses, majority_multiplier=1)\n",
    "\n",
    "# with unlabeled\n",
    "unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], unlabeled_split_amt=2)\n",
    "orig_train_indices = unlabeled_split['train_indices']\n",
    "train_indices = orig_train_indices[get_celeba_split(\n",
    "    train_labels[orig_train_indices], \n",
    "    train_spuriouses[orig_train_indices],\n",
    "    majority_multiplier=mult,\n",
    ")]\n",
    "\n",
    "orig_unlabeled_indices = unlabeled_split['unlabeled_indices']\n",
    "unlabeled_indices = orig_unlabeled_indices[get_celeba_split(\n",
    "    train_labels[orig_unlabeled_indices], \n",
    "    train_spuriouses[orig_unlabeled_indices],\n",
    "    majority_multiplier=mult,\n",
    ")]\n",
    "path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')\n",
    "torch.save({\n",
    "    'val_indices': val_indices,\n",
    "    'train_indices': train_indices,\n",
    "    'unlabeled_indices': unlabeled_indices\n",
    "}, path)\n",
    "\n",
    "# without unlabeled\n",
    "path = os.path.join('index_files', f'{ds_name}.pt')\n",
    "torch.save({\n",
    "    'val_indices': val_indices,\n",
    "    'train_indices': get_celeba_split(train_labels, train_spuriouses, majority_multiplier=mult),\n",
    "}, path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b05ae900",
   "metadata": {},
   "source": [
    "# CelebA Age"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "127fe177",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = \"dataset_configs/celeba.yaml\"\n",
    "\n",
    "hparams, train_labels, train_spuriouses = ds_utils.get_all_beton_labels(config, 'train', \"REDACTED\", include_spurious=True)\n",
    "hparams, val_labels, val_spuriouses = ds_utils.get_all_beton_labels(config, 'val', \"REDACTED\", include_spurious=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47351cc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# spurious: \n",
    "# old is male\n",
    "# young is female\n",
    "\n",
    "def get_celeba_split(labels, spuriouses, majority_multiplier=7):\n",
    "    all_indices = np.arange(len(labels))\n",
    "\n",
    "    old = labels == 0\n",
    "    young = labels == 1\n",
    "    female = spuriouses == 0\n",
    "    male = spuriouses == 1\n",
    "\n",
    "    old_male = all_indices[old & male]\n",
    "    old_female = all_indices[old & female]\n",
    "    young_male = all_indices[young & male]\n",
    "    young_female = all_indices[young & female]\n",
    "\n",
    "    print(\"OLD MALE\", len(old_male), \"OLD FEMALE\", len(old_female), \"YOUNG MALE\", len(young_male), \"YOUNG FEMALE\", len(young_female))\n",
    "\n",
    "    smallest_minority = len(old_female)\n",
    "    smallest_majority = len(old_male)\n",
    "    minority = min(smallest_minority, int(smallest_majority/majority_multiplier))\n",
    "    \n",
    "    majority = minority*majority_multiplier\n",
    "    train_indices = [old_female[:minority], old_male[:majority], young_female[:majority], young_male[:minority]]\n",
    "    print([len(u) for u in train_indices])\n",
    "    return np.concatenate(train_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "234eed76",
   "metadata": {},
   "outputs": [],
   "source": [
    "# val\n",
    "mult = 4\n",
    "\n",
    "ds_name = f\"celeba_age_1_{mult}\"\n",
    "\n",
    "val_indices = get_celeba_split(val_labels, val_spuriouses, majority_multiplier=1)\n",
    "\n",
    "# with unlabeled\n",
    "unlabeled_split = ds_utils.create_dataset_split(train_labels, hparams['num_classes'], unlabeled_split_amt=2)\n",
    "orig_train_indices = unlabeled_split['train_indices']\n",
    "train_indices = orig_train_indices[get_celeba_split(\n",
    "    train_labels[orig_train_indices], \n",
    "    train_spuriouses[orig_train_indices],\n",
    "    majority_multiplier=mult,\n",
    ")]\n",
    "\n",
    "orig_unlabeled_indices = unlabeled_split['unlabeled_indices']\n",
    "unlabeled_indices = orig_unlabeled_indices[get_celeba_split(\n",
    "    train_labels[orig_unlabeled_indices], \n",
    "    train_spuriouses[orig_unlabeled_indices],\n",
    "    majority_multiplier=mult,\n",
    ")]\n",
    "path = os.path.join('index_files', f'{ds_name}_with_unlabeled.pt')\n",
    "torch.save({\n",
    "    'val_indices': val_indices,\n",
    "    'train_indices': train_indices,\n",
    "    'unlabeled_indices': unlabeled_indices\n",
    "}, path)\n",
    "\n",
    "# without unlabeled\n",
    "path = os.path.join('index_files', f'{ds_name}.pt')\n",
    "torch.save({\n",
    "    'val_indices': val_indices,\n",
    "    'train_indices': get_celeba_split(train_labels, train_spuriouses, majority_multiplier=mult),\n",
    "}, path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7c0bc46",
   "metadata": {},
   "outputs": [],
   "source": [
    "# val with same as train split\n",
    "mult = 4\n",
    "\n",
    "ds_name = f\"celeba_age_1_{mult}_same_val_as_train\"\n",
    "\n",
    "val_indices = get_celeba_split(val_labels, val_spuriouses, majority_multiplier=mult)\n",
    "# without unlabeled\n",
    "path = os.path.join('index_files', f'{ds_name}.pt')\n",
    "torch.save({\n",
    "    'val_indices': val_indices,\n",
    "    'train_indices': get_celeba_split(train_labels, train_spuriouses, majority_multiplier=mult),\n",
    "}, path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4073ed37",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}