{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import os\n",
    "from ffcv.writer import DatasetWriter\n",
    "from ffcv.fields import RGBImageField, IntField\n",
    "import torch\n",
    "import src.pytorch_datasets as pytorch_datasets\n",
    "from src import ffcv_utils\n",
    "import yaml\n",
    "from src.config_parsing import ffcv_read_check_override_config\n",
    "import pprint\n",
    "from src.ffcv_utils import get_training_loaders\n",
    "from src.pytorch_datasets import create_val_split, get_unlabeled_indices, IndexedDataset\n",
    "from robustness.tools.vis_tools import show_image_row"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_dataset(config, pipeline_subset=['image', 'label', 'index']):\n",
    "    with open(f\"dataset_configs/{config}\", 'r') as file:\n",
    "        hparams = yaml.safe_load(file)\n",
    "    hparams = ffcv_read_check_override_config(hparams)\n",
    "    print(\"=========== Current Config ==================\")\n",
    "    pprint.pprint(hparams, indent=4)\n",
    "    train_loader, val_loader, test_loader = get_training_loaders(hparams, pipeline_subset=pipeline_subset)\n",
    "    return train_loader, val_loader, test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_load, val_load, test_load = test_dataset('imagenet_c.yaml', pipeline_subset=['image', 'label', 'spurious', 'index'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imagenet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = 'REDACTED'\n",
    "supervised_file_path = data_dir + 'distorted_indices'\n",
    "with open(supervised_file_path,'r') as f:\n",
    "    distorted_inds=[]\n",
    "    distorted_inds_dict = {}\n",
    "    for line in f:\n",
    "        strnum = line.rstrip()\n",
    "        distorted_ind = float(strnum)\n",
    "        distorted_inds.append(distorted_ind)\n",
    "        distorted_inds_dict[distorted_ind] = 1\n",
    "print('No repeats: ', len(set(distorted_inds))==len(distorted_inds))\n",
    "\n",
    "      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(set(distorted_inds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set(distorted_inds_dict.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('No repeats: ', len(set(distorted_inds_dict.keys()))==len(distorted_inds_dict.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "class ImagenetC(torch.utils.data.Dataset):\n",
    "    def __init__(self, ds, supervised_file_path): #indices, return_all_attr): #supervised_key\n",
    "        # indices should have the key inds which is indices to extract\n",
    "        # all other keys are annotation information\n",
    "        # supervised_key is the key name of the supervised label\n",
    "        # if return_all_attr is True then also return the full dict of attr info\n",
    "        self.ds = ds\n",
    "        self.supervised_file_path = supervised_file_path \n",
    "        \n",
    "        with open(supervised_file_path,'r') as f:\n",
    "            for line in f:\n",
    "                distorted_inds=[]\n",
    "                distorted_inds_dict = {}\n",
    "                for line in f:\n",
    "                    strnum = line.rstrip()\n",
    "                    distorted_ind = float(strnum)\n",
    "                    distorted_inds.append(distorted_ind)\n",
    "                    distorted_inds_dict[distorted_ind] = 1\n",
    "            print('No repeats: ', len(set(distorted_inds))==len(distorted_inds))\n",
    "        \n",
    "        self.distorted_inds_dict = distorted_inds_dict\n",
    "        self.distorted_inds = distorted_inds\n",
    "        \n",
    "        # generate list based on this!\n",
    "        \n",
    "        #self.indices = indices\n",
    "        #assert 'inds' in self.indices\n",
    "        #self.return_all_attr = return_all_attr\n",
    "        \n",
    "        \"\"\"if weights is None:\n",
    "            self.weights = torch.ones(len(indices['inds']))\n",
    "        else:\n",
    "            self.weights = weights\n",
    "        self.supervised_key = supervised_key\n",
    "        self.other_keys = [k for k in self.indices.keys() if k != 'inds']\n",
    "        \n",
    "        assert supervised_key in self.indices\"\"\"\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.ds) #len(self.indices['inds'])\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        #ds_index = self.indices['inds'][idx]\n",
    "        x, y = self.ds[idx] #ds_index]\n",
    "        spurious = (idx in self.distorted_inds_dict.keys())\n",
    "        return x, y, spurious\n",
    "        \n",
    "        # ALSO: sanity check distorted inds, that images look distorted vs not, etc \n",
    "        \n",
    "        # CHECK WHAT INDICES HAS: use distorted_inds_dict here instead?\n",
    "        \"\"\"y = self.indices[self.supervised_key][idx] \n",
    "        weight = self.weights[idx]\n",
    "        if self.return_all_attr:\n",
    "            meta = {k: self.indices[k][idx] for k in self.other_keys}\n",
    "            return x, y, meta, weight\n",
    "        else:\n",
    "            return x, y\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = 'REDACTED'\n",
    "supervised_file_path = data_dir + 'distorted_indices'\n",
    "distorted_dataset = dset.ImageFolder(\n",
    "            root=data_dir,\n",
    "            transform=trn.Compose([trn.CenterCrop(224), trn.ToTensor(), trn.Normalize(mean, std)]))\n",
    "tryINC = ImagenetC(distorted_dataset, supervised_file_path)\n",
    "spur_loader = torch.utils.data.DataLoader(tryINC, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tryINC.distorted_inds_dict[0]\n",
    "tryINC.distorted_inds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tryINC.distorted_inds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spurious.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "count=1\n",
    "for x,y,spurious in spur_loader:\n",
    "    print('spurious', spurious.sum(), spurious[0:20].sum())\n",
    "    if count>2:\n",
    "        break\n",
    "    count=count+1\n",
    "    show_image_row([x[0:10], x[10:20]],\n",
    "             [\"Examples1\", \"Examples2\"],\n",
    "             fontsize=18,\n",
    "             filename=\"./ignore.png\")\n",
    "\n",
    "    print(spurious[0:10])\n",
    "    print('--')\n",
    "    print(spurious[10:20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(spurious[0:10])\n",
    "print('--')\n",
    "print(spurious[10:20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torchvision.datasets as dset\n",
    "import torchvision.transforms as trn\n",
    "import torch\n",
    "\n",
    "mean = [0.485, 0.456, 0.406]\n",
    "std = [0.229, 0.224, 0.225]\n",
    "\n",
    "\n",
    "data_dir = 'REDACTED'\n",
    "supervised_file_path_val = data_dir + 'distorted_indices'\n",
    "distorted_dataset_val = dset.ImageFolder(\n",
    "            root=data_dir,\n",
    "            transform=trn.Compose([trn.CenterCrop(224)])) #, trn.ToTensor(), trn.Normalize(mean, std)]))\n",
    "\n",
    "train_dir = 'REDACTED'\n",
    "# NEED TO MAKE A WRAPPER THAT INCLUDES ATTRIBUTE (distorted or not) \n",
    "supervised_file_path_train = data_dir + 'distorted_indices'\n",
    "distorted_dataset_train = dset.ImageFolder(\n",
    "            root=train_dir,\n",
    "            transform=trn.Compose([trn.CenterCrop(224)])) #, trn.ToTensor(), trn.Normalize(mean, std)]))\n",
    "\n",
    "final_val = ImagenetC(distorted_dataset_val, supervised_file_path_val)\n",
    "final_train = ImagenetC(distorted_dataset_train, supervised_file_path_train)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(distorted_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_val_split_from_test(train_targets, test_targets, num_classes, split_amt=5):\n",
    "    N = len(test_targets)\n",
    "    val_indices = []\n",
    "    for c in range(num_classes):\n",
    "        val_indices.append(torch.arange(N)[test_targets == c][::split_amt])\n",
    "    val_indices = torch.cat(val_indices)\n",
    "    test_indices = torch.tensor([u for u in torch.arange(N) if u not in val_indices])\n",
    "    Ntrain = len(train_targets)\n",
    "    train_indices = torch.tensor([u for u in torch.arange(Ntrain)])\n",
    "    indices_dict = {\n",
    "        'val_indices': val_indices,\n",
    "        'train_indices': train_indices,\n",
    "        'test_indices': test_indices,\n",
    "    }\n",
    "    return indices_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_val_split(train_targets, num_classes, split_amt=5):\n",
    "    N = len(train_targets)\n",
    "    val_indices = []\n",
    "    for c in range(num_classes):\n",
    "        val_indices.append(torch.arange(N)[train_targets == c][::split_amt])\n",
    "    val_indices = torch.cat(val_indices)\n",
    "    train_indices = torch.tensor([u for u in torch.arange(N) if u not in val_indices])\n",
    "    indices_dict = {\n",
    "        'val_indices': val_indices,\n",
    "        'train_indices': train_indices,\n",
    "    }\n",
    "    return indices_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IndexedDataset():\n",
    "    def __init__(self, ds):\n",
    "        self.ds = ds\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.ds)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return *self.ds[idx], idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BETON_ROOT = \"REDACTED\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_betons(ds_name, train_ds, test_ds, val_ds=None):\n",
    "    os.makedirs(os.path.join(BETON_ROOT, ds_name), exist_ok=True)\n",
    "    ds_pairs = [\n",
    "        ['train', train_ds],\n",
    "        ['test', test_ds]\n",
    "    ]\n",
    "    if val_ds is not None:\n",
    "        ds_pairs.append(['val', val_ds])\n",
    "    \n",
    "    for split_name, ds in ds_pairs:\n",
    "        ds = IndexedDataset(ds)\n",
    "        write_path = os.path.join(BETON_ROOT, ds_name, f\"{ds_name}_{split_name}.beton\")\n",
    "        # Pass a type for each data field\n",
    "        writer = DatasetWriter(write_path, {\n",
    "            # Tune options to optimize dataset size, throughput at train-time\n",
    "            'image': RGBImageField(),\n",
    "            'label': IntField(),\n",
    "            'index': IntField(),\n",
    "        })\n",
    "\n",
    "        # Write dataset\n",
    "        writer.from_indexed_dataset(ds)\n",
    "\n",
    "def write_imagenet_c_betons(ds_name, train_ds, test_ds, val_ds=None):\n",
    "    os.makedirs(os.path.join(BETON_ROOT, ds_name), exist_ok=True)\n",
    "    ds_pairs = [\n",
    "        ['train', train_ds]\n",
    "        #['test', test_ds]\n",
    "    ]\n",
    "    if val_ds is not None:\n",
    "        ds_pairs.append(['val', val_ds])\n",
    "    \n",
    "    for split_name, ds in ds_pairs:\n",
    "        ds = IndexedDataset(ds)\n",
    "        write_path = os.path.join(BETON_ROOT, ds_name, f\"{ds_name}_{split_name}.beton\")\n",
    "        # Pass a type for each data field\n",
    "        writer = DatasetWriter(write_path, {\n",
    "            # Tune options to optimize dataset size, throughput at train-time\n",
    "            'image': RGBImageField(max_resolution=75),\n",
    "            'label': IntField(),\n",
    "            'spurious': IntField(),\n",
    "            'index': IntField(),\n",
    "        })\n",
    "\n",
    "        # Write dataset\n",
    "        writer.from_indexed_dataset(ds)\n",
    "\n",
    "def write_celeba_betons(ds_name, train_ds, test_ds, val_ds=None):\n",
    "    os.makedirs(os.path.join(BETON_ROOT, ds_name), exist_ok=True)\n",
    "    ds_pairs = [\n",
    "        ['train', train_ds],\n",
    "        ['test', test_ds]\n",
    "    ]\n",
    "    if val_ds is not None:\n",
    "        ds_pairs.append(['val', val_ds])\n",
    "    \n",
    "    for split_name, ds in ds_pairs:\n",
    "        ds = IndexedDataset(ds)\n",
    "        write_path = os.path.join(BETON_ROOT, ds_name, f\"{ds_name}_{split_name}.beton\")\n",
    "        # Pass a type for each data field\n",
    "        writer = DatasetWriter(write_path, {\n",
    "            # Tune options to optimize dataset size, throughput at train-time\n",
    "            'image': RGBImageField(max_resolution=75),\n",
    "            'label': IntField(),\n",
    "            'spurious': IntField(),\n",
    "            'index': IntField(),\n",
    "        })\n",
    "\n",
    "        # Write dataset\n",
    "        writer.from_indexed_dataset(ds)\n",
    "        \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CIFAR10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_ds_path = \"REDACTED\"\n",
    "train_ds = torchvision.datasets.CIFAR10(orig_ds_path, train=True)\n",
    "test_ds = torchvision.datasets.CIFAR10(orig_ds_path, train=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "write_betons('cifar', train_ds, test_ds, val_ds=None)\n",
    "indices_dict = create_val_split(torch.tensor(train_ds.targets), 10, 5)\n",
    "torch.save(indices_dict, 'index_files/cifar_indices.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 10\n",
    "# write subsets\n",
    "for fold in [2, 4, 5, 10]:\n",
    "    indices_dict = create_val_split(torch.tensor(train_ds.targets), 10, 5)\n",
    "    base_inds = indices_dict['train_indices']\n",
    "    sub_indices_dict = create_val_split(torch.tensor(train_ds.targets)[base_inds], num_classes, fold)\n",
    "    result_indices = {\n",
    "        'train_indices': base_inds[sub_indices_dict['val_indices']],\n",
    "        'unlabeled_indices': base_inds[sub_indices_dict['train_indices']],\n",
    "        'val_indices': indices_dict['val_indices']\n",
    "    }\n",
    "    print(\"--\", fold, \"--\")\n",
    "    for k, v in result_indices.items():\n",
    "        print(k, len(v))\n",
    "    torch.save(result_indices, f'index_files/cifar_indices_{fold}.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, test_loader = test_dataset('cifar_mslurm.yaml')\n",
    "\n",
    "for batch in train_loader:\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CIFAR 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_ds_path = \"REDACTED\"\n",
    "train_ds = torchvision.datasets.CIFAR100(orig_ds_path, train=True)\n",
    "test_ds = torchvision.datasets.CIFAR100(orig_ds_path, train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_unlabeled(\"cifar100\", torch.tensor(train_ds.targets), 100, [2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "write_betons('cifar100', train_ds, test_ds, val_ds=None)\n",
    "indices_dict = create_val_split(torch.tensor(train_ds.targets), 100, 5)\n",
    "torch.save(indices_dict, 'index_files/cifar100_indices.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Super CIFAR100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "write_betons('supercifar100', train_ds, test_ds, val_ds=None)\n",
    "indices_dict = create_val_split(torch.tensor(train_ds.targets), 20, 5)\n",
    "torch.save(indices_dict, 'index_files/supercifar100_indices.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "get_unlabeled(\"supercifar100\", torch.tensor(train_ds.targets), 20, [2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds.classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get spurious\n",
    "import numpy as np\n",
    "classes_to_drop = []\n",
    "orig_indices = torch.load(\"index_files/supercifar100_indices_2.pt\")\n",
    "subclass_targets = np.array(train_ds.subclass_targets)\n",
    "targets = np.array(train_ds.targets)\n",
    "for c in range(20):\n",
    "    class_to_drop = np.unique(np.array(subclass_targets[targets == c]))\n",
    "    classes_to_drop.append(class_to_drop[0])\n",
    "new_train_indices = []\n",
    "for c in range(100):\n",
    "    mask = subclass_targets[orig_indices['train_indices']] == c\n",
    "    if c in classes_to_drop:\n",
    "        new_train_indices.append(orig_indices['train_indices'][mask][::7])\n",
    "    else:\n",
    "        new_train_indices.append(orig_indices['train_indices'][mask])\n",
    "new_train_indices = torch.cat(new_train_indices)\n",
    "orig_indices['train_indices'] = new_train_indices\n",
    "orig_indices['classes_to_drop'] = torch.tensor(classes_to_drop)\n",
    "torch.save(orig_indices, 'index_files/spurious_supercifar100_indices.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes_to_drop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CelebA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = pytorch_datasets.SpuriousAttributeCelebA(root=\"REDACTED\", split='train') \n",
    "val_ds = pytorch_datasets.SpuriousAttributeCelebA(root=\"REDACTED\", split='valid') \n",
    "test_ds = pytorch_datasets.SpuriousAttributeCelebA(root=\"REDACTED\", split='test') \n",
    "write_celeba_betons('celeba', train_ds, test_ds, val_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute Training Splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "is_male = train_ds.male_targets == 1\n",
    "is_blond = train_ds.blond_hair_targets == 1\n",
    "is_black_hair = train_ds.black_hair_targets == 1\n",
    "\n",
    "male_and_blond = is_male & is_blond\n",
    "male_and_black_hair = is_male & (is_black_hair)\n",
    "female_and_blond = (~is_male) & is_blond\n",
    "female_and_black_hair = (~is_male) & (is_black_hair)\n",
    "\n",
    "overall_indices = torch.arange(len(train_ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1 to 7 split\n",
    "K = 7\n",
    "np.random.seed(10)\n",
    "taken_male_and_blond = overall_indices[male_and_blond]\n",
    "lowest = len(taken_male_and_blond)\n",
    "taken_male_and_black_hair = np.random.choice(overall_indices[male_and_black_hair], replace=False, size=lowest*K)\n",
    "taken_female_and_blond = np.random.choice(overall_indices[female_and_blond], replace=False, size=lowest*K)\n",
    "taken_female_and_black_hair = np.random.choice(overall_indices[female_and_black_hair], replace=False, size=lowest)\n",
    "print(len(taken_male_and_blond), len(taken_male_and_black_hair), len(taken_female_and_blond), len(taken_female_and_black_hair))\n",
    "one_to_seven_train = np.concatenate([taken_male_and_blond, taken_male_and_black_hair, taken_female_and_blond, taken_female_and_black_hair])\n",
    "one_to_seven_train = torch.tensor(one_to_seven_train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extreme Split\n",
    "np.random.seed(10)\n",
    "taken_male_and_blond, taken_female_and_black_hair = np.array([]),  np.array([])\n",
    "\n",
    "taken_male_and_black_hair = overall_indices[male_and_black_hair]\n",
    "lowest = len(taken_male_and_black_hair)\n",
    "taken_female_and_blond = np.random.choice(overall_indices[female_and_blond], replace=False, size=lowest)\n",
    "print(len(taken_male_and_blond), len(taken_male_and_black_hair), len(taken_female_and_blond), len(taken_female_and_black_hair))\n",
    "extreme_train = np.concatenate([taken_male_and_blond, taken_male_and_black_hair, taken_female_and_blond, taken_female_and_black_hair])\n",
    "extreme_train = torch.tensor(extreme_train)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Val Splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_male = val_ds.male_targets == 1\n",
    "is_blond = val_ds.blond_hair_targets == 1\n",
    "is_black_hair = val_ds.black_hair_targets == 1\n",
    "\n",
    "male_and_blond = is_male & is_blond\n",
    "male_and_black_hair = is_male & (is_black_hair)\n",
    "female_and_blond = (~is_male) & is_blond\n",
    "female_and_black_hair = (~is_male) & (is_black_hair)\n",
    "\n",
    "overall_indices = torch.arange(len(val_ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Equal split\n",
    "np.random.seed(10)\n",
    "taken_male_and_blond = overall_indices[male_and_blond]\n",
    "lowest = len(taken_male_and_blond)\n",
    "taken_male_and_black_hair = np.random.choice(overall_indices[male_and_black_hair], replace=False, size=lowest)\n",
    "taken_female_and_blond = np.random.choice(overall_indices[female_and_blond], replace=False, size=lowest)\n",
    "taken_female_and_black_hair = np.random.choice(overall_indices[female_and_black_hair], replace=False, size=lowest)\n",
    "print(len(taken_male_and_blond), len(taken_male_and_black_hair), len(taken_female_and_blond), len(taken_female_and_black_hair))\n",
    "equal_val = np.concatenate([taken_male_and_blond, taken_male_and_black_hair, taken_female_and_blond, taken_female_and_black_hair])\n",
    "equal_val = torch.tensor(equal_val)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_val = overall_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for train_indices, train_name in [[one_to_seven_train, \"one_to_seven_train\"], [extreme_train, 'extreme_train']]:\n",
    "    for val_indices, val_name in [[all_val, \"all_val\"], [equal_val, 'equal_val']]:\n",
    "        indices_dict = {'val_indices': val_indices.long(), 'train_indices': train_indices.long()}\n",
    "        torch.save(indices_dict, f\"index_files/celeba_{train_name}_{val_name}.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize Dataset Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "from src.config_parsing import ffcv_read_check_override_config\n",
    "import pprint\n",
    "from src.ffcv_utils import get_training_loaders\n",
    "\n",
    "with open(\"dataset_configs/celeba.yaml\", 'r') as file:\n",
    "    hparams = yaml.safe_load(file)\n",
    "hparams = ffcv_read_check_override_config(hparams)\n",
    "print(\"=========== Current Config ==================\")\n",
    "pprint.pprint(hparams, indent=4)\n",
    "train_loader, val_loader, test_loader = get_training_loaders(hparams, pipeline_subset=['image', 'label', 'spurious'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for batch in train_loader:\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "InvNorm = ffcv_utils.inv_norm(hparams['mean'], hparams['std'])\n",
    "for b in range(100):\n",
    "    img = batch[0][b]\n",
    "    print(batch[1][b].item(), (batch[2][b].item()))\n",
    "    display(torchvision.transforms.ToPILImage()(InvNorm(img)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 -ffcv\n",
   "language": "python",
   "name": "ffcv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}