{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "sys.path.append(\"<anonymized>/hard_label_manifolds\")\n",
    "\n",
    "import os, argparse\n",
    "import numpy as np\n",
    "import json\n",
    "import utils\n",
    "import shutil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "from fnmatch import filter\n",
    "from community.numpy_encoder import NumpyEncoder\n",
    "from research_pool.config import datasets, att_pretty, att_raw, dataset_to_path, dataset_to_victim_path, \\\n",
    "                                 dataset_to_ae, dataset_to_compress_mode, ae_compress_modes, imagenet_archs, \\\n",
    "                                 target_bools, order_to_dataset_epsilon, dataset_to_batch, query_limit, \\\n",
    "                                 dataset_to_cc_ae, hlm_architectures, cc_epoch, norms, dataset_to_resizings, SEED, \\\n",
    "                                 dataset_to_classes, dataset_to_database_path, early_stop, RayS_a, RayS_b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Imagenet', 'Imagenet_madry8']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_to_database_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['inf']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "norms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = {\n",
    "    'base_dir': '<anonymized>/hard_label_manifolds/research_pool/config_autogen'\n",
    "}\n",
    "\n",
    "experiments_id = utils.get_time_stamp()\n",
    "args['base_dir'] = os.path.join(args['base_dir'], experiments_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "default_state = {\n",
    "  \"dataset\": \"\",\n",
    "  \"dataset_path\": \"\",\n",
    "  \"attack\": \"\",\n",
    "  \"victim_architecture\": \"\",\n",
    "  \"targeted\": \"\",\n",
    "  \"random_start\": 0,\n",
    "  \"fd_eta\": 0.5,\n",
    "  \"image_lr\": 0.1,\n",
    "  \"online_lr\": 0.1,\n",
    "  \"exploration\": 0.5,\n",
    "  \"verbose\": 1,\n",
    "  \"batch_size\": 1,\n",
    "  \"test_batch_size\": 1,\n",
    "  \"test_batch\": \"\",\n",
    "  \"epsilon\": 0.0,\n",
    "  \"query_limit\": query_limit,\n",
    "  \"early_stopping\": early_stop,\n",
    "  \"order\": 0,\n",
    "  \"hlm_architecture\": \"\",  \n",
    "  \"compress_mode\": \"\",\n",
    "  \"encoder_resize\": \"\",\n",
    "  \"resize_dim\": \"\",\n",
    "  \"original_size\": \"\",\n",
    "  \"victim_path\": \"\",\n",
    "  \"victim_arch\": \"\",\n",
    "  \"enc_path\": \"\",\n",
    "  \"dec_path\": \"\",\n",
    "  \"classes_path\": \"\",  # Path to class conditional numpy file\n",
    "  \"run_id\": \"\",\n",
    "  \"a\": 1, # RayS specific\n",
    "  \"b\": 1, # RayS specific\n",
    "  \n",
    "  \"classes_select\": [],\n",
    "  \"ix_database_path\" : \"\",  # Path to k_to_ix database (pre-computed using notebook)\n",
    "  \"save_dir\": f\"<anonymized>/analysis/hlm/{experiments_id}\",\n",
    "  \"num_threads\": 8,\n",
    "  \"seed\": SEED,\n",
    "  \"save_suffix\": \"\",\n",
    "  \"margin\": 200.0,\n",
    "  \"log_interval\": 100\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_state(state, f_path, dry_run=False):\n",
    "    with open(f_path, 'w') as config_file:\n",
    "        json.dump(state, config_file, indent=2, cls=NumpyEncoder)\n",
    "        \n",
    "    \n",
    "class Node(object):\n",
    "    def __init__(self, parent, state: dict, name=\"\", children=None):\n",
    "        self.parent = parent\n",
    "        self.children = children if children else []\n",
    "        self.state = state\n",
    "        self.name = name\n",
    "        \n",
    "    def is_root(self):\n",
    "        return True if self.parent is None else False\n",
    "    \n",
    "    def is_leaf(self):\n",
    "        if len(self.children) == 0:\n",
    "            return True\n",
    "        else:\n",
    "            return False\n",
    "        \n",
    "    def branch(self, key, params):\n",
    "        for p in params:\n",
    "            child_state = copy.deepcopy(self.state)\n",
    "            child_state[key] = p\n",
    "            self.children.append(Node(parent=self, state=child_state))\n",
    "    \n",
    "    def apply_state(self, new_dict):\n",
    "        for key in list(new_dict.keys()):\n",
    "            self.state[key] = new_dict[key]\n",
    "            \n",
    "    def __str__(self):\n",
    "        s = \"\"\n",
    "        for key, value in self.state.items():\n",
    "            s += f\"{key}: {value}\\n\"\n",
    "        return s\n",
    "\n",
    "\n",
    "def terminating_leafs(finalsubtree):\n",
    "    # Project-wide params\n",
    "    apply_state = {}\n",
    "    dataset = finalsubtree.state[\"dataset\"]\n",
    "    apply_state[\"dataset_path\"] = dataset_to_path[dataset]\n",
    "\n",
    "    finalsubtree.apply_state({\"victim_path\": dataset_to_victim_path[dataset]})\n",
    "    \n",
    "    apply_state[\"epsilon\"] = order_to_dataset_epsilon[(finalsubtree.state['order'], dataset)]\n",
    "    apply_state[\"test_batch\"] = dataset_to_batch[dataset]\n",
    "    apply_state[\"classes_select\"] = dataset_to_classes[dataset]\n",
    "    apply_state[\"ix_database_path\"] = dataset_to_database_path[dataset]\n",
    "    \n",
    "    if \"HLM_\" in attack:\n",
    "        apply_state[\"enc_path\"] = dataset_to_ae[(dataset, \"enc\")]\n",
    "        apply_state[\"dec_path\"] = dataset_to_ae[(dataset, \"dec\")]\n",
    "       \n",
    "        if finalsubtree.state['hlm_architecture'] == 'CC-AE':\n",
    "            apply_state[\"enc_path\"] = os.path.join(dataset_to_cc_ae[dataset], f'CX_conv_encoder_epoch-{cc_epoch}.pth')  # Filename template\n",
    "            apply_state[\"dec_path\"] = os.path.join(dataset_to_cc_ae[dataset], f'CX_conv_decoder_epoch-{cc_epoch}.pth')\n",
    "            apply_state[\"classes_path\"] = os.path.join(dataset_to_cc_ae[dataset], 'classes_ix.npy')\n",
    "    \n",
    "    apply_state[\"compress_mode\"] = dataset_to_compress_mode[dataset]\n",
    "    \n",
    "    if \"Imagenet\" in dataset:\n",
    "        apply_state[\"original_size\"] = 224\n",
    "        # apply_state[\"resize_dim\"] = 28\n",
    "        # apply_state[\"resize_dim\"] = 64\n",
    "\n",
    "        if \"HLM_\" in attack:\n",
    "            apply_state[\"encoder_resize\"] = 128\n",
    "        \n",
    "        if \"HSJA\" in attack:\n",
    "            apply_state[\"gamma\"] = 1000.0\n",
    "\n",
    "        finalsubtree.apply_state(apply_state)\n",
    "        finalsubtree.branch(\"victim_architecture\", imagenet_archs)\n",
    "\n",
    "    elif \"CIFAR10\" in dataset:\n",
    "        apply_state[\"original_size\"] = 32\n",
    "        # apply_state[\"resize_dim\"] = 16\n",
    "        # apply_state[\"resize_dim\"] = 24\n",
    "        if \"HSJA\" in attack:\n",
    "            apply_state[\"gamma\"] = 10.0\n",
    "        \n",
    "        finalsubtree.apply_state(apply_state)\n",
    "\n",
    "    elif \"MNIST\" in dataset:\n",
    "        apply_state[\"original_size\"] = 28\n",
    "        # apply_state[\"resize_dim\"] = 14\n",
    "        # apply_state[\"resize_dim\"] = 14\n",
    "        if \"HSJA\" in attack:\n",
    "            apply_state[\"gamma\"] = 10.0\n",
    "        \n",
    "        finalsubtree.apply_state(apply_state)\n",
    "\n",
    "        \n",
    "tree = Node(None, default_state, \"root\")\n",
    "\n",
    "tree.branch(\"targeted\", target_bools)\n",
    "for subtree in tree.children:\n",
    "    subtree.branch(\"attack\", att_raw)\n",
    "    \n",
    "    for s2tree in subtree.children:\n",
    "        s2tree.branch(\"dataset\", datasets)\n",
    "        \n",
    "        for s23tree in s2tree.children:\n",
    "            s23tree.branch(\"order\", norms)\n",
    "            \n",
    "            for s3tree in s23tree.children:\n",
    "                attack = s3tree.state[\"attack\"]\n",
    "                dataset = s3tree.state[\"dataset\"]\n",
    "                \n",
    "                if \"HLM_\" in attack:\n",
    "                    s3tree.branch(\"compress_mode\", ae_compress_modes)\n",
    "\n",
    "                    # Go down one subtree\n",
    "                    for s4tree in s3tree.children:\n",
    "                        s4tree.branch(\"hlm_architecture\", hlm_architectures)\n",
    "\n",
    "                        # Go down one subtree\n",
    "                        for s5tree in s4tree.children:\n",
    "                            terminating_leafs(s5tree)\n",
    "                \n",
    "                elif \"Sampling_RayS\" == attack:\n",
    "                    # Assume either a is not empty or b is not empty, but not both\n",
    "                    if len(RayS_a):\n",
    "                        # add a subtree\n",
    "                        s3tree.branch(\"a\", RayS_a)\n",
    "                    \n",
    "                    if len(RayS_b):\n",
    "                        # add a subtree\n",
    "                        s3tree.branch(\"b\", RayS_b)\n",
    "                        \n",
    "                    # Go down one subtree\n",
    "                    for s4tree in s3tree.children:\n",
    "                        terminating_leafs(s4tree)\n",
    "                    \n",
    "                elif \"Sampling_\" in attack:\n",
    "                    s3tree.branch(\"resize_dim\", dataset_to_resizings[dataset])\n",
    "                    \n",
    "                    # Go down one subtree\n",
    "                    for s4tree in s3tree.children:\n",
    "                        terminating_leafs(s4tree)\n",
    "                        \n",
    "                else:\n",
    "                    # Stay at this subtree\n",
    "                    terminating_leafs(s3tree)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DFS and write leafs\n",
    "\n",
    "if not os.path.isdir(args['base_dir']):\n",
    "    os.makedirs(args['base_dir'])\n",
    "    \n",
    "    \n",
    "def dfs(subtree, counter=0):\n",
    "    \n",
    "    if subtree.is_leaf():\n",
    "        subtree.apply_state({\"run_id\": counter})\n",
    "\n",
    "        state = subtree.state\n",
    "        # config_database[counter] = state\n",
    "        \n",
    "        name_params = [str(state[\"targeted\"]), state[\"dataset\"], state[\"attack\"], str(state['run_id']), f\"L{str(state['order'])}\"]\n",
    "        if state[\"compress_mode\"] != \"\":\n",
    "            name_params.append(f'cm{str(state[\"compress_mode\"])}')\n",
    "        if state[\"victim_architecture\"] != \"\":\n",
    "            name_params.append(state[\"victim_architecture\"])\n",
    "        \n",
    "        new_name = f\"{'_'.join(name_params)}.json\"\n",
    "        write_state(state, os.path.join(args['base_dir'], new_name))\n",
    "        \n",
    "        print(f\"Write {new_name} as id {state['run_id']}\")\n",
    "        return counter + 1\n",
    "    else:\n",
    "        for s2tree in subtree.children:\n",
    "            counter = dfs(s2tree, counter)\n",
    "        \n",
    "        return counter\n",
    "            \n",
    "num_files = dfs(tree)\n",
    "print(f\"Wrote {num_files} files to {args['base_dir']}.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for state_file in filter(os.listdir(args['base_dir']), \"*.json\"):\n",
    "    with open(os.path.join(args['base_dir'], state_file)) as state_f:\n",
    "        state = json.load(state_f)\n",
    "        \n",
    "    run_save_dir = os.path.join(state['save_dir'], state['attack'], state['dataset'], str(state['run_id']))\n",
    "    print(f\"Expect to find {str(state['run_id'])} @ {run_save_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "alae-data",
   "language": "python",
   "name": "alae-data"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
