{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import logging\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "DATA_DIR = \"//output/div_explore/metashift/COCO-Cat-Dog-indoor-outdoor\"\n",
    "OUTPUT_DIR = \"//output/div_explore\"\n",
    "coco_dir = os.path.join(OUTPUT_DIR, \"metashift/COCO-Cat-Dog-indoor-outdoor\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls $coco_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate metadata\n",
    "import pickle\n",
    "import os\n",
    "# load .pkl\n",
    "path = \"//output/div_explore/metashift/COCO-Cat-Dog-indoor-outdoor/imageID_to_group.pkl\"\n",
    "with open(path, \"rb\") as pkl_f:\n",
    "    gt = pickle.load( pkl_f )\n",
    "\n",
    "metadata_all = {\"filename\":[], \"split\":[], \"y\":[], \"a\":[]}\n",
    "\n",
    "folder_list = [\"train\", \"val_out_of_domain\"]\n",
    "cls_list= [\"cat\", \"dog\"]\n",
    "for f in folder_list:\n",
    "    for c in cls_list:\n",
    "        curr_path = os.path.join(coco_dir, f, c)\n",
    "        imgs = os.listdir(curr_path)\n",
    "        metadata_all[\"filename\"].extend(imgs)\n",
    "        if f == \"train\":\n",
    "            metadata_all[\"split\"].extend([0]*len(imgs))\n",
    "        else:\n",
    "            metadata_all[\"split\"].extend([2]*len(imgs))\n",
    "        if c == \"dog\":\n",
    "            metadata_all[\"y\"].extend([0]*len(imgs))\n",
    "        else:\n",
    "            metadata_all[\"y\"].extend([1]*len(imgs))\n",
    "        metadata_all[\"a\"].extend([gt[img[:-4]][0][4:-1] for img in imgs])\n",
    "df_metadata = pd.DataFrame(metadata_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_metadata['a'] = df_metadata['a'].apply(lambda x: 1 if x == 'indoor' else 0)\n",
    "df_metadata['id'] = df_metadata.index\n",
    "df_metadata.to_csv(os.path.join(coco_dir, \"metadata_all.csv\"), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_metadata[df_metadata['split']==0].groupby(['y', 'a']).size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TaskGenerator:\n",
    "    \"\"\"\n",
    "    Given a dataset represented by a csv metadata file (from download.py), generate a set of tasks.\n",
    "    The entry in metadata file is normally in the form of (id, filename, split, label, attribute).\n",
    "    Each resulting tasks are also represented by a list of (id, filename, split, label, attribute) tuples.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, task_grid: dict, seed: int):\n",
    "        self.seed = seed\n",
    "        np.random.seed(self.seed)\n",
    "\n",
    "        self.task_grid = task_grid\n",
    "\n",
    "        num_tasks = len(self.task_grid[\"sc\"]) * len(self.task_grid[\"ci\"]) * len(\n",
    "            self.task_grid[\"ai\"]\n",
    "        )\n",
    "        print(\n",
    "            \"Task generator initialized. {} tasks will be generated.\".format(\n",
    "                num_tasks\n",
    "            )\n",
    "        )\n",
    "\n",
    "        # define the group matrix for solving group size in different conditions.\n",
    "        # {(0, 0), (0, 1), (1, 0), (1, 1)}\n",
    "        self.group_matrix = np.array(\n",
    "            [\n",
    "                [1, 0, 0, 1],\n",
    "                [1, 0, 1, 0],\n",
    "                [1, 1, 0, 0],\n",
    "                [1, 1, 1, 1],\n",
    "            ]\n",
    "        )\n",
    "\n",
    "    def pipeline(self, dataname, dataset_info, add_test=True, output_path=None, datasize=5000):\n",
    "        \"\"\"Traverse the task grid and generate all tasks.\"\"\"\n",
    "        all_metadata = pd.read_csv(dataset_info)\n",
    "\n",
    "        tr_metadata = all_metadata[all_metadata[\"split\"] == 0]\n",
    "        tr_group_dict = tr_metadata.groupby([\"y\", \"a\"]).groups\n",
    "        print(\"Original training group size: \", {k: len(v) for k, v in tr_group_dict.items()})\n",
    "\n",
    "        val_metadata = all_metadata[all_metadata[\"split\"] == 1]\n",
    "        val_group_dict = val_metadata.groupby([\"y\", \"a\"]).groups\n",
    "        print(\"Original validation group size: \", {k: len(v) for k, v in val_group_dict.items()})\n",
    "\n",
    "        test_metadata = all_metadata[all_metadata[\"split\"] == 2]\n",
    "        test_group_dict = test_metadata.groupby([\"y\", \"a\"]).groups\n",
    "        print(\"Original test group size: \", {k: len(v) for k, v in test_group_dict.items()})\n",
    "\n",
    "        for sc in self.task_grid[\"sc\"]:\n",
    "            for ci in self.task_grid[\"ci\"]:\n",
    "                for ai in self.task_grid[\"ai\"]:\n",
    "                    print(\"sc:\", sc, \"ci:\", ci, \"ai:\", ai)\n",
    "\n",
    "                    tr_task, flag = self.generate_single_task(sc, ci, ai, tr_group_dict, tr_metadata, datasize, split=\"tr\")# .sample(frac = 1)\n",
    "                    if not flag: continue\n",
    "                    # val_task = self.generate_single_task(sc, ci, ai, val_group_dict, val_metadata, datasize, split=\"va\").sample(frac = 1)\n",
    "\n",
    "                    # self.task = pd.concat([tr_task, val_task])\n",
    "                    self.task = tr_task\n",
    "\n",
    "\n",
    "                    if add_test:\n",
    "                        # subsample the val_test_metadata\n",
    "                        # te_task = self.subsample(test_metadata, int(datasize/2), balanced=True)\n",
    "                        te_task = test_metadata\n",
    "                        self.task = pd.concat([self.task, te_task])\n",
    "\n",
    "                    self.static_metafeatures = self.compute_static_meta_features(tr_task[\"y\"], tr_task[\"a\"])\n",
    "                    self.static_metafeatures[\"sc\"] = sc\n",
    "                    self.static_metafeatures[\"ci\"] = ci\n",
    "                    self.static_metafeatures[\"ai\"] = ai\n",
    "\n",
    "                    if output_path is not None:\n",
    "                        # format with two decimal places\n",
    "                        metadata_file = os.path.join(\n",
    "                            output_path,\n",
    "                            \"task_{}_sc{:.2f}_ci{:.2f}_ai{:.2f}.csv\".format(dataname, sc, ci, ai),\n",
    "                        )\n",
    "                        self.save(metadata_file)\n",
    "\n",
    "    def subsample(self, metadata, datasize, balanced=False):\n",
    "        \"\"\"Subsample the metadata either balanced (based on group) or orignally.\"\"\"\n",
    "        split_metadata = metadata\n",
    "        if balanced:\n",
    "            group_dict = split_metadata.groupby([\"y\", \"a\"]).groups\n",
    "            group_size = {k: len(v) for k, v in group_dict.items()}\n",
    "            min_group_size = min(group_size.values())\n",
    "            sampled_group_size = int(datasize/4) if min_group_size > datasize/4 else min_group_size\n",
    "            group_sample = pd.concat(\n",
    "                [split_metadata.loc[v].sample(sampled_group_size, random_state=self.seed) for v in group_dict.values()]\n",
    "            )\n",
    "        else:\n",
    "            group_sample = split_metadata.sample(datasize, random_state=self.seed)\n",
    "        return group_sample\n",
    "\n",
    "    def generate_single_task(self, sc, ci, ai, group_dict, metadata, datasize, split):\n",
    "        \"\"\"Main function to generate a single task given sc/ci/ai statistics.\"\"\"\n",
    "        flag=True\n",
    "        # b = [sc * datasize, max(ai, 1-ai) * datasize, min(ci, 1-ci) * datasize, datasize]\n",
    "        b = [sc * datasize, ai * datasize, ci * datasize, datasize]\n",
    "        num_per_group = np.linalg.solve(self.group_matrix, b)\n",
    "        print(\"num_per_group: \", num_per_group)\n",
    "        # sampling from the group_dict with num_per_group\n",
    "        selected = {}\n",
    "        for i, (k, v) in enumerate(group_dict.items()):\n",
    "            # see if v is a ndarray, if not, convert it to ndarray\n",
    "            if not isinstance(v, np.ndarray):\n",
    "                v = v.to_numpy()\n",
    "            np.random.shuffle(v)\n",
    "            selected[k] = v[:int(num_per_group[i])]\n",
    "            if int(num_per_group[i]) < 0:\n",
    "                print(\"Error: group size is less than 0.\")\n",
    "                flag=False\n",
    "        # extract from metadata with group_dict\n",
    "        task = pd.concat([metadata.loc[v] for v in selected.values()])\n",
    "\n",
    "        if len(task) < datasize:\n",
    "            print(f\"Warning: {split} task size is less than datasize by {datasize - len(task)}.\")\n",
    "            if datasize - len(task) > 10:\n",
    "                flag=False\n",
    "        return task, flag\n",
    "\n",
    "    def save(self, output_file):\n",
    "        \"\"\"Save the tasks to a csv file.\"\"\"\n",
    "        self.task.to_csv(output_file, index=False)\n",
    "        # with open(output_file.replace(\"task_\", \"mfeature_\").replace(\".csv\", \".json\"), \"w\") as f:\n",
    "        #     json.dump(self.static_metafeatures, f)\n",
    "\n",
    "    # meta-features\n",
    "    def compute_static_meta_features(self, y, a):\n",
    "\n",
    "        def calculate_entropy(s):\n",
    "            values, counts = np.unique(s, return_counts=True)\n",
    "            probabilities = counts / counts.sum()\n",
    "            entropy = -np.sum(probabilities * np.log2(probabilities))\n",
    "            return entropy\n",
    "\n",
    "        df = {}\n",
    "        # MI, NMI, Cramer, Tschuprow\n",
    "        from sklearn.metrics import mutual_info_score, normalized_mutual_info_score\n",
    "        from scipy.stats import chi2_contingency\n",
    "        from sklearn.metrics.cluster import contingency_matrix\n",
    "        mi = mutual_info_score(y, a)\n",
    "        nmi = normalized_mutual_info_score(y, a)\n",
    "        cm = contingency_matrix(y, a)\n",
    "        chi2, _, _, _ = chi2_contingency(cm)\n",
    "        cramer = np.sqrt(chi2 / (len(y) * (min(len(np.unique(y)), len(np.unique(a))) - 1)))\n",
    "        # tschuprow = np.sqrt(chi2 / (cm.shape[0] * cm.shape[1]))\n",
    "\n",
    "        # entropy and normalized entropy\n",
    "        entropy_y = calculate_entropy(y)\n",
    "        entropy_a = calculate_entropy(a)\n",
    "        n_entropy_y = entropy_y / np.log2(len(np.unique(y)))\n",
    "        n_entropy_a = entropy_a / np.log2(len(np.unique(a)))\n",
    "\n",
    "        # Difference between the probability of the most frequent class and the probability of the least frequent class\n",
    "        diff_y = np.abs(np.max(np.bincount(y)) - np.min(np.bincount(y))) / len(y)\n",
    "        diff_a = np.abs(np.max(np.bincount(a)) - np.min(np.bincount(a))) / len(a)\n",
    "\n",
    "        df[\"mi\"] = mi\n",
    "        df[\"nmi\"] = nmi\n",
    "        df[\"cramer\"] = cramer\n",
    "        # df[\"tschuprow\"] = tschuprow\n",
    "        df[\"entropy_y\"] = entropy_y\n",
    "        df[\"entropy_a\"] = entropy_a\n",
    "        df[\"n_entropy_y\"] = n_entropy_y\n",
    "        df[\"n_entropy_a\"] = n_entropy_a\n",
    "        df[\"diff_y\"] = diff_y\n",
    "        df[\"diff_a\"] = diff_a\n",
    "        return df\n",
    "\n",
    "\n",
    "def generate_valid_dist_shifts(num_exps, min_n=200):\n",
    "    import random\n",
    "\n",
    "    group_matrix = np.array(\n",
    "        [\n",
    "            [1, 0, 0, 1],\n",
    "            [1, 0, 1, 0],\n",
    "            [1, 1, 0, 0],\n",
    "            [1, 1, 1, 1],\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    i = 0\n",
    "    dist_shift_lists = []\n",
    "    while i < num_exps:\n",
    "        # sample sc, ci, ai from (0,1) randomly and independently\n",
    "        sc, ci, ai = [round(random.uniform(0, 1), 2) for _ in range(3)]\n",
    "        b = [sc * min_n, ai * min_n, ci * min_n, min_n]\n",
    "        num_per_group = np.linalg.solve(group_matrix, b)\n",
    "        # check constraints\n",
    "        c1 = (sc + ci + ai) > 1\n",
    "        c2 = sc > (ci + ai - 1)\n",
    "        c3 = ci > (sc + ai - 1)\n",
    "        c4 = ai > (sc + ci - 1)\n",
    "        # c5 = sc > 0.01\n",
    "        # c6 = ci > 0.01\n",
    "        # c7 = ai > 0.01\n",
    "        if c1 and c2 and c3 and c4: # and c5 and c6 and c7:\n",
    "            if (num_per_group >= 2).all():\n",
    "                dist_shift_lists.append([sc, ci, ai])\n",
    "                i += 1\n",
    "    return dist_shift_lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(generate_valid_dist_shifts(30, min_n=200)).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# single shift\n",
    "datasizes = [200, 500, 1000]\n",
    "seed = 0\n",
    "dataset = \"COCO\"\n",
    "for datasize in datasizes:\n",
    "    task_dir = os.path.join(coco_dir, 'metadata', f\"datasize{datasize}_seed{seed}\")\n",
    "    !mkdir $task_dir\n",
    "\n",
    "    # Example usage\n",
    "    task_grid = {\"sc\": [0.01, 0.05, 0.10, 0.30, 0.50, 0.70, 0.90, 0.95, 0.99], \"ci\": [0.50], \"ai\": [0.50]}\n",
    "    task_gen = TaskGenerator(task_grid, seed)\n",
    "    task_gen.pipeline(\n",
    "        dataset.lower(),\n",
    "        dataset_info=os.path.join(OUTPUT_DIR, \"metashift/COCO-Cat-Dog-indoor-outdoor/\", f\"metadata_all.csv\"),\n",
    "        output_path=task_dir,\n",
    "        datasize=datasize,\n",
    "    )\n",
    "\n",
    "\n",
    "    task_grid = {\"sc\": [0.50], \"ci\": [0.01, 0.05, 0.10, 0.30, 0.50, 0.70, 0.90, 0.95, 0.99], \"ai\": [0.50]}\n",
    "    task_gen = TaskGenerator(task_grid, seed)\n",
    "    task_gen.pipeline(\n",
    "        dataset.lower(),\n",
    "        dataset_info=os.path.join(OUTPUT_DIR, \"metashift/COCO-Cat-Dog-indoor-outdoor/\", f\"metadata_all.csv\"),\n",
    "        output_path=task_dir,\n",
    "        datasize=datasize,\n",
    "    )\n",
    "\n",
    "\n",
    "    task_grid = {\"sc\": [0.50], \"ci\": [0.50], \"ai\": [0.01, 0.05, 0.10, 0.30, 0.50, 0.70, 0.90, 0.95, 0.99]}\n",
    "    task_gen = TaskGenerator(task_grid, seed)\n",
    "    task_gen.pipeline(\n",
    "        dataset.lower(),\n",
    "        dataset_info=os.path.join(OUTPUT_DIR, \"metashift/COCO-Cat-Dog-indoor-outdoor/\", f\"metadata_all.csv\"),\n",
    "        output_path=task_dir,\n",
    "        datasize=datasize,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3 shifts\n",
    "datasizes = [200, 500, 1000]\n",
    "seed = 0\n",
    "dataset = \"COCO\"\n",
    "# sc_values = [\n",
    "#     0.68,\n",
    "#     0.5,\n",
    "#     0.3,\n",
    "#     0.51,\n",
    "#     0.4,\n",
    "#     0.57,\n",
    "#     0.17,\n",
    "#     0.51,\n",
    "#     0.55,\n",
    "#     0.47,\n",
    "#     0.49,\n",
    "#     0.52,\n",
    "#     0.23,\n",
    "#     0.12,\n",
    "#     0.36,\n",
    "#     0.15,\n",
    "#     0.5,\n",
    "#     0.49,\n",
    "#     0.48,\n",
    "#     0.26,\n",
    "#     0.28,\n",
    "#     0.68,\n",
    "#     0.27,\n",
    "#     0.25,\n",
    "#     0.23,\n",
    "#     0.59,\n",
    "#     0.72,\n",
    "#     0.67,\n",
    "#     0.64,\n",
    "#     0.7,\n",
    "# ]\n",
    "# ci_values = [\n",
    "#     0.35,\n",
    "#     0.47,\n",
    "#     0.49,\n",
    "#     0.36,\n",
    "#     0.11,\n",
    "#     0.46,\n",
    "#     0.4,\n",
    "#     0.65,\n",
    "#     0.24,\n",
    "#     0.22,\n",
    "#     0.29,\n",
    "#     0.42,\n",
    "#     0.24,\n",
    "#     0.26,\n",
    "#     0.39,\n",
    "#     0.39,\n",
    "#     0.7,\n",
    "#     0.65,\n",
    "#     0.4,\n",
    "#     0.67,\n",
    "#     0.38,\n",
    "#     0.81,\n",
    "#     0.39,\n",
    "#     0.81,\n",
    "#     0.6,\n",
    "#     0.42,\n",
    "#     0.58,\n",
    "#     0.85,\n",
    "#     0.24,\n",
    "#     0.42,\n",
    "# ]\n",
    "# ai_values = [\n",
    "#     0.47,\n",
    "#     0.31,\n",
    "#     0.71,\n",
    "#     0.27,\n",
    "#     0.65,\n",
    "#     0.36,\n",
    "#     0.49,\n",
    "#     0.7,\n",
    "#     0.47,\n",
    "#     0.52,\n",
    "#     0.55,\n",
    "#     0.16,\n",
    "#     0.64,\n",
    "#     0.8,\n",
    "#     0.45,\n",
    "#     0.68,\n",
    "#     0.65,\n",
    "#     0.66,\n",
    "#     0.57,\n",
    "#     0.17,\n",
    "#     0.75,\n",
    "#     0.69,\n",
    "#     0.53,\n",
    "#     0.12,\n",
    "#     0.21,\n",
    "#     0.46,\n",
    "#     0.61,\n",
    "#     0.69,\n",
    "#     0.36,\n",
    "#     0.33,\n",
    "# ]\n",
    "sc_values, ci_values, ai_values = ([0.79, 0.34, 0.52, 0.54, 0.12, 0.26, 0.62, 0.53, 0.09, 0.51, 0.64,\n",
    "        0.31, 0.28, 0.4 , 0.3 , 0.8 , 0.45, 0.12, 0.17, 0.71, 0.51, 0.26,\n",
    "        0.28, 0.55, 0.4 , 0.37, 0.12, 0.36, 0.57, 0.54],\n",
    "       [0.46, 0.12, 0.5 , 0.1 , 0.65, 0.46, 0.78, 0.29, 0.2 , 0.43, 0.48,\n",
    "        0.56, 0.5 , 0.47, 0.48, 0.91, 0.45, 0.47, 0.78, 0.8 , 0.55, 0.49,\n",
    "        0.67, 0.5 , 0.13, 0.26, 0.91, 0.72, 0.1 , 0.67],\n",
    "       [0.36, 0.76, 0.39, 0.54, 0.36, 0.52, 0.74, 0.32, 0.87, 0.82, 0.71,\n",
    "        0.58, 0.36, 0.84, 0.67, 0.79, 0.2 , 0.47, 0.22, 0.59, 0.27, 0.66,\n",
    "        0.37, 0.57, 0.65, 0.65, 0.13, 0.28, 0.43, 0.62])\n",
    "for datasize in datasizes:\n",
    "    task_dir = os.path.join(coco_dir, 'metadata', f\"datasize{datasize}_seed{seed}\")\n",
    "    !mkdir $task_dir\n",
    "\n",
    "    for i in range(len(sc_values)):\n",
    "        task_grid = {\"sc\": [sc_values[i]], \"ci\": [ci_values[i]], \"ai\": [ai_values[i]]}\n",
    "        # Example usage\n",
    "        task_gen = TaskGenerator(task_grid, seed)\n",
    "        task_gen.pipeline(\n",
    "            dataset.lower(),\n",
    "            dataset_info=os.path.join(OUTPUT_DIR, \"metashift/COCO-Cat-Dog-indoor-outdoor/\", f\"metadata_all.csv\"),\n",
    "            output_path=task_dir,\n",
    "            datasize=datasize,\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "div_backup",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
