{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": 1,
            "metadata": {},
            "outputs": [],
            "source": [
                "import os\n",
                "import json\n",
                "import logging\n",
                "import pandas as pd\n",
                "import numpy as np\n",
                "\n",
                "DATA_DIR = \"/resource/database/\"\n",
                "OUTPUT_DIR = \"//output/div_explore\"\n",
                "celeba_dir = os.path.join(OUTPUT_DIR, \"celeba/celeba_v2\")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "#### Define metadata generator and task generator"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 3,
            "metadata": {},
            "outputs": [],
            "source": [
                "def generate_metadata_celeba(data_path, output_path, y_col=9, a_col=20):\n",
                "    logging.info(\"Generating metadata for CelebA...\")\n",
                "    with open(os.path.join(data_path, \"CelebFaces/list_eval_partition.txt\"), \"r\") as f:\n",
                "        splits = f.readlines()\n",
                "\n",
                "    with open(os.path.join(data_path, \"CelebFaces/list_attr_celeba.txt\"), \"r\") as f:\n",
                "        attrs = f.readlines()[2:]\n",
                "\n",
                "    f = open(os.path.join(output_path, \"celeba\", f\"metadata_celeba_y{y_col}_a{a_col}.csv\"), \"w\")\n",
                "    f.write(\"id,filename,split,y,a\\n\")\n",
                "\n",
                "    for i, (split, attr) in enumerate(zip(splits, attrs)):\n",
                "        fi, si = split.strip().split()\n",
                "        ai = attr.strip().split()[1:]\n",
                "        yi = 1 if ai[y_col] == \"1\" else 0\n",
                "        gi = 1 if ai[a_col] == \"1\" else 0\n",
                "        f.write(\"{},{},{},{},{}\\n\".format(i + 1, fi, si, yi, gi))\n",
                "\n",
                "    f.close()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 4,
            "metadata": {},
            "outputs": [],
            "source": [
                "class TaskGenerator:\n",
                "    \"\"\"\n",
                "    Given a dataset represented by a csv metadata file (from download.py), generate a set of tasks.\n",
                "    The entry in metadata file is normally in the form of (id, filename, split, label, attribute).\n",
                "    Each resulting tasks are also represented by a list of (id, filename, split, label, attribute) tuples.\n",
                "    \"\"\"\n",
                "\n",
                "    def __init__(self, task_grid: dict, seed: int):\n",
                "        self.seed = seed\n",
                "        np.random.seed(self.seed)\n",
                "\n",
                "        self.task_grid = task_grid\n",
                "\n",
                "        num_tasks = len(self.task_grid[\"sc\"]) * len(self.task_grid[\"ci\"]) * len(\n",
                "            self.task_grid[\"ai\"]\n",
                "        )\n",
                "        print(\n",
                "            \"Task generator initialized. {} tasks will be generated.\".format(\n",
                "                num_tasks\n",
                "            )\n",
                "        )\n",
                "\n",
                "        # define the group matrix for solving group size in different conditions.\n",
                "        # {(0, 0), (0, 1), (1, 0), (1, 1)}\n",
                "        self.group_matrix = np.array(\n",
                "            [\n",
                "                [1, 0, 0, 1],\n",
                "                [1, 0, 1, 0],\n",
                "                [1, 1, 0, 0],\n",
                "                [1, 1, 1, 1],\n",
                "            ]\n",
                "        )\n",
                "\n",
                "    def pipeline(self, dataname, dataset_info, add_test=True, output_path=None, datasize=5000):\n",
                "        \"\"\"Traverse the task grid and generate all tasks.\"\"\"\n",
                "        all_metadata = pd.read_csv(dataset_info)\n",
                "\n",
                "        tr_metadata = all_metadata[all_metadata[\"split\"] == 0]\n",
                "        tr_group_dict = tr_metadata.groupby([\"y\", \"a\"]).groups\n",
                "        print(\"Original training group size: \", {k: len(v) for k, v in tr_group_dict.items()})\n",
                "\n",
                "        val_metadata = all_metadata[all_metadata[\"split\"] == 1]\n",
                "        val_group_dict = val_metadata.groupby([\"y\", \"a\"]).groups\n",
                "        print(\"Original validation group size: \", {k: len(v) for k, v in val_group_dict.items()})\n",
                "\n",
                "        test_metadata = all_metadata[all_metadata[\"split\"] == 2]\n",
                "        test_group_dict = test_metadata.groupby([\"y\", \"a\"]).groups\n",
                "        print(\"Original test group size: \", {k: len(v) for k, v in test_group_dict.items()})\n",
                "\n",
                "        for sc in self.task_grid[\"sc\"]:\n",
                "            for ci in self.task_grid[\"ci\"]:\n",
                "                for ai in self.task_grid[\"ai\"]:\n",
                "                    print(\"sc:\", sc, \"ci:\", ci, \"ai:\", ai)\n",
                "\n",
                "                    tr_task = self.generate_single_task(sc, ci, ai, tr_group_dict, tr_metadata, datasize, split=\"tr\").sample(frac = 1)\n",
                "                    val_task = self.generate_single_task(sc, ci, ai, val_group_dict, val_metadata, datasize, split=\"va\").sample(frac = 1)\n",
                "\n",
                "                    self.task = pd.concat([tr_task, val_task])\n",
                "\n",
                "                    if add_test:\n",
                "                        # subsample the val_test_metadata\n",
                "                        te_task = self.subsample(test_metadata, int(datasize/2), balanced=True)\n",
                "                        self.task = pd.concat([self.task, te_task])\n",
                "\n",
                "                    self.static_metafeatures = self.compute_static_meta_features(tr_task[\"y\"], tr_task[\"a\"])\n",
                "                    self.static_metafeatures[\"sc\"] = sc\n",
                "                    self.static_metafeatures[\"ci\"] = ci\n",
                "                    self.static_metafeatures[\"ai\"] = ai\n",
                "\n",
                "                    if output_path is not None:\n",
                "                        # format with two decimal places\n",
                "                        metadata_file = os.path.join(\n",
                "                            output_path,\n",
                "                            \"task_{}_sc{:.2f}_ci{:.2f}_ai{:.2f}.csv\".format(dataname, sc, ci, ai),\n",
                "                        )\n",
                "                        self.save(metadata_file)\n",
                "\n",
                "    def subsample(self, metadata, datasize, balanced=False):\n",
                "        \"\"\"Subsample the metadata either balanced (based on group) or orignally.\"\"\"\n",
                "        split_metadata = metadata\n",
                "        if balanced:\n",
                "            group_dict = split_metadata.groupby([\"y\", \"a\"]).groups\n",
                "            group_size = {k: len(v) for k, v in group_dict.items()}\n",
                "            min_group_size = min(group_size.values())\n",
                "            sampled_group_size = int(datasize/4) if min_group_size > datasize/4 else min_group_size\n",
                "            group_sample = pd.concat(\n",
                "                [split_metadata.loc[v].sample(sampled_group_size, random_state=self.seed) for v in group_dict.values()]\n",
                "            )\n",
                "        else:\n",
                "            group_sample = split_metadata.sample(datasize, random_state=self.seed)\n",
                "        return group_sample\n",
                "\n",
                "    def generate_single_task(self, sc, ci, ai, group_dict, metadata, datasize, split):\n",
                "        \"\"\"Main function to generate a single task given sc/ci/ai statistics.\"\"\"\n",
                "        # b = [sc * datasize, max(ai, 1-ai) * datasize, min(ci, 1-ci) * datasize, datasize]\n",
                "        b = [sc * datasize, ai * datasize, ci * datasize, datasize]\n",
                "        num_per_group = np.linalg.solve(self.group_matrix, b)\n",
                "        print(\"num_per_group: \", num_per_group)\n",
                "        # sampling from the group_dict with num_per_group\n",
                "        selected = {}\n",
                "        for i, (k, v) in enumerate(group_dict.items()):\n",
                "            # see if v is a ndarray, if not, convert it to ndarray\n",
                "            if not isinstance(v, np.ndarray):\n",
                "                v = v.to_numpy()\n",
                "            np.random.shuffle(v)\n",
                "            selected[k] = v[:int(num_per_group[i])]\n",
                "            if int(num_per_group[i]) < 0:\n",
                "                print(\"Error: group size is less than 0.\")\n",
                "        # extract from metadata with group_dict\n",
                "        task = pd.concat([metadata.loc[v] for v in selected.values()])\n",
                "\n",
                "        if len(task) < datasize:\n",
                "            print(f\"Warning: {split} task size is less than datasize by {datasize - len(task)}.\")\n",
                "        return task\n",
                "\n",
                "    def save(self, output_file):\n",
                "        \"\"\"Save the tasks to a csv file.\"\"\"\n",
                "        self.task.to_csv(output_file, index=False)\n",
                "        # with open(output_file.replace(\"task_\", \"mfeature_\").replace(\".csv\", \".json\"), \"w\") as f:\n",
                "        #     json.dump(self.static_metafeatures, f)\n",
                "\n",
                "    # meta-features\n",
                "    def compute_static_meta_features(self, y, a):\n",
                "\n",
                "        def calculate_entropy(s):\n",
                "            values, counts = np.unique(s, return_counts=True)\n",
                "            probabilities = counts / counts.sum()\n",
                "            entropy = -np.sum(probabilities * np.log2(probabilities))\n",
                "            return entropy\n",
                "\n",
                "        df = {}\n",
                "        # MI, NMI, Cramer, Tschuprow\n",
                "        from sklearn.metrics import mutual_info_score, normalized_mutual_info_score\n",
                "        from scipy.stats import chi2_contingency\n",
                "        from sklearn.metrics.cluster import contingency_matrix\n",
                "        mi = mutual_info_score(y, a)\n",
                "        nmi = normalized_mutual_info_score(y, a)\n",
                "        cm = contingency_matrix(y, a)\n",
                "        chi2, _, _, _ = chi2_contingency(cm)\n",
                "        cramer = np.sqrt(chi2 / (len(y) * (min(len(np.unique(y)), len(np.unique(a))) - 1)))\n",
                "        # tschuprow = np.sqrt(chi2 / (cm.shape[0] * cm.shape[1]))\n",
                "\n",
                "        # entropy and normalized entropy\n",
                "        entropy_y = calculate_entropy(y)\n",
                "        entropy_a = calculate_entropy(a)\n",
                "        n_entropy_y = entropy_y / np.log2(len(np.unique(y)))\n",
                "        n_entropy_a = entropy_a / np.log2(len(np.unique(a)))\n",
                "\n",
                "        # Difference between the probability of the most frequent class and the probability of the least frequent class\n",
                "        diff_y = np.abs(np.max(np.bincount(y)) - np.min(np.bincount(y))) / len(y)\n",
                "        diff_a = np.abs(np.max(np.bincount(a)) - np.min(np.bincount(a))) / len(a)\n",
                "\n",
                "        df[\"mi\"] = mi\n",
                "        df[\"nmi\"] = nmi\n",
                "        df[\"cramer\"] = cramer\n",
                "        # df[\"tschuprow\"] = tschuprow\n",
                "        df[\"entropy_y\"] = entropy_y\n",
                "        df[\"entropy_a\"] = entropy_a\n",
                "        df[\"n_entropy_y\"] = n_entropy_y\n",
                "        df[\"n_entropy_a\"] = n_entropy_a\n",
                "        df[\"diff_y\"] = diff_y\n",
                "        df[\"diff_a\"] = diff_a\n",
                "        return df\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "#### Generate metadata"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "By default, the attributes and labels are [hair_color, gender]. In DivDis paper (Lee et al., 2023), they consider several others:\n",
                "<!-- make the image smaller -->\n",
                "<img src=\"assets/image.png\" width=\"500\">\n",
                "\n",
                "As we will see, some of the tasks are more close to evenly distributed while others not.  \n",
                "- Balanced: [`Mouth_Slightly_Open`, `Wearing_Lipstick`], [`Attractive`, `Smiling`]\n",
                "- Imbalanced: [`Blond_Hair`, `Male`] (Default), [`Wavy_Hair`, `High_Cheekbones`], [`Heavy_Makeup`, `Big_Lips`]"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "!ls $celeba_dir"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def show_statstics(metadata_file, split=0):\n",
                "    # load metadata\n",
                "    metadata = pd.read_csv(metadata_file)\n",
                "    if split is not None:\n",
                "        metadata = metadata[metadata[\"split\"] == split]\n",
                "    # show the first few rows\n",
                "    print(metadata.head())\n",
                "    print(\"-----------------------------\")\n",
                "    # show joint group counts of y and a\n",
                "    print(metadata.groupby([\"y\", \"a\"]).size())\n",
                "    groups = metadata.groupby([\"y\", \"a\"]).size().tolist()\n",
                "    print(\"sc:\", (groups[0] + groups[3])/len(metadata), \"ci:\", (groups[0] + groups[1])/len(metadata), \"ai:\", (groups[0] + groups[2])/len(metadata))\n",
                "\n",
                "# load txt with pandas\n",
                "all_attr = pd.read_csv(os.path.join(DATA_DIR, \"CelebFaces/list_attr_celeba.txt\"), header=1, delim_whitespace=True)\n",
                "# output column names\n",
                "# print(all_attr.columns)\n",
                "# print(\"-----------------------------\")\n",
                "# # find the index of a given column in all_attr.columns\n",
                "# # print(all_attr.columns.get_loc(\"Blond_Hair\"), all_attr.columns.get_loc(\"Male\"))\n",
                "# print(all_attr.columns.get_loc(\"Mouth_Slightly_Open\"), all_attr.columns.get_loc(\"Wearing_Lipstick\"))\n",
                "# print(all_attr.columns.get_loc(\"Attractive\"), all_attr.columns.get_loc(\"Smiling\"))\n",
                "# # print(all_attr.columns.get_loc(\"Wavy_Hair\"), all_attr.columns.get_loc(\"High_Cheekbones\"))\n",
                "# print(all_attr.columns.get_loc(\"Heavy_Makeup\"), all_attr.columns.get_loc(\"Big_Lips\"))\n",
                "# print(all_attr.columns.get_loc(\"Wearing_Hat\"), all_attr.columns.get_loc(\"Young\"))\n",
                "# print(all_attr.columns.get_loc(\"Eyeglasses\"), all_attr.columns.get_loc(\"Young\"))\n",
                "\n",
                "# print column names and the number of 1s in the column\n",
                "for i, col in enumerate(all_attr.columns):\n",
                "    print(i, col, (all_attr[col]==1).sum(), (all_attr[col]==-1).sum())\n",
                "    \n",
                "# attractive, smiling\n",
                "# mouth slightly open, wearing lipsticks\n",
                "# heavy makeup, big lips\n",
                "# male, black hair\n",
                "# Oval_Face, High_Cheekbones"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "y_col, a_col = 25, 19\n",
                "generate_metadata_celeba(DATA_DIR, OUTPUT_DIR, y_col=y_col, a_col=a_col)\n",
                "show_statstics(os.path.join(OUTPUT_DIR, 'celeba', f\"metadata_celeba_y{y_col}_a{a_col}.csv\"))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "y_col, a_col = 9, 20\n",
                "# generate_metadata_celeba(DATA_DIR, OUTPUT_DIR, y_col=y_col, a_col=a_col)\n",
                "show_statstics(os.path.join(celeba_dir, f\"metadata_celeba_y{y_col}_a{a_col}.csv\"))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 12,
            "metadata": {},
            "outputs": [],
            "source": [
                "metadata = pd.read_csv(os.path.join(celeba_dir, f\"metadata_celeba_y{y_col}_a{a_col}.csv\"))\n",
                "# random sample 4000\n",
                "metadata = metadata.sample(4000, random_state=0)\n",
                "metadata.to_csv(os.path.join(celeba_dir, f\"metadata_celeba_y{y_col}_a{a_col}_sample4000.csv\"), index=False)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "y_col, a_col = 21, 36\n",
                "generate_metadata_celeba(DATA_DIR, OUTPUT_DIR, y_col=y_col, a_col=a_col)\n",
                "show_statstics(os.path.join(celeba_dir, f\"metadata_celeba_y{y_col}_a{a_col}.csv\"))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "y_col, a_col = 2, 31\n",
                "generate_metadata_celeba(DATA_DIR, OUTPUT_DIR, y_col=y_col, a_col=a_col)\n",
                "show_statstics(os.path.join(celeba_dir, f\"metadata_celeba_y{y_col}_a{a_col}.csv\"))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "y_col, a_col = 33, 19\n",
                "generate_metadata_celeba(DATA_DIR, OUTPUT_DIR, y_col=y_col, a_col=a_col)\n",
                "show_statstics(os.path.join(celeba_dir, f\"metadata_celeba_y{y_col}_a{a_col}.csv\"))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "y_col, a_col = 18, 6\n",
                "generate_metadata_celeba(DATA_DIR, OUTPUT_DIR, y_col=y_col, a_col=a_col)\n",
                "show_statstics(os.path.join(celeba_dir, f\"metadata_celeba_y{y_col}_a{a_col}.csv\"))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "#### Generate tasks"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "We start with more balanced tasks. [`Mouth_Slightly_Open` 21, `Wearing_Lipstick` 36], [`Attractive` 2, `Smiling` 31]"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 6,
            "metadata": {},
            "outputs": [],
            "source": [
                "seed = 0\n",
                "dataset = \"CelebA\"\n",
                "# y_col, a_col = 9, 20\n",
                "# y_col, a_col = 18, 6\n",
                "y_col, a_col = 25, 19\n",
                "\n",
                "# create a map between y_col, a_col and the corresponding attribute names\n",
                "attr_map = {\n",
                "    21: \"Mouth_Slightly_Open\",\n",
                "    36: \"Wearing_Lipstick\",\n",
                "    2: \"Attractive\",\n",
                "    31: \"Smiling\",\n",
                "    9: \"Blond_Hair\",\n",
                "    20: \"Male\",\n",
                "    33: \"Wavy_Hair\",\n",
                "    19: \"High_Cheekbones\",\n",
                "    18: \"Heavy_Makeup\",\n",
                "    6: \"Big_Lips\",\n",
                "    8: \"Black_Hair\",\n",
                "    25: \"Oval_Face\",\n",
                "    19: \"High_Cheekbones\",\n",
                "}"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# single shift\n",
                "datasizes = [200, 500, 1000, 2000, 5000, 10000]\n",
                "for datasize in datasizes:\n",
                "    task_dir = os.path.join(celeba_dir, f\"tasks_y{y_col}_a{a_col}_datasize{datasize}_seed{seed}\")\n",
                "    !mkdir $task_dir\n",
                "\n",
                "    # Example usage\n",
                "    task_grid = {\"sc\": [0.01, 0.05, 0.10, 0.30, 0.50, 0.70, 0.90, 0.95, 0.99], \"ci\": [0.50], \"ai\": [0.50]}\n",
                "    task_gen = TaskGenerator(task_grid, seed)\n",
                "    task_gen.pipeline(\n",
                "        dataset.lower(),\n",
                "        dataset_info=os.path.join(OUTPUT_DIR, dataset.lower(), f\"metadata_celeba_y{y_col}_a{a_col}.csv\"),\n",
                "        output_path=task_dir,\n",
                "        datasize=datasize,\n",
                "    )\n",
                "\n",
                "\n",
                "    task_grid = {\"sc\": [0.50], \"ci\": [0.01, 0.05, 0.10, 0.30, 0.50, 0.70, 0.90, 0.95, 0.99], \"ai\": [0.50]}\n",
                "    task_gen = TaskGenerator(task_grid, seed)\n",
                "    task_gen.pipeline(\n",
                "        dataset.lower(),\n",
                "        dataset_info=os.path.join(OUTPUT_DIR, dataset.lower(), f\"metadata_celeba_y{y_col}_a{a_col}.csv\"),\n",
                "        output_path=task_dir,\n",
                "        datasize=datasize,\n",
                "    )\n",
                "\n",
                "\n",
                "    task_grid = {\"sc\": [0.50], \"ci\": [0.50], \"ai\": [0.01, 0.05, 0.10, 0.30, 0.50, 0.70, 0.90, 0.95, 0.99]}\n",
                "    task_gen = TaskGenerator(task_grid, seed)\n",
                "    task_gen.pipeline(\n",
                "        dataset.lower(),\n",
                "        dataset_info=os.path.join(OUTPUT_DIR, dataset.lower(), f\"metadata_celeba_y{y_col}_a{a_col}.csv\"),\n",
                "        output_path=task_dir,\n",
                "        datasize=datasize,\n",
                "    )"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# 3 shifts\n",
                "datasizes = [200, 500, 1000, 2000, 5000, 10000]\n",
                "task_list = [[2, 31], [21, 36], [8, 20], [25, 19]]\n",
                "sc_values = [\n",
                "    0.68,\n",
                "    0.5,\n",
                "    0.3,\n",
                "    0.51,\n",
                "    0.4,\n",
                "    0.57,\n",
                "    0.17,\n",
                "    0.51,\n",
                "    0.55,\n",
                "    0.47,\n",
                "    0.49,\n",
                "    0.52,\n",
                "    0.23,\n",
                "    0.12,\n",
                "    0.36,\n",
                "    0.15,\n",
                "    0.5,\n",
                "    0.49,\n",
                "    0.48,\n",
                "    0.26,\n",
                "    0.28,\n",
                "    0.68,\n",
                "    0.27,\n",
                "    0.25,\n",
                "    0.23,\n",
                "    0.59,\n",
                "    0.72,\n",
                "    0.67,\n",
                "    0.64,\n",
                "    0.7,\n",
                "]\n",
                "ci_values = [\n",
                "    0.35,\n",
                "    0.47,\n",
                "    0.49,\n",
                "    0.36,\n",
                "    0.11,\n",
                "    0.46,\n",
                "    0.4,\n",
                "    0.65,\n",
                "    0.24,\n",
                "    0.22,\n",
                "    0.29,\n",
                "    0.42,\n",
                "    0.24,\n",
                "    0.26,\n",
                "    0.39,\n",
                "    0.39,\n",
                "    0.7,\n",
                "    0.65,\n",
                "    0.4,\n",
                "    0.67,\n",
                "    0.38,\n",
                "    0.81,\n",
                "    0.39,\n",
                "    0.81,\n",
                "    0.6,\n",
                "    0.42,\n",
                "    0.58,\n",
                "    0.85,\n",
                "    0.24,\n",
                "    0.42,\n",
                "]\n",
                "ai_values = [\n",
                "    0.47,\n",
                "    0.31,\n",
                "    0.71,\n",
                "    0.27,\n",
                "    0.65,\n",
                "    0.36,\n",
                "    0.49,\n",
                "    0.7,\n",
                "    0.47,\n",
                "    0.52,\n",
                "    0.55,\n",
                "    0.16,\n",
                "    0.64,\n",
                "    0.8,\n",
                "    0.45,\n",
                "    0.68,\n",
                "    0.65,\n",
                "    0.66,\n",
                "    0.57,\n",
                "    0.17,\n",
                "    0.75,\n",
                "    0.69,\n",
                "    0.53,\n",
                "    0.12,\n",
                "    0.21,\n",
                "    0.46,\n",
                "    0.61,\n",
                "    0.69,\n",
                "    0.36,\n",
                "    0.33,\n",
                "]\n",
                "num=0\n",
                "for y_col, a_col in task_list:\n",
                "    for datasize in datasizes:\n",
                "        task_dir = os.path.join(celeba_dir, f\"tasks_y{y_col}_a{a_col}_datasize{datasize}_seed{seed}\")\n",
                "        !mkdir $task_dir\n",
                "        for i in range(len(sc_values)):\n",
                "            task_grid = {\"sc\": [sc_values[i]], \"ci\": [ci_values[i]], \"ai\": [ai_values[i]]}\n",
                "            # Example usage\n",
                "            task_gen = TaskGenerator(task_grid, seed)\n",
                "            task_gen.pipeline(\n",
                "                dataset.lower(),\n",
                "                dataset_info=os.path.join(OUTPUT_DIR, dataset.lower(), f\"metadata_celeba_y{y_col}_a{a_col}.csv\"),\n",
                "                output_path=task_dir,\n",
                "                datasize=datasize,\n",
                "            )\n",
                "            num+=1\n",
                "print(num)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "!ls $task_dir"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "task = pd.read_csv(os.path.join(task_dir, \"task_celeba_sc0.50_ci0.50_ai0.50.csv\"))\n",
                "show_statstics(os.path.join(task_dir, \"task_celeba_sc0.30_ci0.50_ai0.50.csv\"), split=0)\n",
                "show_statstics(os.path.join(task_dir, \"task_celeba_sc0.30_ci0.50_ai0.50.csv\"), split=1)\n",
                "show_statstics(os.path.join(task_dir, \"task_celeba_sc0.30_ci0.50_ai0.50.csv\"), split=2)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import matplotlib.pyplot as plt\n",
                "import matplotlib.image as mpimg\n",
                "# print a single image\n",
                "img = mpimg.imread(os.path.join(DATA_DIR, f\"CelebFaces/img_align_celeba/193\", \"193810.jpg\"))\n",
                "plt.imshow(img)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "#### Check an example task"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import matplotlib.pyplot as plt\n",
                "import matplotlib.image as mpimg\n",
                "\n",
                "task_example = pd.read_csv(os.path.join(task_dir, \"task_celeba_sc0.5_ci0.5_ai0.5.csv\"))\n",
                "print(os.path.join(task_dir, \"task_celeba_sc0.5_ci0.5_ai0.5.csv\"))\n",
                "task_example = task_example[task_example[\"split\"] == 2]\n",
                "task_groups = task_example.groupby([\"y\", \"a\"])\n",
                "# plot 10 images for each of 4 group (0,0), (0,1), (1,0), (1,1)\n",
                "fig, axs = plt.subplots(4, 10, figsize=(20, 10))\n",
                "for i, (group, rows) in enumerate(task_groups):\n",
                "    for j, (_, row) in enumerate(rows.iterrows()):\n",
                "        if j == 10:\n",
                "            break\n",
                "        img = mpimg.imread(os.path.join(DATA_DIR, f\"CelebFaces/img_align_celeba/{row['filename'][:3]}\", row[\"filename\"]))\n",
                "        axs[i, j].imshow(img)\n",
                "        axs[i, j].axis(\"off\")\n",
                "        axs[i, j].set_title(f\"y: {row['y']}, a: {row['a']}\")\n",
                "# add a big title\n",
                "fig.suptitle(\"y = \" + attr_map[y_col] + \" | \" + \"a = \" + attr_map[a_col], fontsize=16)\n",
                "# add number of samples in each group as subtitle\n",
                "fig.text(0.5, 0.04, \"Number of samples: \" + str({k: len(v) for k, v in task_groups.groups.items()}), ha=\"center\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "task1 = pd.read_csv(os.path.join(\"//output/div_explore/celeba/\", f\"metadata_celeba_y2_a31.csv\"))\n",
                "task1 = task1[task1[\"split\"] == 0]\n",
                "print(task1.head())"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "#### Check an example attribute"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 5,
            "metadata": {},
            "outputs": [],
            "source": [
                "import pandas as pd\n",
                "# load txt with pandas\n",
                "all_attr = pd.read_csv(os.path.join(DATA_DIR, \"CelebFaces/list_attr_celeba.txt\"), header=1, delim_whitespace=True)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "all_attr"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "all_attr.columns"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import matplotlib.pyplot as plt\n",
                "import matplotlib.image as mpimg\n",
                "\n",
                "# task_example = pd.read_csv(os.path.join(task_dir, \"task_celeba_sc0.5_ci0.5_ai0.5.csv\"))\n",
                "# print(os.path.join(task_dir, \"task_celeba_sc0.5_ci0.5_ai0.5.csv\"))\n",
                "# task_example = task_example[task_example[\"split\"] == 2]\n",
                "# task_groups = task_example.groupby([\"y\", \"a\"])\n",
                "# # plot 10 images for each of 4 group (0,0), (0,1), (1,0), (1,1)\n",
                "# fig, axs = plt.subplots(4, 10, figsize=(20, 10))\n",
                "# for i, (group, rows) in enumerate(task_groups):\n",
                "#     for j, (_, row) in enumerate(rows.iterrows()):\n",
                "#         if j == 10:\n",
                "#             break\n",
                "#         img = mpimg.imread(os.path.join(DATA_DIR, f\"CelebFaces/img_align_celeba/{row['filename'][:3]}\", row[\"filename\"]))\n",
                "#         axs[i, j].imshow(img)\n",
                "#         axs[i, j].axis(\"off\")\n",
                "#         axs[i, j].set_title(f\"y: {row['y']}, a: {row['a']}\")\n",
                "# # add a big title\n",
                "# fig.suptitle(\"y = \" + attr_map[y_col] + \" | \" + \"a = \" + attr_map[a_col], fontsize=16)\n",
                "# # add number of samples in each group as subtitle\n",
                "# fig.text(0.5, 0.04, \"Number of samples: \" + str({k: len(v) for k, v in task_groups.groups.items()}), ha=\"center\")\n",
                "\n",
                "# show 10 images for each column of all_attr where the value is 1\n",
                "all_attr['image_id'] = all_attr.index\n",
                "# get number of columns\n",
                "num_cols = len(all_attr.columns) - 1\n",
                "\n",
                "fig, axs = plt.subplots(num_cols, 10, figsize=(30, 70))\n",
                "for i in range(num_cols):\n",
                "    attr = all_attr.columns[i]\n",
                "    print(attr, len(all_attr[all_attr[attr] == 1]))\n",
                "    rows = all_attr[all_attr[attr] == 1].sample(10, random_state=0)\n",
                "    # plot 10 images for each column\n",
                "    for j, (_, row) in enumerate(rows.iterrows()):\n",
                "        img = mpimg.imread(os.path.join(DATA_DIR, f\"CelebFaces/img_align_celeba/{row['image_id'][:3]}\", row[\"image_id\"]))\n",
                "        axs[i,j].imshow(img)\n",
                "        axs[i,j].axis(\"off\")\n",
                "        axs[i,j].set_title(attr)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "subpop_bench",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.9.18"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}