{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4MWfQKlXwOU1"
      },
      "source": [
        "This is the source codes of Multi-Fidelity Neural Architecture Search (MF-NAS) on NAS-Bench-101.\n",
        "\n",
        "The content is presented as:\n",
        "\n",
        "*   The utility functions are presented in Section 'Utilities'.\n",
        "*   Codes for loading benchmark databases are in Section 'Load Benchmark'.\n",
        "*   Codes of First-Improvement Local Search, Random Search, Successive Halving, MF-NAS are presented in the Section 'Algorithms' and Section 'MF-NAS'.\n",
        "*   Codes for running Random Search, Local Search, SH, REA, REA+W, and MF-NAS variants in Section 'Run'.\n",
        "\n",
        "All results (for NAS-Bench-101) in the paper have already presented in Section 'Run'.\n",
        "\n",
        "Executing all cells if you want to reproduce our results in the paper."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YEEe3vfm7h4m"
      },
      "source": [
        "## Utilities"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "f20ae2a9"
      },
      "outputs": [],
      "source": [
        "import copy\n",
        "import pickle as p\n",
        "import numpy as np\n",
        "import random\n",
        "from copy import deepcopy\n",
        "import itertools\n",
        "import json\n",
        "import math\n",
        "from scipy import stats\n",
        "from tqdm import tqdm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "fO7mt2YyuKK4"
      },
      "outputs": [],
      "source": [
        "allowed_ops = ['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3']\n",
        "allowed_edges = [0, 1]  # Binary adjacency matrix\n",
        "num_vertices = 7\n",
        "max_edges = 9\n",
        "edge_spots_idx = np.triu_indices(num_vertices, 1)\n",
        "edge_spots = int(num_vertices * (num_vertices - 1) / 2)  # Upper triangular matrix\n",
        "op_spots = int(num_vertices - 2)  # Input/output vertices are fixed"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "xArHM0PauFca"
      },
      "outputs": [],
      "source": [
        "def encode(arch: dict):\n",
        "    x_edge = np.array(arch['matrix'])[edge_spots_idx]\n",
        "    x_ops = np.empty(num_vertices - 2)\n",
        "    for i, op in enumerate(arch['ops'][1:-1]):\n",
        "        x_ops[i] = (np.array(allowed_ops) == op).nonzero()[0][0]\n",
        "    return np.concatenate((x_edge, x_ops)).astype(int)\n",
        "\n",
        "def decode(x: np.ndarray):\n",
        "    x_edge = x[:edge_spots]\n",
        "    x_ops = x[-op_spots:]\n",
        "    matrix = np.zeros((num_vertices, num_vertices), dtype=int)\n",
        "    matrix[edge_spots_idx] = x_edge\n",
        "    ops = ['input'] + [allowed_ops[i] for i in x_ops] + ['output']\n",
        "    return {'matrix': matrix, 'ops': ops}\n",
        "\n",
        "def sample(phenotype=True):\n",
        "    matrix = np.random.choice(allowed_edges, size=(num_vertices, num_vertices))\n",
        "    matrix = np.triu(matrix, 1)\n",
        "    ops = np.random.choice(allowed_ops, size=num_vertices).tolist()\n",
        "    ops[0] = 'input'\n",
        "    ops[-1] = 'output'\n",
        "\n",
        "    if phenotype:\n",
        "        return {'matrix': matrix, 'ops': ops}\n",
        "    else:\n",
        "        return encode({'matrix': matrix, 'ops': ops})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "7ECtcuxBx-wo"
      },
      "outputs": [],
      "source": [
        "import hashlib\n",
        "\n",
        "def hash_module(matrix, labeling):\n",
        "    \"\"\"Computes a graph-invariance MD5 hash of the matrix and label pair.\n",
        "\n",
        "    Args:\n",
        "      matrix: np.ndarray square upper-triangular adjacency matrix.\n",
        "      labeling: list of int labels of length equal to both dimensions of\n",
        "        matrix.\n",
        "\n",
        "    Returns:\n",
        "      MD5 hash of the matrix and labeling.\n",
        "    \"\"\"\n",
        "    vertices = np.shape(matrix)[0]\n",
        "    in_edges = np.sum(matrix, axis=0).tolist()\n",
        "    out_edges = np.sum(matrix, axis=1).tolist()\n",
        "\n",
        "    assert len(in_edges) == len(out_edges) == len(labeling)\n",
        "    hashes = list(zip(out_edges, in_edges, labeling))\n",
        "    hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes]\n",
        "    # Computing this up to the diameter is probably sufficient but since the\n",
        "    # operation is fast, it is okay to repeat more times.\n",
        "    for _ in range(vertices):\n",
        "        new_hashes = []\n",
        "        for v in range(vertices):\n",
        "            in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]]\n",
        "            out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]]\n",
        "            new_hashes.append(hashlib.md5(\n",
        "                (''.join(sorted(in_neighbors)) + '|' +\n",
        "                 ''.join(sorted(out_neighbors)) + '|' +\n",
        "                 hashes[v]).encode('utf-8')).hexdigest())\n",
        "        hashes = new_hashes\n",
        "    fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest()\n",
        "    return fingerprint\n",
        "\n",
        "class ModelSpec(object):\n",
        "    \"\"\"Model specification given adjacency matrix and labeling.\"\"\"\n",
        "\n",
        "    def __init__(self, matrix, ops, data_format='channels_last'):\n",
        "        \"\"\"Initialize the module spec.\n",
        "\n",
        "        Args:\n",
        "          matrix: ndarray or nested list with shape [V, V] for the adjacency matrix.\n",
        "          ops: V-length list of labels for the base ops used. The first and last\n",
        "            elements are ignored because they are the input and output vertices\n",
        "            which have no operations. The elements are retained to keep consistent\n",
        "            indexing.\n",
        "          data_format: channels_last or channels_first.\n",
        "\n",
        "        Raises:\n",
        "          ValueError: invalid matrix or ops\n",
        "        \"\"\"\n",
        "        if not isinstance(matrix, np.ndarray):\n",
        "            matrix = np.array(matrix)\n",
        "        shape = np.shape(matrix)\n",
        "        if len(shape) != 2 or shape[0] != shape[1]:\n",
        "            raise ValueError('matrix must be square')\n",
        "        if shape[0] != len(ops):\n",
        "            raise ValueError('length of ops must match matrix dimensions')\n",
        "        if not is_upper_triangular(matrix):\n",
        "            raise ValueError('matrix must be upper triangular')\n",
        "\n",
        "        # Both the original and pruned matrices are deep copies of the matrix and\n",
        "        # ops so any changes to those after initialization are not recognized by the\n",
        "        # spec.\n",
        "        self.original_matrix = copy.deepcopy(matrix)\n",
        "        self.original_ops = copy.deepcopy(ops)\n",
        "\n",
        "        self.matrix = copy.deepcopy(matrix)\n",
        "        self.ops = copy.deepcopy(ops)\n",
        "        self.valid_spec = True\n",
        "        self._prune()\n",
        "\n",
        "        self.data_format = data_format\n",
        "\n",
        "    def _prune(self):\n",
        "        \"\"\"Prune the extraneous parts of the graph.\n",
        "\n",
        "        General procedure:\n",
        "          1) Remove parts of graph not connected to input.\n",
        "          2) Remove parts of graph not connected to output.\n",
        "          3) Reorder the vertices so that they are consecutive after steps 1 and 2.\n",
        "\n",
        "        These 3 steps can be combined by deleting the rows and columns of the\n",
        "        vertices that are not reachable from both the input and output (in reverse).\n",
        "        \"\"\"\n",
        "        num_vertices = np.shape(self.original_matrix)[0]\n",
        "\n",
        "        # DFS forward from input\n",
        "        visited_from_input = set([0])\n",
        "        frontier = [0]\n",
        "        while frontier:\n",
        "            top = frontier.pop()\n",
        "            for v in range(top + 1, num_vertices):\n",
        "                if self.original_matrix[top, v] and v not in visited_from_input:\n",
        "                    visited_from_input.add(v)\n",
        "                    frontier.append(v)\n",
        "\n",
        "        # DFS backward from output\n",
        "        visited_from_output = set([num_vertices - 1])\n",
        "        frontier = [num_vertices - 1]\n",
        "        while frontier:\n",
        "            top = frontier.pop()\n",
        "            for v in range(0, top):\n",
        "                if self.original_matrix[v, top] and v not in visited_from_output:\n",
        "                    visited_from_output.add(v)\n",
        "                    frontier.append(v)\n",
        "\n",
        "        # Any vertex that isn't connected to both input and output is extraneous to\n",
        "        # the computation graph.\n",
        "        extraneous = set(range(num_vertices)).difference(\n",
        "            visited_from_input.intersection(visited_from_output))\n",
        "\n",
        "        # If the non-extraneous graph is less than 2 vertices, the input is not\n",
        "        # connected to the output and the spec is invalid.\n",
        "        if len(extraneous) > num_vertices - 2:\n",
        "            self.matrix = None\n",
        "            self.ops = None\n",
        "            self.valid_spec = False\n",
        "            return\n",
        "\n",
        "        self.matrix = np.delete(self.matrix, list(extraneous), axis=0)\n",
        "        self.matrix = np.delete(self.matrix, list(extraneous), axis=1)\n",
        "        for index in sorted(extraneous, reverse=True):\n",
        "            del self.ops[index]\n",
        "\n",
        "    def hash_spec(self, canonical_ops):\n",
        "        \"\"\"Computes the isomorphism-invariant graph hash of this spec.\n",
        "\n",
        "        Args:\n",
        "          canonical_ops: list of operations in the canonical ordering which they\n",
        "            were assigned (i.e. the order provided in the config['available_ops']).\n",
        "\n",
        "        Returns:\n",
        "          MD5 hash of this spec which can be used to query the dataset.\n",
        "        \"\"\"\n",
        "        # Invert the operations back to integer label indices used in graph gen.\n",
        "        labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2]\n",
        "        return hash_module(self.matrix, labeling)\n",
        "\n",
        "\n",
        "def is_upper_triangular(matrix):\n",
        "    \"\"\"True if matrix is 0 on diagonal and below.\"\"\"\n",
        "    for src in range(np.shape(matrix)[0]):\n",
        "        for dst in range(0, src + 1):\n",
        "            if matrix[src, dst] != 0:\n",
        "                return False\n",
        "    return True"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "p3gaWv1Dyomv"
      },
      "outputs": [],
      "source": [
        "class OutOfDomainError(Exception):\n",
        "    \"\"\"Indicates that the requested graph is outside of the search domain.\"\"\"\n",
        "\n",
        "def is_valid(state):\n",
        "    arch = decode(state)\n",
        "    model_spec = ModelSpec_(arch['matrix'], arch['ops'])\n",
        "    try:\n",
        "        check_spec(model_spec)\n",
        "    except OutOfDomainError:\n",
        "        return False\n",
        "    return True\n",
        "\n",
        "def get_hashKey(state):\n",
        "    arch = decode(state)\n",
        "    modelspec = ModelSpec_(arch['matrix'], arch['ops'])\n",
        "    return modelspec.hash_spec(['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'])\n",
        "\n",
        "def check_spec(model_spec):\n",
        "    \"\"\"Checks that the model spec is within the dataset.\"\"\"\n",
        "    if not model_spec.valid_spec:\n",
        "        raise OutOfDomainError('invalid spec, provided graph is disconnected.')\n",
        "\n",
        "    num_vertices = len(model_spec.ops)\n",
        "    num_edges = np.sum(model_spec.matrix)\n",
        "\n",
        "    if num_vertices > 7:\n",
        "        raise OutOfDomainError('too many vertices')\n",
        "\n",
        "    if num_edges > 9:\n",
        "        raise OutOfDomainError('too many edges')\n",
        "\n",
        "    if model_spec.ops[0] != 'input':\n",
        "        raise OutOfDomainError('first operation should be \\'input\\'')\n",
        "\n",
        "    if model_spec.ops[-1] != 'output':\n",
        "        raise OutOfDomainError('last operation should be \\'output\\'')\n",
        "\n",
        "    for op in model_spec.ops[1:-1]:\n",
        "        if op not in ['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3']:\n",
        "            raise OutOfDomainError('unsupported op')\n",
        "\n",
        "ModelSpec_ = ModelSpec"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "821c6f87"
      },
      "outputs": [],
      "source": [
        "def get_indices(state, distance):\n",
        "    return list(itertools.combinations(range(len(state)), distance))\n",
        "\n",
        "def get_all_neighbors(state, ids):\n",
        "    list_neighbors = []\n",
        "    list_available_ops = []\n",
        "    for i in ids:\n",
        "        _available_ops = [0, 1] if i in range(21) else [0, 1, 2]\n",
        "        _available_ops.remove(state[i])\n",
        "        list_available_ops.append(_available_ops)\n",
        "    list_ops = list(itertools.product(*list_available_ops))\n",
        "    ids = np.array(ids)\n",
        "    for ops in list_ops:\n",
        "        neighbor = np.array(state).copy()\n",
        "        neighbor[ids] = ops\n",
        "        list_neighbors.append(neighbor.tolist())\n",
        "    np.random.shuffle(list_neighbors)\n",
        "    return list_neighbors"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "77031a12"
      },
      "outputs": [],
      "source": [
        "def train_evaluate(state, metric):\n",
        "    arch = decode(state)\n",
        "    modelspec = ModelSpec_(arch['matrix'], arch['ops'])\n",
        "    h = modelspec.hash_spec(['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'])\n",
        "    if metric == 'test_acc':\n",
        "        return benchmark['108'][h]['test_acc'], 0.0\n",
        "    else:\n",
        "        iepoch = int(metric.split('_')[-1])\n",
        "        metric = '_'.join(metric.split('_')[:-1])\n",
        "        score = benchmark[f'{iepoch}'][h]['val_acc']\n",
        "        time = benchmark[f'{iepoch}'][h]['train_time']\n",
        "        return score, time\n",
        "\n",
        "def zc_evaluate(state, metric):\n",
        "    arch = decode(state)\n",
        "    modelspec = ModelSpec_(arch['matrix'], arch['ops'])\n",
        "    h = modelspec.hash_spec(['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'])\n",
        "    if metric == 'synflow':\n",
        "        time = 1.4356617034300945\n",
        "    elif metric == 'params':\n",
        "        time = 0.30215823150349097\n",
        "    elif metric == 'grad_norm':\n",
        "        time = 2.035946015246093\n",
        "    elif metric == 'grasp':\n",
        "        time = 5.570546795804546\n",
        "    elif metric == 'jacob_cov':\n",
        "        time = 2.5207841626097856\n",
        "    elif metric == 'snip':\n",
        "        time = 2.028758352457235\n",
        "    else:\n",
        "        time = 2.610283957422675\n",
        "    if metric in ['params']:\n",
        "        return benchmark['108'][h]['n_params'], time\n",
        "    if metric == 'flops':\n",
        "         return flops_benchmark[h], 0.30215823150349097\n",
        "    return zc_benchmark[h][metric], time\n",
        "\n",
        "def evaluate(state, metric):\n",
        "    if 'acc' in metric:\n",
        "        res = train_evaluate(state, metric)\n",
        "    else:\n",
        "        res = zc_evaluate(state, metric)\n",
        "    return res[0], res[1]\n",
        "\n",
        "def evaluate_trend_best_state(trend_best_state):\n",
        "    return [evaluate(state, 'test_acc')[0] for state in trend_best_state]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M_OXQMY9wWK6"
      },
      "source": [
        "# Load Benchmark"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mdi6SJJ_wY4P",
        "outputId": "3b27ddb0-4443-4b6f-f4b3-be49e78f096e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1y9pxpnYkhOsg-3xeNwhJYg9gVtOwipHT\n",
            "To: /content/data.p\n",
            "100% 339M/339M [00:06<00:00, 51.0MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1V5Nz9SlMvwYSnQq1-1bJNvob8i1o4HiK\n",
            "To: /content/zc_101.p\n",
            "100% 48.7M/48.7M [00:00<00:00, 63.9MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1--RQXQU1r6UZOVF9kI2IKaofETOkcTRG\n",
            "To: /content/flops_all_arch_101.p\n",
            "100% 22.9M/22.9M [00:00<00:00, 49.4MB/s]\n"
          ]
        }
      ],
      "source": [
        "!gdown https://drive.google.com/uc?id=1y9pxpnYkhOsg-3xeNwhJYg9gVtOwipHT\n",
        "!gdown https://drive.google.com/uc?id=1V5Nz9SlMvwYSnQq1-1bJNvob8i1o4HiK\n",
        "!gdown https://drive.google.com/uc?id=1--RQXQU1r6UZOVF9kI2IKaofETOkcTRG"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "lJH6asvI1q5K"
      },
      "outputs": [],
      "source": [
        "zc_benchmark = p.load(open('zc_101.p', 'rb'))\n",
        "benchmark = p.load(open('data.p', 'rb'))\n",
        "flops_benchmark = p.load(open('flops_all_arch_101.p', 'rb'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OOKQTE0k7lhY"
      },
      "source": [
        "# Algorithms"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "densE8Qu1NJC"
      },
      "source": [
        "## First-improvement Local Search"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "rtZc_2001HeH"
      },
      "outputs": [],
      "source": [
        "def first_improvement_ls(init_state, max_eval=np.inf, max_time=np.inf, metric=None):\n",
        "    assert metric is not None, 'Missing the evaluation metric!'\n",
        "    assert init_state is not None, 'Missing the initial state!'\n",
        "\n",
        "    n_eval = 0\n",
        "    total_time = 0\n",
        "    curr_state = init_state.copy()\n",
        "    f_curr_state, time = evaluate(init_state, metric)\n",
        "    total_time += time\n",
        "    n_eval += 1\n",
        "\n",
        "    best_state = curr_state.copy()\n",
        "    f_best_state = f_curr_state\n",
        "\n",
        "    trend_best_state = [best_state]\n",
        "    trend_time = [total_time]\n",
        "\n",
        "    state_history, f_state_history = [curr_state], [f_curr_state]\n",
        "    while (n_eval <= max_eval) and (total_time <= max_time):\n",
        "        improved = False\n",
        "        list_ids = get_indices(curr_state, 1)\n",
        "        while len(list_ids) != 0:\n",
        "            idx = np.random.choice(range(len(list_ids)))\n",
        "            ids = list_ids[idx]\n",
        "            list_ids.remove(list_ids[idx])\n",
        "\n",
        "            list_neighbors = get_all_neighbors(state=curr_state, ids=ids)\n",
        "            for neighbor in list_neighbors:\n",
        "                if is_valid(neighbor):\n",
        "                    f_neighbor, time = evaluate(neighbor, metric)\n",
        "                    state_history.append(neighbor)\n",
        "                    f_state_history.append(f_neighbor)\n",
        "\n",
        "                    total_time += time\n",
        "                    trend_time.append(total_time)\n",
        "                    n_eval += 1\n",
        "                    if f_neighbor >= f_curr_state:\n",
        "                        curr_state = neighbor.copy()\n",
        "                        f_curr_state = f_neighbor\n",
        "\n",
        "                        if f_neighbor > f_best_state:\n",
        "                            best_state = neighbor.copy()\n",
        "                            f_best_state = f_neighbor\n",
        "                        trend_best_state.append(best_state)\n",
        "                        improved = True\n",
        "                        break\n",
        "                    else:\n",
        "                        trend_best_state.append(best_state)\n",
        "            if improved:\n",
        "                break\n",
        "\n",
        "        if not improved:\n",
        "            list_ids = get_indices(curr_state, 2)\n",
        "            while True:\n",
        "                idx = np.random.choice(range(len(list_ids)))\n",
        "                ids = list_ids[idx]\n",
        "\n",
        "                list_neighbors = get_all_neighbors(curr_state, ids)\n",
        "                curr_state = list_neighbors[np.random.choice(len(list_neighbors))]\n",
        "                if is_valid(curr_state):\n",
        "                    break\n",
        "            f_curr_state, time = evaluate(curr_state, metric)\n",
        "            state_history.append(curr_state)\n",
        "            f_state_history.append(f_curr_state)\n",
        "            total_time += time\n",
        "            n_eval += 1\n",
        "            if f_curr_state > f_best_state:\n",
        "                best_state = curr_state.copy()\n",
        "                f_best_state = f_curr_state\n",
        "            trend_best_state.append(best_state)\n",
        "            trend_time.append(total_time)\n",
        "    return trend_best_state, trend_time, state_history, f_state_history"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K7hDoccf1SLc"
      },
      "source": [
        "## Random Search"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "b3aOqQ301VRv"
      },
      "outputs": [],
      "source": [
        "def random_search(max_eval, metric=None):\n",
        "    assert metric is not None, 'Missing the evaluation metric!'\n",
        "\n",
        "    total_time = 0\n",
        "\n",
        "    best_state, f_best_state = None, None\n",
        "    trend_best_state, trend_time = [], []\n",
        "    state_history, f_state_history = [], []\n",
        "\n",
        "    i = 0\n",
        "    first = True\n",
        "    while i <= max_eval:\n",
        "        while True:\n",
        "            state = sample(False)\n",
        "            if is_valid(state):\n",
        "                break\n",
        "        f_state, time = evaluate(state, metric)\n",
        "        total_time += time\n",
        "\n",
        "        if first:\n",
        "            best_state, f_best_state = state.copy(), f_state\n",
        "            trend_best_state, trend_time = [best_state], [total_time]\n",
        "            state_history, f_state_history = [state], [f_state]\n",
        "            first = False\n",
        "        else:\n",
        "            state_history.append(state)\n",
        "            f_state_history.append(f_state)\n",
        "\n",
        "            ## If current state is better than the best current state, replace\n",
        "            if f_state > f_best_state:\n",
        "                best_state = state.copy()\n",
        "                f_best_state = f_state\n",
        "            trend_best_state.append(best_state)\n",
        "            trend_time.append(total_time)\n",
        "\n",
        "        i += 1\n",
        "\n",
        "    return trend_best_state, trend_time, state_history, f_state_history"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i81aVIOI1TmP"
      },
      "source": [
        "## Successive Halving"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "9OPyuRaoDHxr"
      },
      "outputs": [],
      "source": [
        "def succesive_halving(list_candidates, metric, max_budget, list_epochs):\n",
        "    assert len(list_epochs) != 0, 'Have not set the checkpoints for evaluation!'\n",
        "    checkpoint = 0\n",
        "    iepoch = list_epochs[checkpoint]\n",
        "\n",
        "    remaining_time = max_budget\n",
        "\n",
        "    last_iepoch = 0\n",
        "    total_time, total_epoch = 0.0, 0.0\n",
        "\n",
        "    time_tracking = {}\n",
        "\n",
        "    best_candidate, f_best_candidate = None, 0.0\n",
        "    last = False\n",
        "    while total_time < max_budget:\n",
        "\n",
        "        evaluated_candidates = []\n",
        "        f_candidates = []\n",
        "        for candidate in list_candidates:\n",
        "            score, time = evaluate(candidate, f'{metric}_{iepoch}')\n",
        "\n",
        "            arch = decode(candidate)\n",
        "            modelspec = ModelSpec_(arch['matrix'], arch['ops'])\n",
        "            h = modelspec.hash_spec(['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'])\n",
        "\n",
        "            if h not in time_tracking:\n",
        "                time_tracking[h] = time\n",
        "                last_time = 0\n",
        "                diff_time = time\n",
        "            else:\n",
        "                last_time = time_tracking[h]\n",
        "                diff_time = time - last_time\n",
        "                time_tracking[h] = time\n",
        "\n",
        "            diff = (iepoch - last_iepoch)\n",
        "\n",
        "            remaining_time -= diff_time\n",
        "\n",
        "            total_time += diff_time\n",
        "            total_epoch += diff\n",
        "\n",
        "            evaluated_candidates.append(candidate)\n",
        "            f_candidates.append(score)\n",
        "\n",
        "            if total_time >= max_budget:\n",
        "                evaluated_candidates = evaluated_candidates[:-1]\n",
        "                f_candidates = f_candidates[:-1]\n",
        "                total_time -= diff_time\n",
        "                total_epoch -= diff\n",
        "                return best_candidate, total_time, total_epoch\n",
        "\n",
        "            if score > f_best_candidate:\n",
        "                f_best_candidate = score\n",
        "                best_candidate = candidate\n",
        "\n",
        "        last_iepoch = iepoch\n",
        "\n",
        "        ids = np.flip(np.argsort(f_candidates))\n",
        "        list_candidates = np.array(evaluated_candidates)[ids]\n",
        "        list_candidates = list_candidates[:math.ceil(len(list_candidates) / 2)]\n",
        "\n",
        "        if len(list_candidates) == 1 or last:\n",
        "            return best_candidate, total_time, total_epoch\n",
        "\n",
        "        checkpoint += 1\n",
        "        last_iepoch = int(iepoch)\n",
        "        iepoch = list_epochs[checkpoint]\n",
        "\n",
        "        if iepoch == list_epochs[-1]:\n",
        "            last = True"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fQrD0rQtXC0W"
      },
      "source": [
        "## Others"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "emflnq67XD0I"
      },
      "outputs": [],
      "source": [
        "def run_succesive_halving(seed, N, max_budget, list_epochs):\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "    list_candidates = []\n",
        "    for _ in range(N):\n",
        "        while True:\n",
        "            arch = sample(False)\n",
        "            if is_valid(arch):\n",
        "                list_candidates.append(arch)\n",
        "                break\n",
        "    best_candidate, total_time, total_epoch = succesive_halving(list_candidates, 'val_acc', max_budget, list_epochs)\n",
        "    f_best_candidate = evaluate(best_candidate, 'test_acc')[0]\n",
        "    return best_candidate, f_best_candidate, total_time, total_epoch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "FNeXHwH0YdGH"
      },
      "outputs": [],
      "source": [
        "def run_ils(seed, max_budget, iepoch):\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "    while True:\n",
        "        init_state = sample(False)\n",
        "        if is_valid(init_state):\n",
        "            break\n",
        "    trend_best_state, trend_time, _, _ = first_improvement_ls(init_state, max_time=max_budget, metric=f'val_acc_{iepoch}')\n",
        "    best_candidate = trend_best_state[-1]\n",
        "    f_best_candidate = evaluate(best_candidate, 'test_acc')[0]\n",
        "    return best_candidate, f_best_candidate, trend_time[-1], 0.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "3LP9SDtjZS7F"
      },
      "outputs": [],
      "source": [
        "def run_random_search(seed, max_time_budget=5e6, iepoch=108):\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "    \"\"\"Run a single roll-out of random search to a fixed time budget.\"\"\"\n",
        "    times, best_valids, best_archs = [0.0], [0.0], [None]\n",
        "    while True:\n",
        "        while True:\n",
        "            arch = sample(False)\n",
        "            if is_valid(arch):\n",
        "                break\n",
        "        val_acc, time = evaluate(arch, f'val_acc_{iepoch}')\n",
        "\n",
        "        # It's important to select models only based on validation accuracy, test\n",
        "        # accuracy is used only for comparing different search trajectories.\n",
        "        if val_acc > best_valids[-1]:\n",
        "            best_valids.append(val_acc)\n",
        "            best_archs.append(arch)\n",
        "        else:\n",
        "            best_valids.append(best_valids[-1])\n",
        "            best_archs.append(best_archs[-1])\n",
        "\n",
        "        times.append(time)\n",
        "        if sum(times) > max_time_budget:\n",
        "            # Break the first time we exceed the budget.\n",
        "            best_candidate = best_archs[-1]\n",
        "            f_best_candidate = evaluate(best_candidate, 'test_acc')[0]\n",
        "            return best_candidate, f_best_candidate, sum(times), None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "eqNz_76Bagvj"
      },
      "outputs": [],
      "source": [
        "INPUT = 'input'\n",
        "OUTPUT = 'output'\n",
        "CONV3X3 = 'conv3x3-bn-relu'\n",
        "CONV1X1 = 'conv1x1-bn-relu'\n",
        "MAXPOOL3X3 = 'maxpool3x3'\n",
        "NUM_VERTICES = 7\n",
        "MAX_EDGES = 9\n",
        "EDGE_SPOTS = NUM_VERTICES * (NUM_VERTICES - 1) / 2   # Upper triangular matrix\n",
        "OP_SPOTS = NUM_VERTICES - 2   # Input/output vertices are fixed\n",
        "ALLOWED_OPS = [CONV3X3, CONV1X1, MAXPOOL3X3]\n",
        "ALLOWED_EDGES = [0, 1]   # Binary adjacency matrix\n",
        "\n",
        "def check_valid(model_spec):\n",
        "    try:\n",
        "        check_spec(model_spec)\n",
        "    except OutOfDomainError:\n",
        "        return False\n",
        "    return True\n",
        "\n",
        "def random_spec():\n",
        "  \"\"\"Returns a random valid spec.\"\"\"\n",
        "  while True:\n",
        "    matrix = np.random.choice(ALLOWED_EDGES, size=(NUM_VERTICES, NUM_VERTICES))\n",
        "    matrix = np.triu(matrix, 1)\n",
        "    ops = np.random.choice(ALLOWED_OPS, size=(NUM_VERTICES)).tolist()\n",
        "    ops[0] = INPUT\n",
        "    ops[-1] = OUTPUT\n",
        "    spec = ModelSpec_(matrix=matrix, ops=ops)\n",
        "    if check_valid(spec):\n",
        "      return spec\n",
        "\n",
        "def mutate_spec(old_spec, mutation_rate=1.0):\n",
        "  \"\"\"Computes a valid mutated spec from the old_spec.\"\"\"\n",
        "  while True:\n",
        "    new_matrix = copy.deepcopy(old_spec.original_matrix)\n",
        "    new_ops = copy.deepcopy(old_spec.original_ops)\n",
        "\n",
        "    # In expectation, V edges flipped (note that most end up being pruned).\n",
        "    edge_mutation_prob = mutation_rate / NUM_VERTICES\n",
        "    for src in range(0, NUM_VERTICES - 1):\n",
        "      for dst in range(src + 1, NUM_VERTICES):\n",
        "        if random.random() < edge_mutation_prob:\n",
        "          new_matrix[src, dst] = 1 - new_matrix[src, dst]\n",
        "\n",
        "    # In expectation, one op is resampled.\n",
        "    op_mutation_prob = mutation_rate / OP_SPOTS\n",
        "    for ind in range(1, NUM_VERTICES - 1):\n",
        "      if random.random() < op_mutation_prob:\n",
        "        available = [o for o in ['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'] if o != new_ops[ind]]\n",
        "        new_ops[ind] = random.choice(available)\n",
        "\n",
        "    new_spec = ModelSpec_(new_matrix, new_ops)\n",
        "    if check_valid(new_spec):\n",
        "      return new_spec\n",
        "\n",
        "def random_combination(iterable, sample_size):\n",
        "    \"\"\"Random selection from itertools.combinations(iterable, r).\"\"\"\n",
        "    pool = tuple(iterable)\n",
        "    n = len(pool)\n",
        "    indices = sorted(random.sample(range(n), sample_size))\n",
        "    return tuple(pool[i] for i in indices)\n",
        "\n",
        "\n",
        "def custom_evaluate(spec, iepoch):\n",
        "    h = spec.hash_spec(['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'])\n",
        "    score = benchmark[f'{iepoch}'][h]['val_acc']\n",
        "    time = benchmark[f'{iepoch}'][h]['train_time']\n",
        "    test_acc = benchmark[f'108'][h]['test_acc']\n",
        "    return score, time, test_acc\n",
        "\n",
        "\n",
        "def run_evolution_search(max_time_budget=5e6,\n",
        "                         population_size=50,\n",
        "                         tournament_size=10,\n",
        "                         mutation_rate=1.0,\n",
        "                         iepoch=108, warm_up=False, n_warmup=2000):\n",
        "  \"\"\"Run a single roll-out of regularized evolution to a fixed time budget.\"\"\"\n",
        "  times, best_valids, best_tests, BEST_ARCH = [0.0], [0.0], [0.0], None\n",
        "  population = []   # (validation, spec) tuples\n",
        "  trend_time = [0.0]\n",
        "  # For the first population_size individuals, seed the population with randomly\n",
        "  # generated cells.\n",
        "  if not warm_up:\n",
        "    list_spec = [random_spec() for _ in range(population_size)]\n",
        "  else:\n",
        "    list_spec = run_warm_up(n_warmup, population_size)\n",
        "  for spec in list_spec:\n",
        "    val_acc, time, test_acc = custom_evaluate(spec, iepoch=iepoch)\n",
        "\n",
        "    times.append(time)\n",
        "    trend_time.append(sum(times))\n",
        "    population.append((val_acc, spec))\n",
        "\n",
        "    if val_acc > best_valids[-1]:\n",
        "      best_valids.append(val_acc)\n",
        "      best_tests.append(test_acc)\n",
        "      BEST_ARCH = spec\n",
        "    else:\n",
        "      best_valids.append(best_valids[-1])\n",
        "      best_tests.append(best_tests[-1])\n",
        "\n",
        "    if sum(times) > max_time_budget:\n",
        "        return times, trend_time, best_valids, best_tests\n",
        "\n",
        "  # After the population is seeded, proceed with evolving the population.\n",
        "  while True:\n",
        "    sample = random_combination(population, tournament_size)\n",
        "    best_spec = sorted(sample, key=lambda i:i[0])[-1][1]\n",
        "    new_spec = mutate_spec(best_spec, mutation_rate)\n",
        "\n",
        "    val_acc, time, test_acc = custom_evaluate(new_spec, iepoch=iepoch)\n",
        "    times.append(time)\n",
        "    trend_time.append(sum(times))\n",
        "\n",
        "    # In regularized evolution, we kill the oldest individual in the population.\n",
        "    population.append((val_acc, new_spec))\n",
        "    population.pop(0)\n",
        "\n",
        "    if val_acc > best_valids[-1]:\n",
        "      best_valids.append(val_acc)\n",
        "      best_tests.append(test_acc)\n",
        "      BEST_ARCH = new_spec\n",
        "    else:\n",
        "      best_valids.append(best_valids[-1])\n",
        "      best_tests.append(best_tests[-1])\n",
        "\n",
        "    if sum(times) > max_time_budget:\n",
        "      return times, trend_time, best_valids, best_tests, BEST_ARCH"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "KCHxODBBbjxN"
      },
      "outputs": [],
      "source": [
        "def run_warm_up(n_sample, k):\n",
        "    list_arch, list_scores = [], []\n",
        "    for _ in range(n_sample):\n",
        "        spec = random_spec()\n",
        "        h = spec.hash_spec(['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'])\n",
        "        score = zc_benchmark[h]['synflow']\n",
        "        list_arch.append(spec)\n",
        "        list_scores.append(score)\n",
        "    list_arch = np.array(list_arch)\n",
        "    list_scores = np.array(list_scores)\n",
        "    ids = np.flip(np.argsort(list_scores))\n",
        "    list_arch = list_arch[ids]\n",
        "    return list_arch[:k]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VQOUCK7225hT"
      },
      "source": [
        "# MF-NAS"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xmog5VC65pW8"
      },
      "source": [
        "## Local Search"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "1bwdBM242_Eg"
      },
      "outputs": [],
      "source": [
        "def ILS_SH(zc_metric, k, n_run, max_eval, allowed_time, list_epochs, verbose):\n",
        "    all_trend_best_solutions_LS, all_trend_f_best_solutions_LS, all_trend_cost_LS = [], [], []\n",
        "    true_best_f_LS = []\n",
        "\n",
        "    all_best_solution_SH, all_f_best_solution_SH = [], []\n",
        "    all_cost_SH = []\n",
        "    all_total_epoch_SH = []\n",
        "    true_best_f_topk_SH = []\n",
        "\n",
        "    for seed in tqdm(range(1, n_run + 1)):\n",
        "        np.random.seed(seed)\n",
        "        random.seed(seed)\n",
        "        while True:\n",
        "            init_solution = sample(False)\n",
        "            if is_valid(init_solution):\n",
        "                break\n",
        "\n",
        "        # Stage 1: Explore search space with Training-free Local Search\n",
        "        trend_best_solutions_LS, trend_cost_LS, found_solutions_LS, f_found_solutions_LS = first_improvement_ls(init_solution, metric=zc_metric, max_eval=max_eval)\n",
        "\n",
        "        ## Evaluate best solutions obtained during the search (test_acc) (just for analyzing)\n",
        "        trend_cost_LS = trend_cost_LS[:max_eval]\n",
        "\n",
        "        trend_best_solutions_LS = trend_best_solutions_LS[:max_eval]\n",
        "        trend_f_best_solution_LS = evaluate_trend_best_state(trend_best_solutions_LS)\n",
        "\n",
        "        ## Evaluate all states obtained during the search (test_acc) (just for analyzing)\n",
        "        true_f_found_solutions_LS = evaluate_trend_best_state(found_solutions_LS[:max_eval])\n",
        "        true_best_f_LS.append(max(true_f_found_solutions_LS))\n",
        "\n",
        "        all_trend_best_solutions_LS.append(trend_best_solutions_LS)\n",
        "        all_trend_f_best_solutions_LS.append(trend_f_best_solution_LS)\n",
        "        all_trend_cost_LS.append(trend_cost_LS)\n",
        "\n",
        "        ## Remove duplicates in the memory\n",
        "        found_solutions_LS, f_found_solutions_LS = found_solutions_LS[:max_eval], f_found_solutions_LS[:max_eval]\n",
        "\n",
        "        list_h = []\n",
        "        for state in found_solutions_LS:\n",
        "            arch = decode(state)\n",
        "            modelspec = ModelSpec_(arch['matrix'], arch['ops'])\n",
        "            h = modelspec.hash_spec(['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'])\n",
        "            list_h.append(h)\n",
        "\n",
        "        _, ids = np.unique(list_h, return_index=True)\n",
        "\n",
        "        found_solutions_LS = np.array(found_solutions_LS)[ids]\n",
        "        f_found_solutions_LS = np.array(f_found_solutions_LS)[ids]\n",
        "\n",
        "        ## Sort the fitness in descending order\n",
        "        ids = np.flip(np.argsort(f_found_solutions_LS))\n",
        "        found_solutions_LS = found_solutions_LS[ids]\n",
        "\n",
        "        ## Filter out top k% best architectures\n",
        "        topk_found_solutions = found_solutions_LS[:k]\n",
        "\n",
        "        ## Evaluate test_performance of top-k% (just for analyzing)\n",
        "        f_topk = evaluate_trend_best_state(topk_found_solutions)\n",
        "        true_best_f_topk_SH.append(max(f_topk))\n",
        "\n",
        "        # Stage 2: Selection the best one by Succesive Halving\n",
        "        best_solution_SH, total_cost_SH, total_epoch_SH = succesive_halving(topk_found_solutions, 'val_acc', allowed_time - trend_cost_LS[-1], list_epochs)\n",
        "\n",
        "\n",
        "        f_best_solution_SH = evaluate_trend_best_state([best_solution_SH])[-1]\n",
        "\n",
        "        all_best_solution_SH.append(best_solution_SH)\n",
        "        all_f_best_solution_SH.append(f_best_solution_SH)\n",
        "        all_cost_SH.append(total_cost_SH)\n",
        "        all_total_epoch_SH.append(total_epoch_SH)\n",
        "\n",
        "        if verbose:\n",
        "            print('ID:', seed)\n",
        "            print(f'-> Best solution: {decode(best_solution_SH)}')\n",
        "            print(f'-> Test accuracy: {f_best_solution_SH*100}%')\n",
        "\n",
        "\n",
        "    all_trend_best_solutions_LS = np.array(all_trend_best_solutions_LS)\n",
        "    all_trend_f_best_solutions_LS = np.array(all_trend_f_best_solutions_LS)\n",
        "    all_trend_cost_LS = np.array(all_trend_cost_LS)\n",
        "\n",
        "    all_best_solution_SH = np.array(all_best_solution_SH)\n",
        "    all_f_best_solution_SH = np.array(all_f_best_solution_SH)\n",
        "    all_cost_SH = np.array(all_cost_SH)\n",
        "    all_total_epoch_SH = np.array(all_total_epoch_SH)\n",
        "\n",
        "    print(f'#Runs:', n_run)\n",
        "    print(f'- Stage 1 [ILS (metric={zc_metric}, max_eval={max_eval})]:', np.round(np.mean(all_trend_f_best_solutions_LS, axis=0)[-1] * 100, 2), np.round(np.std(all_trend_f_best_solutions_LS, axis=0)[-1] * 100, 2))\n",
        "    print(f'- Stage 2 [Successive Halving (top-{k}, metric=val_acc, budget={allowed_time} sec)]:', np.round(np.mean(all_f_best_solution_SH) * 100, 2), np.round(np.std(all_f_best_solution_SH) * 100, 2))\n",
        "    print(f'- Total Cost: {int(np.mean(all_trend_cost_LS, axis=0)[-1])} + {int(np.mean(all_cost_SH))} = {int(np.mean(all_trend_cost_LS, axis=0)[-1]) + int(np.mean(all_cost_SH))} seconds')\n",
        "    print(f'- Total Epochs: {int(np.mean(all_total_epoch_SH))}')\n",
        "    print(f'- Best visited solution (all):', np.round(np.mean(true_best_f_LS) * 100, 2), np.round(np.std(true_best_f_LS) * 100, 2))\n",
        "    print(f'- Best visited solution (top-{k}):', np.round(np.mean(true_best_f_topk_SH) * 100, 2), np.round(np.std(true_best_f_topk_SH) * 100, 2))\n",
        "\n",
        "    return all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vrDgQ8555mVn"
      },
      "source": [
        "## Random Search"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "eEQI34jQ5oRg"
      },
      "outputs": [],
      "source": [
        "def RS_SH(zc_metric, k, n_run, max_eval, allowed_time, list_epochs, verbose):\n",
        "    all_trend_best_solutions_RS, all_trend_f_best_solutions_RS, all_trend_cost_RS = [], [], []\n",
        "    true_best_f_RS = []\n",
        "\n",
        "    all_best_solution_SH, all_f_best_solution_SH = [], []\n",
        "    all_cost_SH = []\n",
        "    all_total_epoch_SH = []\n",
        "    true_best_f_topk_SH = []\n",
        "\n",
        "    for seed in tqdm(range(1, n_run + 1)):\n",
        "        np.random.seed(seed)\n",
        "        random.seed(seed)\n",
        "\n",
        "        # Stage 1: Explore search space with Training-free Local Search\n",
        "        trend_best_solutions_LS, trend_cost_LS, found_solutions_LS, f_found_solutions_LS = random_search(metric=zc_metric, max_eval=max_eval)\n",
        "\n",
        "        ## Evaluate best solutions obtained during the search (test_acc) (for analyzing)\n",
        "        trend_cost_RS = trend_cost_RS[:max_eval]\n",
        "\n",
        "        trend_best_solutions_RS = trend_best_solutions_RS[:max_eval]\n",
        "        trend_f_best_solution_RS = evaluate_trend_best_state(trend_best_solutions_RS)\n",
        "\n",
        "        ## Evaluate all states obtained during the search (test_acc) (just for analyzing)\n",
        "        true_f_found_solutions_RS = evaluate_trend_best_state(found_solutions_RS[:max_eval])\n",
        "        true_best_f_RS.append(max(true_f_found_solutions_RS))\n",
        "\n",
        "        all_trend_best_solutions_RS.append(trend_best_solutions_RS)\n",
        "        all_trend_f_best_solutions_RS.append(trend_f_best_solution_RS)\n",
        "        all_trend_cost_RS.append(trend_cost_RS)\n",
        "\n",
        "        ## Remove duplicates in the memory\n",
        "        found_solutions_RS, f_found_solutions_RS = found_solutions_RS[:max_eval], f_found_solutions_RS[:max_eval]\n",
        "\n",
        "        list_h = []\n",
        "        for state in found_solutions_RS:\n",
        "            arch = decode(state)\n",
        "            modelspec = ModelSpec_(arch['matrix'], arch['ops'])\n",
        "            h = modelspec.hash_spec(['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'])\n",
        "            list_h.append(h)\n",
        "\n",
        "        _, ids = np.unique(list_h, return_index=True)\n",
        "\n",
        "        found_solutions_RS = np.array(found_solutions_RS)[ids]\n",
        "        f_found_solutions_RS = np.array(f_found_solutions_RS)[ids]\n",
        "\n",
        "        ## Sort the fitness in descending order\n",
        "        ids = np.flip(np.argsort(f_found_solutions_RS))\n",
        "        found_solutions_RS = found_solutions_RS[ids]\n",
        "\n",
        "        ## Filter out top k% best architectures\n",
        "        topk_found_solutions = found_solutions_LS[:k]\n",
        "\n",
        "        ## Evaluate test_performance of top-k% (for analyzing)\n",
        "        f_topk = evaluate_trend_best_state(topk_found_solutions)\n",
        "        true_best_f_topk_SH.append(max(f_topk))\n",
        "\n",
        "        # Stage 2: Selection the best one by Succesive Halving\n",
        "        best_solution_SH, total_cost_SH, total_epoch_SH = succesive_halving(topk_found_solutions, 'val_acc', allowed_time - trend_cost_LS[-1], list_epochs)\n",
        "\n",
        "        f_best_solution_SH = evaluate_trend_best_state([best_solution_SH])[-1]\n",
        "        all_best_solution_SH.append(best_solution_SH)\n",
        "        all_f_best_solution_SH.append(f_best_solution_SH)\n",
        "        all_cost_SH.append(total_cost_SH)\n",
        "        all_total_epoch_SH.append(total_epoch_SH)\n",
        "\n",
        "        if verbose:\n",
        "            print('ID:', seed)\n",
        "            print(f'-> Best solution: {encode(best_solution_SH)}')\n",
        "            print(f'-> Test accuracy: {f_best_solution_SH*100}%')\n",
        "\n",
        "    all_trend_best_solutions_LS = np.array(all_trend_best_solutions_LS)\n",
        "    all_trend_f_best_solutions_LS = np.array(all_trend_f_best_solutions_LS)\n",
        "    all_trend_cost_LS = np.array(all_trend_cost_LS)\n",
        "\n",
        "    all_best_solution_SH = np.array(all_best_solution_SH)\n",
        "    all_f_best_solution_SH = np.array(all_f_best_solution_SH)\n",
        "    all_cost_SH = np.array(all_cost_SH)\n",
        "    all_total_epoch_SH = np.array(all_total_epoch_SH)\n",
        "\n",
        "    print(f'#Runs:', n_run)\n",
        "    print(f'- Stage 1 [Random Search (metric={zc_metric}, max_eval={max_eval})]:', np.round(np.mean(all_trend_f_best_solutions_RS, axis=0)[-1] * 100, 2), np.round(np.std(all_trend_f_best_solutions_RS, axis=0)[-1] * 100, 2))\n",
        "    print(f'- Stage 2 [Successive Halving (top-{k}, metric={metric}, budget={max_budget} sec)]:', np.round(np.mean(all_f_best_solution_SH) * 100, 2), np.round(np.std(all_f_best_solution_SH) * 100, 2))\n",
        "    print(f'- Total Cost: {int(np.mean(all_trend_cost_RS, axis=0)[-1])} + {int(np.mean(all_cost_SH))} = {int(np.mean(all_trend_cost_RS, axis=0)[-1]) + int(np.mean(all_cost_SH))} seconds')\n",
        "    print(f'- Total Epochs: {int(np.mean(all_total_epoch_SH))}')\n",
        "    print(f'- Best visited solution (all):', np.round(np.mean(true_best_f_RS) * 100, 2), np.round(np.std(true_best_f_RS) * 100, 2))\n",
        "    print(f'- Best visited solution (top-{k}):', np.round(np.mean(true_best_f_topk_SH) * 100, 2), np.round(np.std(true_best_f_topk_SH) * 100, 2))\n",
        "\n",
        "    return all_trend_best_solutions_LS, all_trend_f_best_solutions_LS, all_best_solution_SH, all_f_best_solution_SH, all_trend_cost_LS, all_cost_SH, all_total_epoch_SH, true_best_f_LS, true_best_f_topk_SH\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GdWMlbdE9xQV"
      },
      "source": [
        "# Run"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "eywOy7BI93h8"
      },
      "outputs": [],
      "source": [
        "n_run = 500\n",
        "max_eval = 2000  # Stage 1 (Local Search)\n",
        "max_budget = 20000  # Stage 2 (SH, seconds)\n",
        "list_epochs = [4, 12, 36, 108]  # First epoch\n",
        "k = 16  # Top-k\n",
        "verbose = False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fknx0P1TAWS5"
      },
      "source": [
        "## Main metrics (Synflow, FLOPs, Params)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true,
          "base_uri": "https://localhost:8080/",
          "height": 427
        },
        "id": "PQ3hQTTX-C0M",
        "outputId": "ec3e8cc3-e352-4072-9145-10eda7433c13"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "ZC METRIC: SYNFLOW\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 500/500 [23:58<00:00,  2.88s/it]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=synflow, max_eval=2000)]: 93.15 1.31\n",
            "- Stage 2 [Successive Halving (top-16, metric=val_acc, budget=20000 sec)]: 93.82 0.56\n",
            "- Total Cost: 2871 + 11871 = 14742 seconds\n",
            "- Total Epochs: 367\n",
            "- Best visited solution (all): 94.03 0.29\n",
            "- Best visited solution (top-16): 93.92 0.47\n",
            "****************************************************************************************************\n",
            "ZC METRIC: PARAMS\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 500/500 [22:48<00:00,  2.74s/it]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=params, max_eval=2000)]: 93.94 0.32\n",
            "- Stage 2 [Successive Halving (top-16, metric=val_acc, budget=20000 sec)]: 93.89 0.25\n",
            "- Total Cost: 604 + 12913 = 13517 seconds\n",
            "- Total Epochs: 368\n",
            "- Best visited solution (all): 94.13 0.16\n",
            "- Best visited solution (top-16): 94.09 0.19\n",
            "****************************************************************************************************\n",
            "ZC METRIC: FLOPS\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            " 58%|█████▊    | 289/500 [13:08<09:07,  2.60s/it]"
          ]
        }
      ],
      "source": [
        "for zc_metric in ['synflow', 'params', 'flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, verbose=verbose)\n",
        "    print('*'*100)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "for zc_metric in ['flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, verbose=verbose)\n",
        "    print('*'*100)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "P0eRuAHx7rkw",
        "outputId": "1d54416b-09af-4112-d321-a80a1ce420b9"
      },
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: FLOPS\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 500/500 [25:07<00:00,  3.01s/it]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=flops, max_eval=2000)]: 93.93 0.32\n",
            "- Stage 2 [Successive Halving (top-16, metric=val_acc, budget=20000 sec)]: 93.88 0.25\n",
            "- Total Cost: 604 + 12883 = 13487 seconds\n",
            "- Total Epochs: 368\n",
            "- Best visited solution (all): 94.13 0.17\n",
            "- Best visited solution (top-16): 94.09 0.19\n",
            "****************************************************************************************************\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sF0Z9l-NW_Iy"
      },
      "source": [
        "## Succesive Halving"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GApaXgm5XAnh",
        "outputId": "29f900e1-311b-4b2e-8123-3dab5e4f0a40"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Mean: 93.19   Std: 0.46\n"
          ]
        }
      ],
      "source": [
        "all_test_acc = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, total_epoch = run_succesive_halving(seed=run_id, N=k, max_budget=max_budget, list_epochs=list_epochs)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {decode(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EzsWTpJkYVtw"
      },
      "source": [
        "## Local Search"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "o6c7PxUnYXKb",
        "outputId": "74ecab4b-98e0-4155-c371-0dc37d8559cb"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Mean: 93.16   Std: 0.56\n"
          ]
        }
      ],
      "source": [
        "all_test_acc = []\n",
        "iepoch = 12\n",
        "verbose=False\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_ils(seed=run_id, iepoch=iepoch, max_budget=max_budget)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {decode(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QSWNo6tcZQYs"
      },
      "source": [
        "## Random Search"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "duZ319oqZ7rG",
        "outputId": "e6485b7b-a6b3-44db-aab1-5245a58d0eed"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Mean: 93.16   Std: 0.26\n"
          ]
        }
      ],
      "source": [
        "all_test_acc = []\n",
        "iepoch = 12\n",
        "verbose=False\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_random_search(seed=run_id, iepoch=iepoch, max_time_budget=max_budget)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {decode(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9ViLd-nmZE3a"
      },
      "source": [
        "## REA"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gNBXMzqVa0iN",
        "outputId": "97d6b3af-d3fb-4374-b97d-1a91841cac4e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 93.24   Std: 0.27\n"
          ]
        }
      ],
      "source": [
        "population_size, tournament_size = 10, 10\n",
        "iepoch = 12\n",
        "all_test_acc = []\n",
        "\n",
        "for seed in range(1, n_run + 1):\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "    _, _, _, best_test, best_spec = run_evolution_search(max_budget, population_size=population_size, tournament_size=tournament_size, iepoch=iepoch)\n",
        "    all_test_acc.append(best_test[-1]*100)\n",
        "    # print(best_spec)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GBNxx295ZGO4"
      },
      "source": [
        "## REA + W"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "Q7Gm8vVzcNn0",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "86517ebc-0192-4b00-a8d3-0946ab4f67fe"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 93.22   Std: 0.28\n"
          ]
        }
      ],
      "source": [
        "population_size, tournament_size = 10, 10\n",
        "warm_up, n_warmup = True, 2000\n",
        "iepoch = 12\n",
        "all_test_acc = []\n",
        "\n",
        "for seed in range(1, n_run + 1):\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "    _, _, _, best_test, best_spec = run_evolution_search(max_budget, population_size=population_size, tournament_size=tournament_size, iepoch=iepoch, warm_up=warm_up, n_warmup=n_warmup)\n",
        "    all_test_acc.append(best_test[-1]*100)\n",
        "    # print(best_spec)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "65npxV5kAawo"
      },
      "source": [
        "## Other metrics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t3lnh_P6AhJH"
      },
      "outputs": [],
      "source": [
        "for zc_metric in ['jacob_cov', 'grasp', 'fisher', 'grad_norm', 'snip']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, verbose=True)\n",
        "    print('*'*100)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a8uzy-L8BFqH"
      },
      "source": [
        "## Compare to RankNOSH"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "spQqEmr89mbs",
        "outputId": "05cb4e0f-e0de-42b9-ad98-17dd0455d36d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "ZC METRIC: PARAMS\n",
            "ID: 1\n",
            "-> Best solution: {'matrix': array([[0, 1, 1, 1, 1, 1, 1],\n",
            "       [0, 0, 1, 0, 0, 0, 0],\n",
            "       [0, 0, 0, 1, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 1, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 1],\n",
            "       [0, 0, 0, 0, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 0]]), 'ops': ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3', 'output']}\n",
            "-> Test accuracy: 93.95%\n",
            "ID: 2\n",
            "-> Best solution: {'matrix': array([[0, 1, 1, 1, 1, 0, 1],\n",
            "       [0, 0, 0, 0, 1, 0, 0],\n",
            "       [0, 0, 0, 1, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 1, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 1],\n",
            "       [0, 0, 0, 0, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 0]]), 'ops': ['input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv1x1-bn-relu', 'output']}\n",
            "-> Test accuracy: 94.23%\n",
            "ID: 3\n",
            "-> Best solution: {'matrix': array([[0, 1, 1, 1, 0, 1, 1],\n",
            "       [0, 0, 0, 1, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 1],\n",
            "       [0, 0, 0, 0, 0, 0, 0]]), 'ops': ['input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3', 'conv3x3-bn-relu', 'output']}\n",
            "-> Test accuracy: 94.23%\n",
            "ID: 4\n",
            "-> Best solution: {'matrix': array([[0, 0, 1, 1, 1, 1, 1],\n",
            "       [0, 0, 0, 0, 0, 0, 1],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 1, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 1],\n",
            "       [0, 0, 0, 0, 0, 0, 0]]), 'ops': ['input', 'maxpool3x3', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv1x1-bn-relu', 'output']}\n",
            "-> Test accuracy: 94.14%\n",
            "ID: 5\n",
            "-> Best solution: {'matrix': array([[0, 1, 1, 1, 0, 1, 1],\n",
            "       [0, 0, 1, 1, 0, 1, 0],\n",
            "       [0, 0, 0, 1, 1, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 1],\n",
            "       [0, 0, 0, 0, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 0]]), 'ops': ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'output']}\n",
            "-> Test accuracy: 94.22%\n",
            "ID: 6\n",
            "-> Best solution: {'matrix': array([[0, 0, 1, 1, 1, 1, 1],\n",
            "       [0, 0, 0, 1, 1, 0, 0],\n",
            "       [0, 0, 0, 1, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 1],\n",
            "       [0, 0, 0, 0, 0, 0, 0]]), 'ops': ['input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'output']}\n",
            "-> Test accuracy: 94.23%\n",
            "ID: 7\n",
            "-> Best solution: {'matrix': array([[0, 1, 1, 1, 1, 0, 1],\n",
            "       [0, 0, 0, 0, 1, 0, 0],\n",
            "       [0, 0, 0, 1, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 1, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 1],\n",
            "       [0, 0, 0, 0, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 0]]), 'ops': ['input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'output']}\n",
            "-> Test accuracy: 94.23%\n",
            "ID: 8\n",
            "-> Best solution: {'matrix': array([[0, 1, 1, 1, 1, 0, 1],\n",
            "       [0, 0, 1, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 1, 1, 1],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 0]]), 'ops': ['input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3', 'maxpool3x3', 'maxpool3x3', 'output']}\n",
            "-> Test accuracy: 93.58%\n",
            "ID: 9\n",
            "-> Best solution: {'matrix': array([[0, 0, 1, 1, 1, 1, 1],\n",
            "       [0, 0, 1, 1, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 1],\n",
            "       [0, 0, 0, 0, 0, 0, 0]]), 'ops': ['input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv1x1-bn-relu', 'output']}\n",
            "-> Test accuracy: 93.97%\n",
            "ID: 10\n",
            "-> Best solution: {'matrix': array([[0, 1, 1, 0, 0, 1, 1],\n",
            "       [0, 0, 1, 0, 1, 0, 0],\n",
            "       [0, 0, 0, 0, 1, 0, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 1, 0],\n",
            "       [0, 0, 0, 0, 0, 0, 1],\n",
            "       [0, 0, 0, 0, 0, 0, 0]]), 'ops': ['input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'output']}\n",
            "-> Test accuracy: 93.92%\n",
            "#Runs: 10\n",
            "- Stage 1 [ILS (metric=params, max_eval=2000)]: 94.08 0.21\n",
            "- Stage 2 [Successive Halving (top-16, metric=val_acc, budget=9999999 sec)]: 94.07 0.2\n",
            "- Total Cost: 604 + 42320 = 42924 seconds\n",
            "- Total Epochs: 1152\n",
            "- Best visited solution (all): 94.18 0.14\n",
            "- Best visited solution (top-16): 94.14 0.14\n",
            "****************************************************************************************************\n"
          ]
        }
      ],
      "source": [
        "list_epochs = [36, 108]\n",
        "max_budget = 9999999\n",
        "n_run = 10\n",
        "for zc_metric in ['params']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, verbose=True)\n",
        "    print('*'*100)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "YEEe3vfm7h4m",
        "M_OXQMY9wWK6",
        "densE8Qu1NJC",
        "K7hDoccf1SLc",
        "i81aVIOI1TmP",
        "fQrD0rQtXC0W",
        "Xmog5VC65pW8",
        "vrDgQ8555mVn"
      ],
      "toc_visible": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}