{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "This is the source codes of Multi-Fidelity Neural Architecture Search (MF-NAS) on NAS-Bench-ASR.\n",
        "\n",
        "The content is presented as:\n",
        "\n",
        "\n",
        "*   The utility functions are presented in Section 'Utilities'.\n",
        "*   Codes for loading benchmark databases are in Section 'Load Benchmark'.\n",
        "*   Codes of First-Improvement Local Search, Random Search, Successive Halving, MF-NAS are presented in the Section 'Algorithms' and Section 'MF-NAS'.\n",
        "*   Codes for running Random Search, Local Search, SH, REA, REA+W, and MF-NAS variants in Section 'Run'.\n",
        "\n",
        "All results (for NAS-Bench-ASR) in the paper have already presented in Section 'Run'.\n",
        "\n",
        "Executing all cells if you want to reproduce our results in the paper.\n"
      ],
      "metadata": {
        "id": "hDFZA_GZCQem"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YEEe3vfm7h4m"
      },
      "source": [
        "## Utilities"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "f20ae2a9"
      },
      "outputs": [],
      "source": [
        "import copy\n",
        "import pickle as p\n",
        "import numpy as np\n",
        "import random\n",
        "from copy import deepcopy\n",
        "import itertools\n",
        "import json\n",
        "import math\n",
        "from scipy import stats\n",
        "import matplotlib\n",
        "import hashlib"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "SkHiMJtPYwb2"
      },
      "outputs": [],
      "source": [
        "MAIN_OPS = [0, 1, 2, 3, 4, 5]\n",
        "IDX_MAIN_OPS = [0, 2, 5]\n",
        "\n",
        "SKIP_OPS = [0, 1]\n",
        "IDX_SKIP_OPS = [1, 3, 4, 6, 7, 8]\n",
        "maxLength = 9\n",
        "\n",
        "all_ops = ['linear', 'conv5', 'conv5d2', 'conv7', 'conv7d2', 'zero']\n",
        "\n",
        "def arch_int_to_vec(arch_int):\n",
        "    arch_vec = [(arch_int[0], arch_int[1]), (arch_int[2], arch_int[3], arch_int[4]),\n",
        "                (arch_int[5], arch_int[6], arch_int[7], arch_int[8])]\n",
        "    return arch_vec\n",
        "\n",
        "def get_model_graph(arch_vec, ops=None, minimize=True, keep_dims=False):\n",
        "    if ops is None:\n",
        "        ops = all_ops\n",
        "    num_nodes = len(arch_vec)\n",
        "    mat = np.zeros((num_nodes+2, num_nodes+2))\n",
        "    labels = ['input']\n",
        "    prev_skips = []\n",
        "    for nidx, node in enumerate(arch_vec):\n",
        "        op = node[0]\n",
        "        labels.append(ops[op])\n",
        "        mat[nidx, nidx+1] = 1\n",
        "        for i, sc in enumerate(prev_skips):\n",
        "            if sc:\n",
        "                mat[i, nidx+1] = 1\n",
        "        prev_skips = node[1:]\n",
        "    labels.append('output')\n",
        "    mat[num_nodes, num_nodes+1] = 1\n",
        "    for i, sc in enumerate(prev_skips):\n",
        "        if sc:\n",
        "            mat[i, num_nodes+1] = 1\n",
        "    orig = None\n",
        "    if minimize:\n",
        "        orig = copy.copy(mat), copy.copy(labels)\n",
        "        for n in range(len(mat)):\n",
        "            if labels[n] == 'zero':\n",
        "                for n2 in range(len(mat)):\n",
        "                    if mat[n, n2]:\n",
        "                        mat[n, n2] = 0\n",
        "                    if mat[n2, n]:\n",
        "                        mat[n2, n] = 0\n",
        "        def bfs(src, mat, backward):\n",
        "            visited = np.zeros(len(mat))\n",
        "            q = [src]\n",
        "            visited[src] = 1\n",
        "            while q:\n",
        "                n = q.pop()\n",
        "                for n2 in range(len(mat)):\n",
        "                    if visited[n2]:\n",
        "                        continue\n",
        "                    if (backward and mat[n2, n]) or (not backward and mat[n, n2]):\n",
        "                        q.append(n2)\n",
        "                        visited[n2] = 1\n",
        "            return visited\n",
        "        vfw = bfs(0, mat, False)\n",
        "        vbw = bfs(len(mat)-1, mat, True)\n",
        "        v = vfw + vbw\n",
        "        dangling = (v < 2).nonzero()[0]\n",
        "        if dangling.size:\n",
        "            if keep_dims:\n",
        "                mat[dangling, :] = 0\n",
        "                mat[:, dangling] = 0\n",
        "                for i in dangling:\n",
        "                    labels[i] = None\n",
        "            else:\n",
        "                mat = np.delete(mat, dangling, axis=0)\n",
        "                mat = np.delete(mat, dangling, axis=1)\n",
        "                for i in sorted(dangling, reverse=True):\n",
        "                    del labels[i]\n",
        "    return (mat, labels), orig\n",
        "\n",
        "def graph_hash(g):\n",
        "    m, l = g\n",
        "    def hash_module(matrix, labelling):\n",
        "        \"\"\"Computes a graph-invariance MD5 hash of the matrix and label pair.\n",
        "        Args:\n",
        "            matrix: np.ndarray square upper-triangular adjacency matrix.\n",
        "            labelling: list of int labels of length equal to both dimensions of\n",
        "                matrix.\n",
        "        Returns:\n",
        "            MD5 hash of the matrix and labelling.\n",
        "        \"\"\"\n",
        "        vertices = np.shape(matrix)[0]\n",
        "        in_edges = np.sum(matrix, axis=0).tolist()\n",
        "        out_edges = np.sum(matrix, axis=1).tolist()\n",
        "        assert len(in_edges) == len(out_edges) == len(labelling), f'{labelling} {matrix}'\n",
        "        hashes = list(zip(out_edges, in_edges, labelling))\n",
        "        hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes]\n",
        "        # Computing this up to the diameter is probably sufficient but since the\n",
        "        # operation is fast, it is okay to repeat more times.\n",
        "        for _ in range(vertices):\n",
        "            new_hashes = []\n",
        "            for v in range(vertices):\n",
        "                in_neighbours = [hashes[w] for w in range(vertices) if matrix[w, v]]\n",
        "                out_neighbours = [hashes[w] for w in range(vertices) if matrix[v, w]]\n",
        "                new_hashes.append(hashlib.md5(\n",
        "                        (''.join(sorted(in_neighbours)) + '|' +\n",
        "                        ''.join(sorted(out_neighbours)) + '|' +\n",
        "                        hashes[v]).encode('utf-8')).hexdigest())\n",
        "            hashes = new_hashes\n",
        "        fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest()\n",
        "        return fingerprint\n",
        "    labels = []\n",
        "    if l:\n",
        "        labels = [-1] + [all_ops.index(op) for op in l[1:-1]] + [-2]\n",
        "    return hash_module(m, labels)\n",
        "\n",
        "def get_model_hash(arch_vec, ops=None, minimize=True):\n",
        "    \"\"\"\n",
        "    Get hash of the architecture specified by arch_vec.\n",
        "    Architecture hash can be used to determine if two configurations from the search space are in fact the same (graph isomorphism).\n",
        "    \"\"\"\n",
        "    g, _ = get_model_graph(arch_vec, ops=ops, minimize=minimize)\n",
        "    return graph_hash(g)\n",
        "\n",
        "def get_hashKey(arch_int):\n",
        "    arch_vec = arch_int_to_vec(arch_int)\n",
        "    hashKey = get_model_hash(arch_vec)\n",
        "    return hashKey\n",
        "\n",
        "def sample_arch():\n",
        "    arch = np.zeros(maxLength, dtype=np.int8)\n",
        "    arch[IDX_MAIN_OPS] = np.random.choice(MAIN_OPS, len(IDX_MAIN_OPS))\n",
        "    arch[IDX_SKIP_OPS] = np.random.choice(SKIP_OPS, len(IDX_SKIP_OPS))\n",
        "    return arch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "821c6f87"
      },
      "outputs": [],
      "source": [
        "def get_indices(state, distance):\n",
        "    return list(itertools.combinations(range(len(state)), distance))\n",
        "\n",
        "def get_all_neighbors(state, ids):\n",
        "    list_neighbors = []\n",
        "    list_available_ops = []\n",
        "    for i in ids:\n",
        "        _available_ops = [0, 1, 2, 3, 4, 5] if i in [0, 2, 5] else [0, 1]\n",
        "        _available_ops.remove(state[i])\n",
        "        list_available_ops.append(_available_ops)\n",
        "    list_ops = list(itertools.product(*list_available_ops))\n",
        "    ids = np.array(ids)\n",
        "    for ops in list_ops:\n",
        "        neighbor = np.array(state).copy()\n",
        "        neighbor[ids] = ops\n",
        "        list_neighbors.append(neighbor.tolist())\n",
        "    np.random.shuffle(list_neighbors)\n",
        "    return list_neighbors\n",
        "\n",
        "def mutate_spec(old_arch, mutation_rate=1.0):\n",
        "    new_arch = copy.deepcopy(old_arch)\n",
        "\n",
        "    op_mutation_prob = mutation_rate / len(old_arch)\n",
        "    for ind in range(len(old_arch)):\n",
        "        _available_ops = [0, 1, 2, 3, 4, 5] if ind in [0, 2, 5] else [0, 1]\n",
        "        if random.random() < op_mutation_prob:\n",
        "            available = [o for o in _available_ops if o != new_arch[ind]]\n",
        "            new_arch[ind] = random.choice(available)\n",
        "    return new_arch\n",
        "\n",
        "def random_combination(iterable, sample_size):\n",
        "    \"\"\"Random selection from itertools.combinations(iterable, r).\"\"\"\n",
        "    pool = tuple(iterable)\n",
        "    n = len(pool)\n",
        "    indices = sorted(random.sample(range(n), sample_size))\n",
        "    return tuple(pool[i] for i in indices)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "77031a12"
      },
      "outputs": [],
      "source": [
        "def train_evaluate(state, metric):\n",
        "    h = get_hashKey(state)\n",
        "    if metric == 'test_per':\n",
        "        return -benchmark[h]['test_per'], 0.0\n",
        "    else:\n",
        "        iepoch = int(metric.split('_')[-1]) - 1\n",
        "        score = -benchmark[h]['val_per'][iepoch]\n",
        "        return score, 0.0\n",
        "\n",
        "def zc_evaluate(state, metric):\n",
        "    h = get_hashKey(state)\n",
        "    time = 0.0\n",
        "    if metric in ['params', 'FLOPs']:\n",
        "        return benchmark[h][metric], time\n",
        "    return zc_benchmark[h][metric], time\n",
        "\n",
        "def evaluate(state, metric):\n",
        "    if 'per' in metric:\n",
        "        res = train_evaluate(state, metric)\n",
        "    else:\n",
        "        res = zc_evaluate(state, metric)\n",
        "    return res[0], res[1]\n",
        "\n",
        "def evaluate_trend_best_state(trend_best_state):\n",
        "    return [evaluate(state, 'test_per')[0] for state in trend_best_state]\n",
        "\n",
        "\n",
        "def evaluate_trend_best_state1(trend_best_state):\n",
        "    return [evaluate(state, 'val_per_40')[0] for state in trend_best_state]"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load Benchmark"
      ],
      "metadata": {
        "id": "Bp51Mc6_CoiJ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!gdown https://drive.google.com/uc?id=1WX8F4bSIUqac3uUB8OtqH8YB8Ui61fCb\n",
        "!gdown https://drive.google.com/uc?id=1Z7K1YtqeBt6xO1D63dRjSUGCMwrTNDL1"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4EwzolvSG3Mh",
        "outputId": "c6d5369a-ba68-47fa-9e81-65ef86f1a564"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1WX8F4bSIUqac3uUB8OtqH8YB8Ui61fCb\n",
            "To: /content/[TIMIT]_data.p\n",
            "100% 3.43M/3.43M [00:00<00:00, 223MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1Z7K1YtqeBt6xO1D63dRjSUGCMwrTNDL1\n",
            "To: /content/zc_asr.p\n",
            "100% 1.13M/1.13M [00:00<00:00, 102MB/s]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "lJH6asvI1q5K"
      },
      "outputs": [],
      "source": [
        "zc_benchmark = p.load(open('zc_asr.p', 'rb'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "yn2g9Rlm5DU3"
      },
      "outputs": [],
      "source": [
        "benchmark = p.load(open('[TIMIT]_data.p', 'rb'))"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Algorithms"
      ],
      "metadata": {
        "id": "SjHVUaYpCdef"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## First-improvement Local Search"
      ],
      "metadata": {
        "id": "ll-JyaRDC-np"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def first_improvement_ls(init_state, max_eval, metric=None):\n",
        "    assert metric is not None, 'Missing the evaluation metric!'\n",
        "    assert init_state is not None, 'Missing the initial state!'\n",
        "\n",
        "    n_eval = 0\n",
        "    total_time = 0\n",
        "    curr_state = init_state.copy()\n",
        "    f_curr_state, time = evaluate(init_state, metric)\n",
        "    total_time += time\n",
        "    n_eval += 1\n",
        "\n",
        "    best_state = curr_state.copy()\n",
        "    f_best_state = f_curr_state\n",
        "\n",
        "    trend_best_state = [best_state]\n",
        "    trend_time = [total_time]\n",
        "\n",
        "    state_history, f_state_history = [curr_state], [f_curr_state]\n",
        "    visited = [get_hashKey(curr_state)]\n",
        "    while n_eval <= max_eval:\n",
        "        improved = False\n",
        "        list_ids = get_indices(curr_state, 1)\n",
        "        while len(list_ids) != 0:\n",
        "            idx = np.random.choice(range(len(list_ids)))\n",
        "            ids = list_ids[idx]\n",
        "            list_ids.remove(list_ids[idx])\n",
        "\n",
        "            list_neighbors = get_all_neighbors(state=curr_state, ids=ids)\n",
        "            for neighbor in list_neighbors:\n",
        "                h = get_hashKey(neighbor)\n",
        "\n",
        "                f_neighbor, time = evaluate(neighbor, metric)\n",
        "                state_history.append(neighbor)\n",
        "                f_state_history.append(f_neighbor)\n",
        "\n",
        "                total_time += time\n",
        "                trend_time.append(total_time)\n",
        "                n_eval += 1\n",
        "                if f_neighbor >= f_curr_state:\n",
        "                    curr_state = neighbor.copy()\n",
        "                    f_curr_state = f_neighbor\n",
        "\n",
        "                    if f_neighbor > f_best_state:\n",
        "                        best_state = neighbor.copy()\n",
        "                        f_best_state = f_neighbor\n",
        "                    trend_best_state.append(best_state)\n",
        "                    improved = True\n",
        "                    break\n",
        "                else:\n",
        "                    trend_best_state.append(best_state)\n",
        "            if improved:\n",
        "                break\n",
        "\n",
        "        if not improved:\n",
        "            list_ids = get_indices(curr_state, 2)\n",
        "            idx = np.random.choice(range(len(list_ids)))\n",
        "            ids = list_ids[idx]\n",
        "\n",
        "            list_neighbors = get_all_neighbors(curr_state, ids)\n",
        "            curr_state = list_neighbors[np.random.choice(len(list_neighbors))]\n",
        "\n",
        "            f_curr_state, time = evaluate(curr_state, metric)\n",
        "            state_history.append(curr_state)\n",
        "            f_state_history.append(f_curr_state)\n",
        "            total_time += time\n",
        "            n_eval += 1\n",
        "            if f_curr_state > f_best_state:\n",
        "                best_state = curr_state.copy()\n",
        "                f_best_state = f_curr_state\n",
        "            trend_best_state.append(best_state)\n",
        "            trend_time.append(total_time)\n",
        "    return trend_best_state, trend_time, state_history, f_state_history"
      ],
      "metadata": {
        "id": "c9PiBeseCekb"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Random Search"
      ],
      "metadata": {
        "id": "h17YXq-2DBzB"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def random_search(max_eval, metric=None):\n",
        "    assert metric is not None, 'Missing the evaluation metric!'\n",
        "\n",
        "    total_time = 0\n",
        "\n",
        "    best_state, f_best_state = None, None\n",
        "    trend_best_state, trend_time = [], []\n",
        "    state_history, f_state_history = [], []\n",
        "\n",
        "    i = 0\n",
        "    first = True\n",
        "    while i <= max_eval:\n",
        "        state = sample_arch()\n",
        "        f_state, time = evaluate(state, metric)\n",
        "        total_time += time\n",
        "\n",
        "        if first:\n",
        "            best_state, f_best_state = state.copy(), f_state\n",
        "            trend_best_state, trend_time = [best_state], [total_time]\n",
        "            state_history, f_state_history = [state], [f_state]\n",
        "            first = False\n",
        "        else:\n",
        "            state_history.append(state)\n",
        "            f_state_history.append(f_state)\n",
        "\n",
        "            ## If current state is better than the best current state, replace\n",
        "            if f_state > f_best_state:\n",
        "                best_state = state.copy()\n",
        "                f_best_state = f_state\n",
        "            trend_best_state.append(best_state)\n",
        "            trend_time.append(total_time)\n",
        "\n",
        "        i += 1\n",
        "\n",
        "    return trend_best_state, trend_time, state_history, f_state_history"
      ],
      "metadata": {
        "id": "mewqY1WiDDB3"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Successive Halving"
      ],
      "metadata": {
        "id": "IC9UfYfxDGPq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def succesive_halving(list_candidates, metric, max_budget, list_epochs, max_epoch=40):\n",
        "    assert len(list_epochs) != 0, 'Have not set the checkpoints for evaluation!'\n",
        "    checkpoint = 0\n",
        "    iepoch = list_epochs[checkpoint]\n",
        "\n",
        "    last_iepoch = 0\n",
        "    total_time, total_epoch = 0.0, 0.0\n",
        "\n",
        "    best_candidate, f_best_candidate = None, -np.inf\n",
        "    last = False\n",
        "    while True:\n",
        "        evaluated_candidates = []\n",
        "        f_candidates = []\n",
        "        for candidate in list_candidates:\n",
        "            score, time = evaluate(candidate, f'{metric}_{iepoch}')\n",
        "\n",
        "            diff = int(iepoch) - last_iepoch\n",
        "            total_epoch += diff\n",
        "\n",
        "            evaluated_candidates.append(candidate)\n",
        "            f_candidates.append(score)\n",
        "\n",
        "            if score > f_best_candidate:\n",
        "                f_best_candidate = score\n",
        "                best_candidate = candidate\n",
        "\n",
        "        last_iepoch = iepoch\n",
        "\n",
        "        ids = np.flip(np.argsort(f_candidates))\n",
        "        list_candidates = np.array(evaluated_candidates)[ids]\n",
        "        list_candidates = list_candidates[:math.ceil(len(list_candidates) / 2)]\n",
        "\n",
        "        if len(list_candidates) == 1 or last:\n",
        "            return best_candidate, total_time, total_epoch\n",
        "\n",
        "        checkpoint += 1\n",
        "        last_iepoch = int(iepoch)\n",
        "        iepoch = list_epochs[checkpoint]\n",
        "\n",
        "        if iepoch == list_epochs[-1]:\n",
        "            last = True\n",
        "        list_iepoch = iepoch"
      ],
      "metadata": {
        "id": "-L19zoYKDF4P"
      },
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Others"
      ],
      "metadata": {
        "id": "NdE_B2kDLikQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def run_warm_up(n_sample, k):\n",
        "    list_arch, list_scores = [], []\n",
        "    for _ in range(n_sample):\n",
        "        arch = sample_arch()\n",
        "        score, _ = zc_evaluate(arch, 'synflow')\n",
        "        list_arch.append(arch)\n",
        "        list_scores.append(score)\n",
        "    list_arch = np.array(list_arch)\n",
        "    list_scores = np.array(list_scores)\n",
        "    ids = np.flip(np.argsort(list_scores))\n",
        "    list_arch = list_arch[ids]\n",
        "    return list_arch[:k]\n",
        "\n",
        "def run_random_search(seed=0, max_time_budget=5e6, max_eval_budget=100, iepoch=200):\n",
        "    \"\"\"Run a single roll-out of random search to a fixed time budget.\"\"\"\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "\n",
        "    times, best_valids, best_archs = [0.0], [-np.inf], [None]\n",
        "    n_eval = 0\n",
        "    while True:\n",
        "        arch = sample_arch()\n",
        "        val_acc, time = evaluate(arch, f'val_per_{iepoch}')\n",
        "        n_eval += 1\n",
        "        if val_acc > best_valids[-1]:\n",
        "            best_valids.append(val_acc)\n",
        "            best_archs.append(arch)\n",
        "        else:\n",
        "            best_valids.append(best_valids[-1])\n",
        "            best_archs.append(best_archs[-1])\n",
        "\n",
        "        times.append(time)\n",
        "        if sum(times) > max_time_budget or n_eval > max_eval_budget:\n",
        "            # Break the first time we exceed the budget.\n",
        "            best_candidate = best_archs[-1]\n",
        "            f_best_candidate = evaluate(best_candidate, 'test_per')[0]\n",
        "            return best_candidate, f_best_candidate, sum(times), None\n",
        "\n",
        "def run_evolution_search(seed=0, max_time_budget=5e6, max_eval_budget=100, population_size=20,\n",
        "                         tournament_size=4, mutation_rate=1.0, iepoch=200, warm_up=False, n_warmup=2000):\n",
        "    \"\"\"Run a single roll-out of regularized evolution to a fixed time budget.\"\"\"\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "\n",
        "    times, best_valids = [0.0], [-np.inf]\n",
        "    population = []   # (validation, spec) tuples\n",
        "    trend_time = [0.0]\n",
        "    best_test_arch = None\n",
        "    # For the first population_size individuals, seed the population with randomly\n",
        "    # generated cells.\n",
        "    n_eval = 0\n",
        "    if not warm_up:\n",
        "        list_arch = [sample_arch() for _ in range(population_size)]\n",
        "    else:\n",
        "        list_arch = run_warm_up(n_warmup, population_size)\n",
        "    for arch in list_arch:\n",
        "        val_acc, time = evaluate(arch, f'val_per_{iepoch}')\n",
        "        n_eval += 1\n",
        "        times.append(time)\n",
        "        trend_time.append(sum(times))\n",
        "        population.append((val_acc, arch))\n",
        "\n",
        "        if val_acc > best_valids[-1]:\n",
        "            best_valids.append(val_acc)\n",
        "            best_test_arch = arch\n",
        "        else:\n",
        "            best_valids.append(best_valids[-1])\n",
        "\n",
        "        if sum(times) > max_time_budget or n_eval > max_eval_budget:\n",
        "            best_candidate = best_test_arch\n",
        "            f_best_candidate = evaluate(best_candidate, 'test_per')[0]\n",
        "            return best_candidate, f_best_candidate, trend_time[-1], None\n",
        "\n",
        "    # After the population is seeded, proceed with evolving the population.\n",
        "    while True:\n",
        "        sample = random_combination(population, tournament_size)\n",
        "        best_arch = sorted(sample, key=lambda i:i[0])[-1][1]\n",
        "        new_arch = mutate_spec(best_arch, mutation_rate)\n",
        "        n_eval += 1\n",
        "        val_acc, time = evaluate(new_arch, f'val_per_{iepoch}')\n",
        "        times.append(time)\n",
        "        trend_time.append(sum(times))\n",
        "\n",
        "        # In regularized evolution, we kill the oldest individual in the population.\n",
        "        population.append((val_acc, new_arch))\n",
        "        population.pop(0)\n",
        "\n",
        "        if val_acc > best_valids[-1]:\n",
        "            best_valids.append(val_acc)\n",
        "            best_test_arch = new_arch\n",
        "        else:\n",
        "            best_valids.append(best_valids[-1])\n",
        "\n",
        "        if sum(times) > max_time_budget or n_eval > max_eval_budget:\n",
        "            best_candidate = best_test_arch\n",
        "            f_best_candidate = evaluate(best_candidate, 'test_per')[0]\n",
        "            return best_candidate, f_best_candidate, trend_time[-1], None\n",
        "\n",
        "def run_ils(seed, max_budget, iepoch):\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "    init_state = sample_arch()\n",
        "    trend_best_state, trend_time, _, _ = first_improvement_ls(init_state, max_eval=max_budget, metric=f'val_per_{iepoch}')\n",
        "    best_candidate = trend_best_state[-1]\n",
        "    f_best_candidate = evaluate(best_candidate, 'test_per')[0]\n",
        "    return best_candidate, f_best_candidate, trend_time[-1], 0.0\n",
        "\n",
        "def run_succesive_halving(seed, N, max_budget, list_epochs):\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "    list_candidates = [sample_arch() for _ in range(N)]\n",
        "    best_candidate, total_time, total_epoch, _ = succesive_halving(list_candidates, 'val_per', max_budget, list_epochs)\n",
        "    f_best_candidate = evaluate(best_candidate, 'test_per')[0]\n",
        "    return best_candidate, f_best_candidate, total_time, total_epoch"
      ],
      "metadata": {
        "id": "XDHVvqCcLkrv"
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OOKQTE0k7lhY"
      },
      "source": [
        "# MF-NAS"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## with Local Search"
      ],
      "metadata": {
        "id": "9daa-2JSDRPh"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "9OPyuRaoDHxr"
      },
      "outputs": [],
      "source": [
        "def ILS_SH(zc_metric, k, n_run, max_eval, allowed_time, list_epochs, verbose):\n",
        "    all_trend_best_solutions_LS, all_trend_f_best_solutions_LS, all_trend_cost_LS = [], [], []\n",
        "    true_best_f_LS = []\n",
        "\n",
        "    all_best_solution_SH, all_f_best_solution_SH = [], []\n",
        "    all_cost_SH = []\n",
        "    all_total_epoch_SH = []\n",
        "    true_best_f_topk_SH = []\n",
        "\n",
        "    for seed in range(1, n_run + 1):\n",
        "        np.random.seed(seed)\n",
        "        random.seed(seed)\n",
        "\n",
        "        init_solution = sample_arch()\n",
        "\n",
        "        # Stage 1: Explore search space with Training-free Local Search\n",
        "        trend_best_solutions_LS, trend_cost_LS, found_solutions_LS, f_found_solutions_LS = first_improvement_ls(init_solution, metric=zc_metric, max_eval=max_eval)\n",
        "\n",
        "        ## Evaluate best solutions obtained during the search (test_acc) (for analyzing)\n",
        "        trend_cost_LS = trend_cost_LS[:max_eval]\n",
        "\n",
        "        trend_best_solutions_LS = trend_best_solutions_LS[:max_eval]\n",
        "        trend_f_best_solution_LS = evaluate_trend_best_state(trend_best_solutions_LS)\n",
        "\n",
        "        ## Evaluate all states obtained during the search (test_acc) (for analyzing)\n",
        "        true_f_found_solutions_LS = evaluate_trend_best_state(found_solutions_LS[:max_eval])\n",
        "        true_best_f_LS.append(-max(true_f_found_solutions_LS))\n",
        "\n",
        "        all_trend_best_solutions_LS.append(trend_best_solutions_LS)\n",
        "        all_trend_f_best_solutions_LS.append(trend_f_best_solution_LS)\n",
        "        all_trend_cost_LS.append(trend_cost_LS)\n",
        "\n",
        "        ## Remove duplicates in the memory\n",
        "        found_solutions_LS, f_found_solutions_LS = found_solutions_LS[:max_eval], f_found_solutions_LS[:max_eval]\n",
        "        found_solutions_LS, ids = np.unique(found_solutions_LS, axis=0, return_index=True)\n",
        "        f_found_solutions_LS = np.array(f_found_solutions_LS)[ids]\n",
        "\n",
        "        ## Sort the fitness in descending order\n",
        "        ids = np.flip(np.argsort(f_found_solutions_LS))\n",
        "        found_solutions_LS = found_solutions_LS[ids]\n",
        "\n",
        "        ## Filter out top k% best architectures\n",
        "        topk_found_solutions = found_solutions_LS[:k]\n",
        "\n",
        "        ## Evaluate test_performance of top-k% (for analyzing)\n",
        "        f_topk = evaluate_trend_best_state(topk_found_solutions)\n",
        "        true_best_f_topk_SH.append(-max(f_topk))\n",
        "\n",
        "        # Stage 2: Selection the best one by Succesive Halving\n",
        "        best_solution_SH, total_cost_SH, total_epoch_SH = succesive_halving(topk_found_solutions, 'val_per', allowed_time - trend_cost_LS[-1], list_epochs)\n",
        "\n",
        "\n",
        "        f_best_solution_SH = evaluate_trend_best_state([best_solution_SH])[-1]\n",
        "        all_best_solution_SH.append(best_solution_SH)\n",
        "        all_f_best_solution_SH.append(f_best_solution_SH)\n",
        "        all_cost_SH.append(total_cost_SH)\n",
        "        all_total_epoch_SH.append(total_epoch_SH)\n",
        "\n",
        "        if verbose:\n",
        "            print('ID:', seed)\n",
        "            print(f'-> Best solution: {arch_int_to_vec(best_solution_SH)}')\n",
        "            print(f'-> Test PER: {-round(f_best_solution_SH*100, 2)}%')\n",
        "\n",
        "    all_trend_best_solutions_LS = np.array(all_trend_best_solutions_LS)\n",
        "    all_trend_f_best_solutions_LS = -np.array(all_trend_f_best_solutions_LS)\n",
        "    all_trend_cost_LS = np.array(all_trend_cost_LS)\n",
        "\n",
        "    all_best_solution_SH = np.array(all_best_solution_SH)\n",
        "    all_f_best_solution_SH = -np.array(all_f_best_solution_SH)\n",
        "    all_cost_SH = np.array(all_cost_SH)\n",
        "    all_total_epoch_SH = np.array(all_total_epoch_SH)\n",
        "\n",
        "    print(f'#Runs:', n_run)\n",
        "    print(f'- Stage 1 [ILS (metric={zc_metric}, max_eval={max_eval})]:', np.round(np.mean(all_trend_f_best_solutions_LS, axis=0)[-1] * 100, 2), np.round(np.std(all_trend_f_best_solutions_LS, axis=0)[-1] * 100, 2))\n",
        "    print(f'- Stage 2 [Successive Halving (top-{k}, metric=val_per, budget={allowed_time} sec)]:', np.round(np.mean(all_f_best_solution_SH) * 100, 2), np.round(np.std(all_f_best_solution_SH) * 100, 2))\n",
        "    print(f'- Total Epochs: {int(np.mean(all_total_epoch_SH))}')\n",
        "    print(f'- Best visited solution (all):', np.round(np.mean(true_best_f_LS) * 100, 2), np.round(np.std(true_best_f_LS) * 100, 2))\n",
        "    print(f'- Best visited solution (top-{k}):', np.round(np.mean(true_best_f_topk_SH) * 100, 2), np.round(np.std(true_best_f_topk_SH) * 100, 2))\n",
        "\n",
        "    return all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## with Random Search"
      ],
      "metadata": {
        "id": "8J36WdGYGfsQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def RS_SH(zc_metric, k, n_run, max_eval, allowed_time, list_epochs, verbose):\n",
        "    all_trend_best_solutions_RS, all_trend_f_best_solutions_RS, all_trend_cost_RS = [], [], []\n",
        "    true_best_f_RS = []\n",
        "\n",
        "    all_best_solution_SH, all_f_best_solution_SH = [], []\n",
        "    all_cost_SH = []\n",
        "    all_total_epoch_SH = []\n",
        "    true_best_f_topk_SH = []\n",
        "\n",
        "    list_allowed_budget = []\n",
        "    for seed in range(1, n_run + 1):\n",
        "        np.random.seed(seed)\n",
        "        random.seed(seed)\n",
        "\n",
        "        # Stage 1: Explore search space with Zero-cost Random Search\n",
        "        trend_best_solutions_RS, trend_cost_RS, found_solutions_RS, f_found_solutions_RS = random_search(zc_metric=zc_metric, max_eval=max_eval)\n",
        "        ## Evaluate best solutions obtained during the search (test_acc) (for analyzing)\n",
        "        trend_cost_RS = trend_cost_RS[:max_eval]\n",
        "\n",
        "        trend_best_solutions_RS = trend_best_solutions_RS[:max_eval]\n",
        "        trend_f_best_solution_RS = evaluate_trend_best_state(trend_best_solutions_RS)\n",
        "\n",
        "        ## Evaluate all states obtained during the search (test_acc) (for analyzing)\n",
        "        true_f_found_solutions_RS = evaluate_trend_best_state(found_solutions_RS[:max_eval])\n",
        "        true_best_f_RS.append(-max(true_f_found_solutions_RS))\n",
        "\n",
        "        all_trend_best_solutions_RS.append(trend_best_solutions_RS)\n",
        "        all_trend_f_best_solutions_RS.append(trend_f_best_solution_RS)\n",
        "        all_trend_cost_RS.append(trend_cost_RS)\n",
        "\n",
        "        ## Remove duplicates in the memory\n",
        "        found_solutions_RS, f_found_solutions_RS = found_solutions_RS[:max_eval], f_found_solutions_RS[:max_eval]\n",
        "        found_solutions_RS, ids = np.unique(found_solutions_RS, axis=0, return_index=True)\n",
        "        f_found_solutions_RS = np.array(f_found_solutions_RS)[ids]\n",
        "\n",
        "        ## Sort the fitness in descending order\n",
        "        ids = np.flip(np.argsort(f_found_solutions_RS))\n",
        "        found_solutions_RS = found_solutions_RS[ids]\n",
        "\n",
        "        ## Filter out top k% best architectures\n",
        "        topk_found_solutions = found_solutions_RS[:k]\n",
        "\n",
        "        ## Evaluate test_performance of top-k% (for analyzing)\n",
        "        f_topk = evaluate_trend_best_state(topk_found_solutions)\n",
        "        true_best_f_topk_SH.append(-max(f_topk))\n",
        "\n",
        "        # Stage 2: Selection the best one by Succesive Halving\n",
        "        best_solution_SH, total_cost_SH, total_epoch_SH = succesive_halving(topk_found_solutions, 'val_per', allowed_time - trend_cost_RS[-1], list_epochs)\n",
        "\n",
        "\n",
        "        f_best_solution_SH = evaluate_trend_best_state([best_solution_SH])[-1]\n",
        "        all_best_solution_SH.append(best_solution_SH)\n",
        "        all_f_best_solution_SH.append(f_best_solution_SH)\n",
        "        all_cost_SH.append(total_cost_SH)\n",
        "        all_total_epoch_SH.append(total_epoch_SH)\n",
        "\n",
        "        if verbose:\n",
        "            print('ID:', seed)\n",
        "            print(f'-> Best solution: {arch_int_to_vec(best_solution_SH)}')\n",
        "            print(f'-> Test PER: {-round(f_best_solution_SH*100, 2)}%')\n",
        "\n",
        "    all_trend_best_solutions_RS = np.array(all_trend_best_solutions_RS)\n",
        "    all_trend_f_best_solutions_RS = -np.array(all_trend_f_best_solutions_RS)\n",
        "    all_trend_cost_RS = np.array(all_trend_cost_RS)\n",
        "\n",
        "    all_best_solution_SH = np.array(all_best_solution_SH)\n",
        "    all_f_best_solution_SH = -np.array(all_f_best_solution_SH)\n",
        "    all_cost_SH = np.array(all_cost_SH)\n",
        "    all_total_epoch_SH = np.array(all_total_epoch_SH)\n",
        "\n",
        "    print(f'#Runs:', n_run)\n",
        "    print(f'- Stage 1 [Random Search (metric={zc_metric}, max_eval={max_eval})]:', np.round(np.mean(all_trend_f_best_solutions_RS, axis=0)[-1] * 100, 2), np.round(np.std(all_trend_f_best_solutions_RS, axis=0)[-1] * 100, 2))\n",
        "    print(f'- Stage 2 [Successive Halving (top-{k}, metric=val_per, budget={allowed_time} sec)]:', np.round(np.mean(all_f_best_solution_SH) * 100, 2), np.round(np.std(all_f_best_solution_SH) * 100, 2))\n",
        "    print(f'- Total Epochs: {int(np.mean(all_total_epoch_SH))}')\n",
        "    print(f'- Best visited solution (all):', np.round(np.mean(true_best_f_RS) * 100, 2), np.round(np.std(true_best_f_RS) * 100, 2))\n",
        "    print(f'- Best visited solution (top-{k}):', np.round(np.mean(true_best_f_topk_SH) * 100, 2), np.round(np.std(true_best_f_topk_SH) * 100, 2))\n",
        "\n",
        "    return all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH"
      ],
      "metadata": {
        "id": "WGnwPqQ7GhSC"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Run"
      ],
      "metadata": {
        "id": "hSQMBjfXGmM-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "n_run = 500\n",
        "max_eval = 2000  # Stage 1 (Local Search)\n",
        "list_epochs = [10, 20, 30, 40]  # First epoch\n",
        "k = 16  # Top-k\n",
        "verbose = False"
      ],
      "metadata": {
        "id": "hpRii4rrHZN5"
      },
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## MF-NAS"
      ],
      "metadata": {
        "id": "raSvfgFvNenQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "for zc_metric in ['synflow', 'params', 'FLOPs']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, k=k, n_run=500, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, verbose=verbose)\n",
        "    print('*'*100)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2brTPwyGHWee",
        "outputId": "838fcb88-4fc6-4182-a84f-c09f53d510d8"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: SYNFLOW\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=synflow, max_eval=2000)]: 21.79 0.0\n",
            "- Stage 2 [Successive Halving (top-16, metric=val_per, budget=20000 sec)]: 21.77 0.0\n",
            "- Total Epochs: 300\n",
            "- Best visited solution (all): 21.41 0.01\n",
            "- Best visited solution (top-16): 21.68 0.0\n",
            "****************************************************************************************************\n",
            "ZC METRIC: PARAMS\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=params, max_eval=2000)]: 37.18 26.1\n",
            "- Stage 2 [Successive Halving (top-16, metric=val_per, budget=20000 sec)]: 21.81 0.26\n",
            "- Total Epochs: 300\n",
            "- Best visited solution (all): 21.41 0.01\n",
            "- Best visited solution (top-16): 21.65 0.18\n",
            "****************************************************************************************************\n",
            "ZC METRIC: FLOPS\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=FLOPs, max_eval=2000)]: 88.4 0.0\n",
            "- Stage 2 [Successive Halving (top-16, metric=val_per, budget=20000 sec)]: 21.8 0.37\n",
            "- Total Epochs: 300\n",
            "- Best visited solution (all): 21.45 0.03\n",
            "- Best visited solution (top-16): 21.58 0.21\n",
            "****************************************************************************************************\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Random Search"
      ],
      "metadata": {
        "id": "iTug0uBqNgNc"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_per = []\n",
        "iepoch = 12\n",
        "max_budget = 25\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_per, total_time, _ = run_random_search(seed=run_id, iepoch=iepoch, max_eval_budget=max_budget)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {arch_int_to_vec(arch)}')\n",
        "        print(f'-> Test PER: {-test_per*100} %')\n",
        "    all_test_per.append(-test_per*100)\n",
        "print('Mean:', np.round(np.mean(all_test_per), 2), '  Std:', np.round(np.std(all_test_per), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oSjshST3Nnjs",
        "outputId": "187d94b4-29e3-4f9b-ea75-b3416907227b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 22.15   Std: 0.45\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Local Search"
      ],
      "metadata": {
        "id": "x2nkQ7tAPna_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_per = []\n",
        "iepoch = 12\n",
        "max_budget = 25\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_per, total_time, _ = run_ils(seed=run_id, iepoch=iepoch, max_budget=max_budget)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {arch_int_to_vec(arch)}')\n",
        "        print(f'-> Test PER: {-test_per*100} %')\n",
        "    all_test_per.append(-test_per*100)\n",
        "print('Mean:', np.round(np.mean(all_test_per), 2), '  Std:', np.round(np.std(all_test_per), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_SFUbNC-Pm9P",
        "outputId": "9e2241a3-1397-4400-f14f-ac6e3829bb76"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 22.51   Std: 2.85\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Successive Halving"
      ],
      "metadata": {
        "id": "dxKUII41QDf9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_per = []\n",
        "max_budget = 25\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_per, total_time, total_epoch = run_succesive_halving(seed=run_id, N=k, max_budget=max_budget, list_epochs=list_epochs)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {arch_int_to_vec(arch)}')\n",
        "        print(f'-> Test accuracy: {-test_per*100} %')\n",
        "    all_test_per.append(-test_per*100)\n",
        "print('Mean:', np.round(np.mean(all_test_per), 2), '  Std:', np.round(np.std(all_test_per), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hE4VrJZWQG1V",
        "outputId": "870f2950-6419-448b-96e3-78eda9b174a1"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 22.29   Std: 0.6\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## REA"
      ],
      "metadata": {
        "id": "ZkZH-qkHQp0l"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "population_size, tournament_size = 10, 10\n",
        "warm_up = False\n",
        "all_test_per = []\n",
        "iepoch = 12\n",
        "max_budget = 25\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_per, total_time, _ = run_evolution_search(seed=run_id, max_eval_budget=max_budget, iepoch=iepoch, population_size=population_size, tournament_size=tournament_size, warm_up=warm_up)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {arch_int_to_vec(arch)}')\n",
        "        print(f'-> Test accuracy: {-test_per*100} %')\n",
        "    all_test_per.append(-test_per*100)\n",
        "print('Mean:', np.round(np.mean(all_test_per), 2), '  Std:', np.round(np.std(all_test_per), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2nhHxYqmQqlW",
        "outputId": "1f0759f5-d67a-4cec-ba47-07eec6952f97"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 22.32   Std: 0.71\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## REA + W"
      ],
      "metadata": {
        "id": "EMBtSB0qUOtk"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "population_size, tournament_size = 10, 10\n",
        "warm_up = True\n",
        "n_warmup=2000\n",
        "iepoch = 12\n",
        "max_budget = 25\n",
        "all_test_per = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_per, total_time, _ = run_evolution_search(seed=run_id, max_eval_budget=max_budget, iepoch=iepoch, population_size=population_size, tournament_size=tournament_size, warm_up=warm_up, n_warmup=n_warmup)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {arch_int_to_vec(arch)}')\n",
        "        print(f'-> Test accuracy: {-test_per*100} %')\n",
        "    all_test_per.append(-test_per*100)\n",
        "print('Mean:', np.round(np.mean(all_test_per), 2), '  Std:', np.round(np.std(all_test_per), 2))"
      ],
      "metadata": {
        "id": "u6_xpFAVS7L1",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "1a7f42eb-03f4-4ee0-baaf-92792ec846c1"
      },
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 22.02   Std: 0.25\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Ablation Experiments"
      ],
      "metadata": {
        "id": "SRdqXLfRa2Ad"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## MF-NAS with other metrics"
      ],
      "metadata": {
        "id": "iaiptiOibE5o"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "for zc_metric in ['jacob_cov', 'plain', 'fisher', 'grad_norm', 'snip', 'l2_norm']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, verbose=verbose)\n",
        "    print('*'*100)"
      ],
      "metadata": {
        "id": "n2PRPYUNbDqt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Compare to Random Search"
      ],
      "metadata": {
        "id": "qeiYSnszbNIy"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "for zc_metric in ['synflow', 'params', 'FLOPs', 'jacob_cov', 'plain', 'fisher', 'grad_norm', 'snip', 'l2_norm']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = RS_SH(zc_metric=zc_metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, verbose=verbose)\n",
        "    print('*'*100)"
      ],
      "metadata": {
        "id": "8B6-YS4ZbBPK"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "YEEe3vfm7h4m",
        "Bp51Mc6_CoiJ",
        "SjHVUaYpCdef",
        "ll-JyaRDC-np",
        "h17YXq-2DBzB",
        "IC9UfYfxDGPq",
        "NdE_B2kDLikQ",
        "OOKQTE0k7lhY",
        "9daa-2JSDRPh",
        "8J36WdGYGfsQ",
        "SRdqXLfRa2Ad"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}