{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "This is the source codes of Multi-Fidelity Neural Architecture Search (MF-NAS) on NAS-Bench-201.\n",
        "\n",
        "The content is presented as:\n",
        "\n",
        "*   The utility functions are presented in Section 'Utilities'.\n",
        "*   Codes for loading benchmark databases are in Section 'Load Benchmark'.\n",
        "*   Codes of First-Improvement Local Search, Random Search, Successive Halving, MF-NAS are presented in the Section 'Algorithms' and Section 'MF-NAS'.\n",
        "*   Codes for running Random Search, Local Search, SH, REA, REA+W, and MF-NAS variants in Section 'Run'.\n",
        "\n",
        "All results (for NAS-Bench-201) in the paper have already presented in Section 'Run'.\n",
        "\n",
        "Executing all cells if you want to reproduce our results in the paper."
      ],
      "metadata": {
        "id": "SkRJT4K4dbis"
      },
      "id": "SkRJT4K4dbis"
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Utilities"
      ],
      "metadata": {
        "id": "njsh04MeTDnm"
      },
      "id": "njsh04MeTDnm"
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "f20ae2a9",
      "metadata": {
        "id": "f20ae2a9"
      },
      "outputs": [],
      "source": [
        "import copy\n",
        "import pickle as p\n",
        "import numpy as np\n",
        "import random\n",
        "from copy import deepcopy\n",
        "import itertools\n",
        "import json\n",
        "import os\n",
        "from scipy import stats\n",
        "import math\n",
        "\n",
        "OP_NAMES_NB201 = ['skip_connect', 'none', 'nor_conv_3x3', 'nor_conv_1x1', 'avg_pool_3x3']\n",
        "EDGE_LIST = ((1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4))\n",
        "available_ops = [0, 1, 2, 3, 4]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "821c6f87",
      "metadata": {
        "id": "821c6f87"
      },
      "outputs": [],
      "source": [
        "def get_indices(state, distance):\n",
        "    return list(itertools.combinations(range(len(state)), distance))\n",
        "\n",
        "def get_all_neighbors(state, ids):\n",
        "    list_neighbors = []\n",
        "    list_available_ops = []\n",
        "    for i in ids:\n",
        "        # Get all neighbors at the index-i (i in list of indices ids)\n",
        "        # In case of distance == 1, ids only has 1 index\n",
        "        _available_ops = available_ops.copy()\n",
        "        _available_ops.remove(state[i])\n",
        "        list_available_ops.append(_available_ops)\n",
        "    list_ops = list(itertools.product(*list_available_ops))\n",
        "    ids = np.array(ids)\n",
        "    for ops in list_ops:\n",
        "        neighbor = np.array(state).copy()\n",
        "        neighbor[ids] = ops\n",
        "        list_neighbors.append(neighbor.tolist())\n",
        "    np.random.shuffle(list_neighbors)\n",
        "    return list_neighbors\n",
        "\n",
        "def encode_int_list_2_ori_input(int_list):\n",
        "    list_ops = np.array(['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'])\n",
        "    list_int_ops = np.array(int_list)\n",
        "    list_str_ops = list_ops[list_int_ops]\n",
        "    return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.format(*list_str_ops)\n",
        "\n",
        "def convert_str_to_op_indices(str_encoding):\n",
        "    \"\"\"\n",
        "    Converts NB201 string representation to op_indices\n",
        "    \"\"\"\n",
        "    nodes = str_encoding.split('+')\n",
        "\n",
        "    def get_op(x):\n",
        "        return x.split('~')[0]\n",
        "\n",
        "    node_ops = [list(map(get_op, n.strip()[1:-1].split('|'))) for n in nodes]\n",
        "\n",
        "    enc = []\n",
        "    for u, v in EDGE_LIST:\n",
        "        enc.append(OP_NAMES_NB201.index(node_ops[v - 2][u - 1]))\n",
        "\n",
        "    return tuple(enc)\n",
        "\n",
        "def train_evaluate(state, metric):\n",
        "    hashKey = ''.join(map(str, state))\n",
        "    if metric == 'test_acc':\n",
        "        return benchmark['200'][hashKey]['test_acc'][-1], 0.0\n",
        "    else:\n",
        "        try:\n",
        "            iepoch = int(metric.split('_')[-1])\n",
        "            metric = '_'.join(metric.split('_')[:-1])\n",
        "            score = benchmark['200'][hashKey][metric][iepoch-1]\n",
        "            time = benchmark['200'][hashKey]['train_time'] * iepoch\n",
        "        except KeyError:\n",
        "            print(iepoch, metric, hashKey)\n",
        "        if 'loss' in metric:\n",
        "            score *= -1\n",
        "        return score, time\n",
        "\n",
        "def zc_evaluate(state, metric, dataset='cifar10'):\n",
        "    str_input = encode_int_list_2_ori_input(state)\n",
        "    op_indices = str(convert_str_to_op_indices(str_input))\n",
        "    if metric in ['flops', 'params']:\n",
        "        hashKey = ''.join(map(str, state))\n",
        "        if metric == 'flops':\n",
        "            return benchmark['200'][hashKey]['FLOPs'], zc_benchmark[dataset][op_indices]['flops']['time']\n",
        "        return benchmark['200'][hashKey]['params'], zc_benchmark[dataset][op_indices]['params']['time']\n",
        "    return zc_benchmark[dataset][op_indices][metric]['score'], zc_benchmark[dataset][op_indices][metric]['time']\n",
        "\n",
        "def evaluate(state, metric, dataset='cifar10'):\n",
        "    if 'acc' in metric or 'loss' in metric:\n",
        "        res = train_evaluate(state, metric)\n",
        "    else:\n",
        "        res = zc_evaluate(state, metric, dataset)\n",
        "    return res[0], res[1]\n",
        "\n",
        "def evaluate_trend_best_state(trend_best_state):\n",
        "    return [evaluate(state, 'test_acc')[0] for state in trend_best_state]\n",
        "\n",
        "def random_arch():\n",
        "    return np.random.choice(available_ops, 6)\n",
        "\n",
        "def mutate_spec(old_arch, mutation_rate=1.0):\n",
        "    new_arch = copy.deepcopy(old_arch)\n",
        "\n",
        "    op_mutation_prob = mutation_rate / len(old_arch)\n",
        "    for ind in range(len(old_arch)):\n",
        "        if random.random() < op_mutation_prob:\n",
        "            available = [o for o in available_ops if o != new_arch[ind]]\n",
        "            new_arch[ind] = random.choice(available)\n",
        "    return new_arch\n",
        "\n",
        "def random_combination(iterable, sample_size):\n",
        "    \"\"\"Random selection from itertools.combinations(iterable, r).\"\"\"\n",
        "    pool = tuple(iterable)\n",
        "    n = len(pool)\n",
        "    indices = sorted(random.sample(range(n), sample_size))\n",
        "    return tuple(pool[i] for i in indices)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load Benchmark"
      ],
      "metadata": {
        "id": "p3F54IGCn9BG"
      },
      "id": "p3F54IGCn9BG"
    },
    {
      "cell_type": "code",
      "source": [
        "!gdown https://drive.google.com/uc?id=1Ud4GF-3264R0rkqOmBK46L9S_kalN01z\n",
        "!gdown https://drive.google.com/uc?id=1mow6Cqwgs3DnKngF6reF7STsVa7KS15R\n",
        "!gdown https://drive.google.com/uc?id=1LSqvNkFzsRldW2kgCaSao9uCdYpN525H\n",
        "!gdown https://drive.google.com/uc?id=1Zbb_StAGgX6HYHHuoZWPY8PLb0DhDzKY"
      ],
      "metadata": {
        "id": "hCag7QKimI-K",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ca25bd6f-67be-4313-e941-9a81399af049"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1Ud4GF-3264R0rkqOmBK46L9S_kalN01z\n",
            "To: /content/[CIFAR-10]_data.p\n",
            "100% 383M/383M [00:07<00:00, 49.9MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1mow6Cqwgs3DnKngF6reF7STsVa7KS15R\n",
            "To: /content/[CIFAR-100]_data.p\n",
            "100% 266M/266M [00:05<00:00, 47.8MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1LSqvNkFzsRldW2kgCaSao9uCdYpN525H\n",
            "To: /content/[ImageNet16-120]_data.p\n",
            "100% 266M/266M [00:02<00:00, 89.7MB/s]\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1Zbb_StAGgX6HYHHuoZWPY8PLb0DhDzKY\n",
            "To: /content/zc_nasbench201.json\n",
            "100% 44.3M/44.3M [00:01<00:00, 29.2MB/s]\n"
          ]
        }
      ],
      "id": "hCag7QKimI-K"
    },
    {
      "cell_type": "code",
      "source": [
        "zc_benchmark = json.load(open('zc_nasbench201.json'))"
      ],
      "metadata": {
        "id": "aG7xBZtWn6tj"
      },
      "execution_count": 4,
      "outputs": [],
      "id": "aG7xBZtWn6tj"
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Algorithms"
      ],
      "metadata": {
        "id": "B5u0GCCMjV4R"
      },
      "id": "B5u0GCCMjV4R"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## First-improvement Local Search"
      ],
      "metadata": {
        "id": "PeH_GXxNjuIp"
      },
      "id": "PeH_GXxNjuIp"
    },
    {
      "cell_type": "code",
      "source": [
        "def first_improvement_ls(init_state, max_eval=np.inf, max_time=np.inf, metric=None, dataset=None):\n",
        "    # Stage 1: Zero-cost-proxy-guided local search\n",
        "    assert metric is not None, 'Missing the evaluation metric!'\n",
        "    assert init_state is not None, 'Missing the initial state!'\n",
        "    assert dataset is not None, 'Missing the dataset!'\n",
        "\n",
        "    n_eval = 0\n",
        "    total_time = 0\n",
        "    curr_state = init_state.copy()\n",
        "    f_curr_state, time = evaluate(init_state, metric, dataset)\n",
        "    total_time += time\n",
        "    n_eval += 1\n",
        "\n",
        "    best_state = curr_state.copy()\n",
        "    f_best_state = f_curr_state\n",
        "\n",
        "    trend_best_state = [best_state]\n",
        "    trend_time = [total_time]\n",
        "\n",
        "    state_history, f_state_history = [curr_state], [f_curr_state]\n",
        "    while (n_eval <= max_eval) and (total_time <= max_time):\n",
        "        improved = False\n",
        "        list_ids = get_indices(curr_state, 1)\n",
        "        while len(list_ids) != 0:\n",
        "            idx = np.random.choice(range(len(list_ids)))\n",
        "            ids = list_ids[idx]\n",
        "            list_ids.remove(list_ids[idx])\n",
        "\n",
        "            list_neighbors = get_all_neighbors(state=curr_state, ids=ids)\n",
        "            for neighbor in list_neighbors:\n",
        "                f_neighbor, time = evaluate(neighbor, metric, dataset)\n",
        "                state_history.append(neighbor)\n",
        "                f_state_history.append(f_neighbor)\n",
        "\n",
        "                total_time += time\n",
        "                trend_time.append(total_time)\n",
        "                n_eval += 1\n",
        "                # Update the current solution\n",
        "                if f_neighbor >= f_curr_state:\n",
        "                    curr_state = neighbor.copy()\n",
        "                    f_curr_state = f_neighbor\n",
        "\n",
        "                    # Update the best solution so far\n",
        "                    if f_neighbor > f_best_state:\n",
        "                        best_state = neighbor.copy()\n",
        "                        f_best_state = f_neighbor\n",
        "                    trend_best_state.append(best_state)\n",
        "                    improved = True\n",
        "                    break\n",
        "                else:\n",
        "                    trend_best_state.append(best_state)\n",
        "            if improved:\n",
        "                break\n",
        "\n",
        "        # If the current solution cannot be improved, the algorithm is stuck\n",
        "        # Therefore, we perform the escape operator.\n",
        "        if not improved:\n",
        "            list_ids = get_indices(curr_state, 2)\n",
        "            idx = np.random.choice(range(len(list_ids)))\n",
        "            ids = list_ids[idx]\n",
        "\n",
        "            list_neighbors = get_all_neighbors(curr_state, ids)\n",
        "            curr_state = list_neighbors[np.random.choice(len(list_neighbors))]\n",
        "\n",
        "            f_curr_state, time = evaluate(curr_state, metric, dataset)\n",
        "            state_history.append(curr_state)\n",
        "            f_state_history.append(f_curr_state)\n",
        "            total_time += time\n",
        "            n_eval += 1\n",
        "\n",
        "            if f_curr_state > f_best_state:\n",
        "                best_state = curr_state.copy()\n",
        "                f_best_state = f_curr_state\n",
        "            trend_best_state.append(best_state)\n",
        "            trend_time.append(total_time)\n",
        "    return trend_best_state, trend_time, state_history, f_state_history\n"
      ],
      "metadata": {
        "id": "zA97AwKQjXSx"
      },
      "id": "zA97AwKQjXSx",
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Random Search"
      ],
      "metadata": {
        "id": "thcVugGOjzPz"
      },
      "id": "thcVugGOjzPz"
    },
    {
      "cell_type": "code",
      "source": [
        "def random_search(search_space, max_eval, zc_metric=None, dataset=None):\n",
        "    assert zc_metric is not None, 'Missing the evaluation metric!'\n",
        "    assert dataset is not None, 'Missing the dataset!'\n",
        "\n",
        "    total_time = 0\n",
        "\n",
        "    list_state = np.random.choice(search_space, max_eval)\n",
        "\n",
        "    # Evaluate the first individual\n",
        "    state = list(map(int, list(list_state[0])))\n",
        "\n",
        "    f_state, time = evaluate(state, zc_metric, dataset)\n",
        "    total_time += time\n",
        "\n",
        "    best_state, f_best_state = state.copy(), f_state\n",
        "    trend_best_state, trend_time = [best_state], [total_time]\n",
        "    state_history, f_state_history = [state], [f_state]\n",
        "\n",
        "    # For each next individual, evaluate and compare to the best individual so far.\n",
        "    for state in list_state[1:]:\n",
        "        state = list(map(int, list(state)))\n",
        "        f_state, time = evaluate(state, metric=zc_metric, dataset=dataset)\n",
        "        total_time += time\n",
        "        state_history.append(state)\n",
        "        f_state_history.append(f_state)\n",
        "\n",
        "        ## If current state is better than the best current state, replace\n",
        "        if f_state > f_best_state:\n",
        "            best_state = state.copy()\n",
        "            f_best_state = f_state\n",
        "        trend_best_state.append(best_state)\n",
        "        trend_time.append(total_time)\n",
        "    return trend_best_state, trend_time, state_history, f_state_history"
      ],
      "metadata": {
        "id": "Bz9Kdooaj0e_"
      },
      "id": "Bz9Kdooaj0e_",
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Successive Halving"
      ],
      "metadata": {
        "id": "0Y2qS2Qmj2Sz"
      },
      "id": "0Y2qS2Qmj2Sz"
    },
    {
      "cell_type": "code",
      "source": [
        "def succesive_halving(list_candidates, metric, max_budget, dataset, list_epochs):\n",
        "    assert len(list_epochs) != 0, 'Have not set the checkpoints for evaluation!'\n",
        "    checkpoint = 0\n",
        "    iepoch = list_epochs[checkpoint]\n",
        "\n",
        "    total_time, total_epoch = 0.0, 0\n",
        "    list_candidates = np.array(list_candidates)\n",
        "\n",
        "    last_iepoch = 0\n",
        "    best_candidate, f_best_candidate, best_candidate_rank = None, -np.inf, None\n",
        "\n",
        "    true_performance = -1 * np.array(evaluate_trend_best_state(list_candidates))\n",
        "    order = true_performance.argsort()\n",
        "    true_rank = order.argsort() + 1\n",
        "\n",
        "    last = False\n",
        "    while total_time < max_budget:\n",
        "        evaluated_candidates = []\n",
        "\n",
        "        f_candidates = []\n",
        "        for i, candidate in enumerate(list_candidates):\n",
        "            score, time = evaluate(candidate, f'{metric}_{int(iepoch)}', dataset)\n",
        "\n",
        "            diff_epoch = int(iepoch) - last_iepoch\n",
        "            total_time += time / int(iepoch) * diff_epoch\n",
        "            total_epoch += diff_epoch\n",
        "\n",
        "            evaluated_candidates.append(candidate)\n",
        "            f_candidates.append(score)\n",
        "\n",
        "            if total_time >= max_budget:\n",
        "                evaluated_candidates = evaluated_candidates[:-1]\n",
        "                f_candidates = f_candidates[:-1]\n",
        "                total_time -= time / int(iepoch) * diff_epoch\n",
        "                total_epoch -= diff_epoch\n",
        "\n",
        "                return best_candidate, total_time, total_epoch, best_candidate_rank\n",
        "\n",
        "            if score > f_best_candidate:\n",
        "                f_best_candidate = score\n",
        "                best_candidate = candidate\n",
        "                best_candidate_rank = true_rank[i]\n",
        "\n",
        "        ids = np.flip(np.argsort(f_candidates))\n",
        "        list_candidates = np.array(evaluated_candidates)[ids]\n",
        "        list_candidates = list_candidates[:math.ceil(len(list_candidates) / 2)]\n",
        "\n",
        "        true_rank = true_rank[ids]\n",
        "        true_rank = true_rank[:math.ceil(len(f_candidates) / 2)]\n",
        "        if len(list_candidates) == 1 or last:\n",
        "            return best_candidate, total_time, total_epoch, best_candidate_rank\n",
        "\n",
        "        checkpoint += 1\n",
        "        last_iepoch = int(iepoch)\n",
        "        iepoch = list_epochs[checkpoint]\n",
        "        if iepoch == list_epochs[-1]:\n",
        "            last = True"
      ],
      "metadata": {
        "id": "18ppuaQqj4Js"
      },
      "id": "18ppuaQqj4Js",
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Others"
      ],
      "metadata": {
        "id": "rDPFwN6xjEn1"
      },
      "id": "rDPFwN6xjEn1"
    },
    {
      "cell_type": "code",
      "source": [
        "def run_warm_up(n_sample, k):\n",
        "    list_arch, list_scores = [], []\n",
        "    for _ in range(n_sample):\n",
        "        arch = random_arch()\n",
        "        score, _ = zc_evaluate(arch, 'synflow', dataset)\n",
        "        list_arch.append(arch)\n",
        "        list_scores.append(score)\n",
        "    list_arch = np.array(list_arch)\n",
        "    list_scores = np.array(list_scores)\n",
        "    ids = np.flip(np.argsort(list_scores))\n",
        "    list_arch = list_arch[ids]\n",
        "    return list_arch[:k]\n",
        "\n",
        "def run_random_search(seed=0, max_time_budget=5e6, iepoch=200):\n",
        "    \"\"\"Run a single roll-out of random search to a fixed time budget.\"\"\"\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "\n",
        "    times, best_valids, best_archs = [0.0], [0.0], [None]\n",
        "    while True:\n",
        "        arch = random_arch()\n",
        "        val_acc, time = evaluate(arch, f'val_acc_{iepoch}')\n",
        "\n",
        "        if val_acc > best_valids[-1]:\n",
        "            best_valids.append(val_acc)\n",
        "            best_archs.append(arch)\n",
        "        else:\n",
        "            best_valids.append(best_valids[-1])\n",
        "            best_archs.append(best_archs[-1])\n",
        "\n",
        "        times.append(time)\n",
        "        if sum(times) > max_time_budget:\n",
        "            # Break the first time we exceed the budget.\n",
        "            best_candidate = best_archs[-1]\n",
        "            f_best_candidate = evaluate(best_candidate, 'test_acc')[0]\n",
        "            return best_candidate, f_best_candidate, sum(times), None\n",
        "\n",
        "def run_evolution_search(seed=0, max_time_budget=5e6, population_size=20,\n",
        "                         tournament_size=4, mutation_rate=1.0, iepoch=200, warm_up=False, n_warmup=2000):\n",
        "    \"\"\"Run a single roll-out of regularized evolution to a fixed time budget.\"\"\"\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "\n",
        "    times, best_valids = [0.0], [0.0]\n",
        "    population = []   # (validation, spec) tuples\n",
        "    trend_time = [0.0]\n",
        "    best_test_arch = None\n",
        "    # For the first population_size individuals, seed the population with randomly\n",
        "    # generated cells.\n",
        "\n",
        "    if not warm_up:\n",
        "        list_arch = [random_arch() for _ in range(population_size)]\n",
        "    else:\n",
        "        list_arch = run_warm_up(n_warmup, population_size)\n",
        "    for arch in list_arch:\n",
        "        val_acc, time = evaluate(arch, f'val_acc_{iepoch}')\n",
        "\n",
        "        times.append(time)\n",
        "        trend_time.append(sum(times))\n",
        "        population.append((val_acc, arch))\n",
        "\n",
        "        if val_acc > best_valids[-1]:\n",
        "            best_valids.append(val_acc)\n",
        "            best_test_arch = arch\n",
        "        else:\n",
        "            best_valids.append(best_valids[-1])\n",
        "\n",
        "        if sum(times) > max_time_budget:\n",
        "            best_candidate = best_test_arch\n",
        "            f_best_candidate = evaluate(best_candidate, 'test_acc')[0]\n",
        "            return best_candidate, f_best_candidate, trend_time[-1], None\n",
        "\n",
        "    # After the population is seeded, proceed with evolving the population.\n",
        "    while True:\n",
        "        sample = random_combination(population, tournament_size)\n",
        "        best_arch = sorted(sample, key=lambda i:i[0])[-1][1]\n",
        "        new_arch = mutate_spec(best_arch, mutation_rate)\n",
        "\n",
        "        val_acc, time = evaluate(new_arch, f'val_acc_{iepoch}')\n",
        "        times.append(time)\n",
        "        trend_time.append(sum(times))\n",
        "\n",
        "        # In regularized evolution, we kill the oldest individual in the population.\n",
        "        population.append((val_acc, new_arch))\n",
        "        population.pop(0)\n",
        "\n",
        "        if val_acc > best_valids[-1]:\n",
        "            best_valids.append(val_acc)\n",
        "            best_test_arch = new_arch\n",
        "        else:\n",
        "            best_valids.append(best_valids[-1])\n",
        "\n",
        "        if sum(times) > max_time_budget:\n",
        "            best_candidate = best_test_arch\n",
        "            f_best_candidate = evaluate(best_candidate, 'test_acc')[0]\n",
        "            return best_candidate, f_best_candidate, trend_time[-1], None\n",
        "\n",
        "def run_ils(seed, max_budget, iepoch, dataset):\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "    init_state = random_arch()\n",
        "    trend_best_state, trend_time, _, _ = first_improvement_ls(init_state, max_time=max_budget, metric=f'val_acc_{iepoch}', dataset=dataset)\n",
        "    best_candidate = trend_best_state[-1]\n",
        "    f_best_candidate = evaluate(best_candidate, 'test_acc')[0]\n",
        "    return best_candidate, f_best_candidate, trend_time[-1], 0.0\n",
        "\n",
        "def run_succesive_halving(seed, N, max_budget, dataset, list_epochs):\n",
        "    np.random.seed(seed)\n",
        "    random.seed(seed)\n",
        "    list_candidates = [random_arch() for _ in range(N)]\n",
        "    best_candidate, total_time, total_epoch, _ = succesive_halving(list_candidates, 'val_acc', max_budget, dataset, list_epochs)\n",
        "    f_best_candidate = evaluate(best_candidate, 'test_acc')[0]\n",
        "    return best_candidate, f_best_candidate, total_time, total_epoch"
      ],
      "metadata": {
        "id": "gIfnDQrYjF_y"
      },
      "id": "gIfnDQrYjF_y",
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# MF-NAS"
      ],
      "metadata": {
        "id": "X37x-iSgqTHY"
      },
      "id": "X37x-iSgqTHY"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Local Search"
      ],
      "metadata": {
        "id": "XQ72nto5o4_T"
      },
      "id": "XQ72nto5o4_T"
    },
    {
      "cell_type": "code",
      "source": [
        "def ILS_SH(zc_metric, metric, k, n_run, max_eval, allowed_time, list_epochs, dataset, verbose=False):\n",
        "    all_trend_best_solutions_LS, all_trend_f_best_solutions_LS, all_trend_cost_LS = [], [], []\n",
        "    true_best_f_LS = []\n",
        "\n",
        "    all_best_solution_SH, all_f_best_solution_SH = [], []\n",
        "    all_cost_SH = []\n",
        "    all_total_epoch_SH = []\n",
        "    true_best_f_topk_SH = []\n",
        "\n",
        "    list_allowed_budget = []\n",
        "\n",
        "    list_rank = []\n",
        "    for seed in range(1, n_run + 1):\n",
        "        np.random.seed(seed)\n",
        "        random.seed(seed)\n",
        "        init_solution = list(map(int, list(np.random.choice(all_sol))))\n",
        "\n",
        "        # Stage 1: Explore search space with Zero-cost-proxy-guided Local Search\n",
        "        trend_best_solutions_LS, trend_cost_LS, found_solutions_LS, f_found_solutions_LS = first_improvement_ls(init_solution, metric=zc_metric, max_eval=max_eval, dataset=dataset)\n",
        "\n",
        "        ## Evaluate best solutions obtained during the search (test_acc) (for analyzing)\n",
        "        trend_cost_LS = trend_cost_LS[:max_eval]\n",
        "\n",
        "        trend_best_solutions_LS = trend_best_solutions_LS[:max_eval]\n",
        "        trend_f_best_solution_LS = evaluate_trend_best_state(trend_best_solutions_LS)\n",
        "\n",
        "        ## Evaluate all states obtained during the search (test_acc) (for analyzing)\n",
        "        true_f_found_solutions_LS = evaluate_trend_best_state(found_solutions_LS[:max_eval])\n",
        "        true_best_f_LS.append(max(true_f_found_solutions_LS))\n",
        "\n",
        "        all_trend_best_solutions_LS.append(trend_best_solutions_LS)\n",
        "        all_trend_f_best_solutions_LS.append(trend_f_best_solution_LS)\n",
        "        all_trend_cost_LS.append(trend_cost_LS)\n",
        "\n",
        "        ## Remove duplicates in the memory\n",
        "        found_solutions_LS, f_found_solutions_LS = found_solutions_LS[:max_eval], f_found_solutions_LS[:max_eval]\n",
        "        found_solutions_LS, ids = np.unique(found_solutions_LS, axis=0, return_index=True)\n",
        "        f_found_solutions_LS = np.array(f_found_solutions_LS)[ids]\n",
        "\n",
        "        ## Sort the fitness in descending order\n",
        "        ids = np.flip(np.argsort(f_found_solutions_LS))\n",
        "        found_solutions_LS = found_solutions_LS[ids]\n",
        "\n",
        "        ## Filter out top k% best architectures\n",
        "        topk_found_solutions = found_solutions_LS[:k]\n",
        "\n",
        "        ## Evaluate test_performance of top-k% (for analyzing)\n",
        "        f_topk = evaluate_trend_best_state(topk_found_solutions)\n",
        "        true_best_f_topk_SH.append(max(f_topk))\n",
        "\n",
        "        # Stage 2: Selection the best one by Succesive Halving\n",
        "        best_solution_SH, total_cost_SH, total_epoch_SH, best_candidate_rank = succesive_halving(topk_found_solutions, metric, allowed_time - trend_cost_LS[-1], dataset, list_epochs)\n",
        "\n",
        "        f_best_solution_SH = evaluate_trend_best_state([best_solution_SH])[-1]\n",
        "        all_best_solution_SH.append(best_solution_SH)\n",
        "        all_f_best_solution_SH.append(f_best_solution_SH)\n",
        "        all_cost_SH.append(total_cost_SH)\n",
        "        all_total_epoch_SH.append(total_epoch_SH)\n",
        "\n",
        "        if verbose:\n",
        "            print('ID:', seed)\n",
        "            print(f'-> Best solution: {encode_int_list_2_ori_input(best_solution_SH)}')\n",
        "            print(f'-> Test accuracy: {f_best_solution_SH*100} %')\n",
        "\n",
        "        list_rank.append(best_candidate_rank)\n",
        "\n",
        "    all_trend_best_solutions_LS = np.array(all_trend_best_solutions_LS)\n",
        "    all_trend_f_best_solutions_LS = np.array(all_trend_f_best_solutions_LS)\n",
        "    all_trend_cost_LS = np.array(all_trend_cost_LS)\n",
        "\n",
        "    all_best_solution_SH = np.array(all_best_solution_SH)\n",
        "    all_f_best_solution_SH = np.array(all_f_best_solution_SH)\n",
        "    all_cost_SH = np.array(all_cost_SH)\n",
        "    all_total_epoch_SH = np.array(all_total_epoch_SH)\n",
        "\n",
        "    print('-'*100)\n",
        "    print(f'Dataset:', dataset)\n",
        "    print(f'#Runs:', n_run)\n",
        "    print(f'- Stage 1 [ILS (metric={zc_metric}, max_eval={max_eval})]:', np.round(np.mean(all_trend_f_best_solutions_LS, axis=0)[-1] * 100, 2), np.round(np.std(all_trend_f_best_solutions_LS, axis=0)[-1] * 100, 2))\n",
        "    print(f'- Stage 2 [Successive Halving (top-{k}, metric={metric}, budget={max_budget} sec)]:', np.round(np.mean(all_f_best_solution_SH) * 100, 2), np.round(np.std(all_f_best_solution_SH) * 100, 2))\n",
        "    print(f'- Total Cost: {int(np.mean(all_trend_cost_LS, axis=0)[-1])} + {int(np.mean(all_cost_SH))} = {int(np.mean(all_trend_cost_LS, axis=0)[-1]) + int(np.mean(all_cost_SH))} seconds')\n",
        "    print(f'- Total Epochs: {int(np.mean(all_total_epoch_SH))}')\n",
        "    print(f'- Best visited solution (all):', np.round(np.mean(true_best_f_LS) * 100, 2), np.round(np.std(true_best_f_LS) * 100, 2))\n",
        "    print(f'- Best visited solution (top-{k}):', np.round(np.mean(true_best_f_topk_SH) * 100, 2), np.round(np.std(true_best_f_topk_SH) * 100, 2))\n",
        "    print(f'- Average Ranked of SH selection:', np.round(np.mean(list_rank), 2))\n",
        "\n",
        "    return all_best_solution_SH, all_f_best_solution_SH, all_cost_SH,  all_total_epoch_SH"
      ],
      "metadata": {
        "id": "Oxj6CvVhutvd"
      },
      "id": "Oxj6CvVhutvd",
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Random Search"
      ],
      "metadata": {
        "id": "wvcbOh1VpBg3"
      },
      "id": "wvcbOh1VpBg3"
    },
    {
      "cell_type": "code",
      "source": [
        "def RS_SH(zc_metric, metric, k, n_run, max_eval, allowed_time, list_epochs, dataset, verbose=False):\n",
        "    all_trend_best_solutions_RS, all_trend_f_best_solutions_RS, all_trend_cost_RS = [], [], []\n",
        "    true_best_f_RS = []\n",
        "\n",
        "    all_best_solution_SH, all_f_best_solution_SH = [], []\n",
        "    all_cost_SH = []\n",
        "    all_total_epoch_SH = []\n",
        "    true_best_f_topk_SH = []\n",
        "\n",
        "    for seed in range(1, n_run + 1):\n",
        "        np.random.seed(seed)\n",
        "        random.seed(seed)\n",
        "\n",
        "        # Stage 1: Explore search space with Zero-cost Random Search\n",
        "        trend_best_solutions_RS, trend_cost_RS, found_solutions_RS, f_found_solutions_RS = random_search(all_sol, zc_metric=zc_metric, max_eval=max_eval, dataset=dataset)\n",
        "\n",
        "        ## Evaluate best solutions obtained during the search (test_acc) (for analyzing)\n",
        "        trend_cost_RS = trend_cost_RS[:max_eval]\n",
        "\n",
        "        trend_best_solutions_RS = trend_best_solutions_RS[:max_eval]\n",
        "        trend_f_best_solution_RS = evaluate_trend_best_state(trend_best_solutions_RS)\n",
        "\n",
        "        ## Evaluate all states obtained during the search (test_acc) (for analyzing)\n",
        "        true_f_found_solutions_RS = evaluate_trend_best_state(found_solutions_RS[:max_eval])\n",
        "        true_best_f_RS.append(max(true_f_found_solutions_RS))\n",
        "\n",
        "        all_trend_best_solutions_RS.append(trend_best_solutions_RS)\n",
        "        all_trend_f_best_solutions_RS.append(trend_f_best_solution_RS)\n",
        "        all_trend_cost_RS.append(trend_cost_RS)\n",
        "\n",
        "        ## Remove duplicates in the memory\n",
        "        found_solutions_RS, f_found_solutions_RS = found_solutions_RS[:max_eval], f_found_solutions_RS[:max_eval]\n",
        "        found_solutions_RS, ids = np.unique(found_solutions_RS, axis=0, return_index=True)\n",
        "        f_found_solutions_RS = np.array(f_found_solutions_RS)[ids]\n",
        "\n",
        "        ## Sort the fitness in descending order\n",
        "        ids = np.flip(np.argsort(f_found_solutions_RS))\n",
        "        found_solutions_RS = found_solutions_RS[ids]\n",
        "\n",
        "        ## Filter out top k% best architectures\n",
        "        topk_found_solutions = found_solutions_RS[:k]\n",
        "\n",
        "        ## Evaluate test_performance of top-k% (for analyzing)\n",
        "        f_topk = evaluate_trend_best_state(topk_found_solutions)\n",
        "        true_best_f_topk_SH.append(max(f_topk))\n",
        "\n",
        "        # Stage 2: Selection the best one by Succesive Halving\n",
        "        best_solution_SH, total_cost_SH, total_epoch_SH, best_candidate_rank = succesive_halving(topk_found_solutions, metric, allowed_time - trend_cost_RS[-1], dataset, list_epochs)\n",
        "\n",
        "        f_best_solution_SH = evaluate_trend_best_state([best_solution_SH])[-1]\n",
        "        all_best_solution_SH.append(best_solution_SH)\n",
        "        all_f_best_solution_SH.append(f_best_solution_SH)\n",
        "        all_cost_SH.append(total_cost_SH)\n",
        "        all_total_epoch_SH.append(total_epoch_SH)\n",
        "\n",
        "        if verbose:\n",
        "            print('ID:', seed)\n",
        "            print(f'-> Best solution: {encode_int_list_2_ori_input(best_solution_SH)}')\n",
        "            print(f'-> Test accuracy: {f_best_solution_SH*100} %')\n",
        "\n",
        "    all_trend_best_solutions_RS = np.array(all_trend_best_solutions_RS)\n",
        "    all_trend_f_best_solutions_RS = np.array(all_trend_f_best_solutions_RS)\n",
        "    all_trend_cost_RS = np.array(all_trend_cost_RS)\n",
        "\n",
        "    all_best_solution_SH = np.array(all_best_solution_SH)\n",
        "    all_f_best_solution_SH = np.array(all_f_best_solution_SH)\n",
        "    all_cost_SH = np.array(all_cost_SH)\n",
        "    all_total_epoch_SH = np.array(all_total_epoch_SH)\n",
        "\n",
        "    print(f'Dataset:', dataset)\n",
        "    print(f'#Runs:', n_run)\n",
        "    print(f'- Stage 1 [Random Search (metric={zc_metric}, max_eval={max_eval})]:', np.round(np.mean(all_trend_f_best_solutions_RS, axis=0)[-1] * 100, 2), np.round(np.std(all_trend_f_best_solutions_RS, axis=0)[-1] * 100, 2))\n",
        "    print(f'- Stage 2 [Successive Halving (top-{k}, metric={metric}, budget={max_budget} sec)]:', np.round(np.mean(all_f_best_solution_SH) * 100, 2), np.round(np.std(all_f_best_solution_SH) * 100, 2))\n",
        "    print(f'- Total Cost: {int(np.mean(all_trend_cost_RS, axis=0)[-1])} + {int(np.mean(all_cost_SH))} = {int(np.mean(all_trend_cost_RS, axis=0)[-1]) + int(np.mean(all_cost_SH))} seconds')\n",
        "    print(f'- Total Epochs: {int(np.mean(all_total_epoch_SH))}')\n",
        "    print(f'- Best visited solution (all):', np.round(np.mean(true_best_f_RS) * 100, 2), np.round(np.std(true_best_f_RS) * 100, 2))\n",
        "    print(f'- Best visited solution (top-{k}):', np.round(np.mean(true_best_f_topk_SH) * 100, 2), np.round(np.std(true_best_f_topk_SH) * 100, 2))\n",
        "\n",
        "    return all_best_solution_SH, all_f_best_solution_SH, all_cost_SH,  all_total_epoch_SH"
      ],
      "metadata": {
        "id": "1wuFabESo9cA"
      },
      "id": "1wuFabESo9cA",
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Run"
      ],
      "metadata": {
        "id": "bVJW2l3_s9rp"
      },
      "id": "bVJW2l3_s9rp"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## CIFAR-10"
      ],
      "metadata": {
        "id": "6oLoPYXMuyOa"
      },
      "id": "6oLoPYXMuyOa"
    },
    {
      "cell_type": "code",
      "source": [
        "benchmark = p.load(open('[CIFAR-10]_data.p', 'rb'))"
      ],
      "metadata": {
        "id": "VH76R8K8vAIN"
      },
      "id": "VH76R8K8vAIN",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "dataset = 'cifar10'\n",
        "n_run = 500\n",
        "max_eval = 2000  # Stage 1 (Local Search)\n",
        "max_budget = 20000  # Stage 2 (SH, seconds)\n",
        "list_epochs = [12, 25, 50, 100, 200]  # First epoch\n",
        "k = 32  # Top-k\n",
        "all_sol = list(benchmark['200'].keys())\n",
        "verbose = False"
      ],
      "metadata": {
        "id": "xFChor0ivD52"
      },
      "id": "xFChor0ivD52",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### MF-NAS (Local Search)"
      ],
      "metadata": {
        "id": "aMwIq2lXtAIP"
      },
      "id": "aMwIq2lXtAIP"
    },
    {
      "cell_type": "code",
      "source": [
        "metric = 'val_acc'\n",
        "for zc_metric in ['jacov', 'plain', 'grasp', 'fisher', 'epe_nas', 'grad_norm', 'snip', 'synflow', 'l2_norm', 'zen', 'nwot', 'params', 'flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, metric=metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, dataset='cifar10', verbose=verbose)\n",
        "    print()"
      ],
      "metadata": {
        "id": "NvU3W-daq50P",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "63f65a7d-1834-4c60-cfe4-7c8ae18f7238"
      },
      "id": "NvU3W-daq50P",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: JACOV\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=jacov, max_eval=2000)]: 91.93 1.46\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.63 0.43\n",
            "- Total Cost: 2054 + 17333 = 19387 seconds\n",
            "- Total Epochs: 825\n",
            "- Best visited solution (all): 94.31 0.07\n",
            "- Best visited solution (top-32): 93.82 0.2\n",
            "- Average Ranked of SH selection: 2.18\n",
            "\n",
            "ZC METRIC: PLAIN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=plain, max_eval=2000)]: 88.07 4.48\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 92.11 0.96\n",
            "- Total Cost: 2083 + 17351 = 19434 seconds\n",
            "- Total Epochs: 940\n",
            "- Best visited solution (all): 94.06 0.16\n",
            "- Best visited solution (top-32): 92.49 0.59\n",
            "- Average Ranked of SH selection: 2.32\n",
            "\n",
            "ZC METRIC: GRASP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grasp, max_eval=2000)]: 83.87 0.47\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 92.77 0.32\n",
            "- Total Cost: 7018 + 12824 = 19842 seconds\n",
            "- Total Epochs: 562\n",
            "- Best visited solution (all): 94.33 0.07\n",
            "- Best visited solution (top-32): 93.11 0.11\n",
            "- Average Ranked of SH selection: 2.39\n",
            "\n",
            "ZC METRIC: FISHER\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=fisher, max_eval=2000)]: 83.82 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 90.98 1.07\n",
            "- Total Cost: 2135 + 17604 = 19739 seconds\n",
            "- Total Epochs: 785\n",
            "- Best visited solution (all): 94.26 0.09\n",
            "- Best visited solution (top-32): 92.35 0.47\n",
            "- Average Ranked of SH selection: 15.26\n",
            "\n",
            "ZC METRIC: EPE_NAS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=epe_nas, max_eval=2000)]: 91.48 1.96\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.4 0.61\n",
            "- Total Cost: 2027 + 17459 = 19486 seconds\n",
            "- Total Epochs: 871\n",
            "- Best visited solution (all): 94.31 0.09\n",
            "- Best visited solution (top-32): 93.89 0.21\n",
            "- Average Ranked of SH selection: 5.7\n",
            "\n",
            "ZC METRIC: GRAD_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grad_norm, max_eval=2000)]: 84.07 1.09\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 92.58 0.44\n",
            "- Total Cost: 2080 + 17583 = 19663 seconds\n",
            "- Total Epochs: 757\n",
            "- Best visited solution (all): 94.31 0.07\n",
            "- Best visited solution (top-32): 93.13 0.06\n",
            "- Average Ranked of SH selection: 2.07\n",
            "\n",
            "ZC METRIC: SNIP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=snip, max_eval=2000)]: 83.82 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.13 0.0\n",
            "- Total Cost: 2093 + 17525 = 19618 seconds\n",
            "- Total Epochs: 750\n",
            "- Best visited solution (all): 94.31 0.07\n",
            "- Best visited solution (top-32): 93.13 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "\n",
            "ZC METRIC: SYNFLOW\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=synflow, max_eval=2000)]: 93.76 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 94.36 0.0\n",
            "- Total Cost: 1459 + 18175 = 19634 seconds\n",
            "- Total Epochs: 617\n",
            "- Best visited solution (all): 94.37 0.0\n",
            "- Best visited solution (top-32): 94.37 0.0\n",
            "- Average Ranked of SH selection: 1.95\n",
            "\n",
            "ZC METRIC: L2_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=l2_norm, max_eval=2000)]: 92.93 0.06\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.32 0.03\n",
            "- Total Cost: 408 + 19490 = 19898 seconds\n",
            "- Total Epochs: 616\n",
            "- Best visited solution (all): 94.37 0.01\n",
            "- Best visited solution (top-32): 93.76 0.02\n",
            "- Average Ranked of SH selection: 14.74\n",
            "\n",
            "ZC METRIC: ZEN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=zen, max_eval=2000)]: 90.64 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 88.35 0.55\n",
            "- Total Cost: 1533 + 17724 = 19257 seconds\n",
            "- Total Epochs: 1042\n",
            "- Best visited solution (all): 94.37 0.02\n",
            "- Best visited solution (top-32): 90.74 0.0\n",
            "- Average Ranked of SH selection: 20.8\n",
            "\n",
            "ZC METRIC: NWOT\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=nwot, max_eval=2000)]: 93.32 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.35 0.07\n",
            "- Total Cost: 1999 + 17723 = 19722 seconds\n",
            "- Total Epochs: 566\n",
            "- Best visited solution (all): 94.37 0.01\n",
            "- Best visited solution (top-32): 93.89 0.02\n",
            "- Average Ranked of SH selection: 14.14\n",
            "\n",
            "ZC METRIC: PARAMS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=params, max_eval=2000)]: 93.76 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 94.36 0.0\n",
            "- Total Cost: 760 + 18695 = 19455 seconds\n",
            "- Total Epochs: 592\n",
            "- Best visited solution (all): 94.37 0.0\n",
            "- Best visited solution (top-32): 94.36 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "\n",
            "ZC METRIC: FLOPS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=flops, max_eval=2000)]: 93.76 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 94.36 0.0\n",
            "- Total Cost: 765 + 18696 = 19461 seconds\n",
            "- Total Epochs: 592\n",
            "- Best visited solution (all): 94.37 0.0\n",
            "- Best visited solution (top-32): 94.36 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "metric = 'train_loss'\n",
        "for zc_metric in ['jacov', 'plain', 'grasp', 'fisher', 'epe_nas', 'grad_norm', 'snip', 'synflow', 'l2_norm', 'zen', 'nwot', 'params', 'flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, metric=metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, dataset=dataset, verbose=verbose)\n",
        "    print()"
      ],
      "metadata": {
        "id": "pQC3trXdspq5",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "333d3b5a-4aae-456d-cadc-cf70bc60c1fb"
      },
      "id": "pQC3trXdspq5",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: JACOV\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=jacov, max_eval=2000)]: 91.93 1.46\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 93.78 0.23\n",
            "- Total Cost: 2054 + 17324 = 19378 seconds\n",
            "- Total Epochs: 820\n",
            "- Best visited solution (all): 94.31 0.07\n",
            "- Best visited solution (top-32): 93.82 0.2\n",
            "- Average Ranked of SH selection: 1.31\n",
            "\n",
            "ZC METRIC: PLAIN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=plain, max_eval=2000)]: 88.07 4.48\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 92.42 0.63\n",
            "- Total Cost: 2083 + 17345 = 19428 seconds\n",
            "- Total Epochs: 918\n",
            "- Best visited solution (all): 94.06 0.16\n",
            "- Best visited solution (top-32): 92.49 0.59\n",
            "- Average Ranked of SH selection: 1.4\n",
            "\n",
            "ZC METRIC: GRASP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grasp, max_eval=2000)]: 83.87 0.47\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 93.09 0.12\n",
            "- Total Cost: 7018 + 12836 = 19854 seconds\n",
            "- Total Epochs: 554\n",
            "- Best visited solution (all): 94.33 0.07\n",
            "- Best visited solution (top-32): 93.11 0.11\n",
            "- Average Ranked of SH selection: 1.13\n",
            "\n",
            "ZC METRIC: FISHER\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=fisher, max_eval=2000)]: 83.82 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 92.21 0.62\n",
            "- Total Cost: 2135 + 17610 = 19745 seconds\n",
            "- Total Epochs: 775\n",
            "- Best visited solution (all): 94.26 0.09\n",
            "- Best visited solution (top-32): 92.35 0.47\n",
            "- Average Ranked of SH selection: 4.54\n",
            "\n",
            "ZC METRIC: EPE_NAS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=epe_nas, max_eval=2000)]: 91.48 1.96\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 93.83 0.29\n",
            "- Total Cost: 2027 + 17442 = 19469 seconds\n",
            "- Total Epochs: 851\n",
            "- Best visited solution (all): 94.31 0.09\n",
            "- Best visited solution (top-32): 93.89 0.21\n",
            "- Average Ranked of SH selection: 1.65\n",
            "\n",
            "ZC METRIC: GRAD_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grad_norm, max_eval=2000)]: 84.07 1.09\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 93.13 0.06\n",
            "- Total Cost: 2080 + 17487 = 19567 seconds\n",
            "- Total Epochs: 749\n",
            "- Best visited solution (all): 94.31 0.07\n",
            "- Best visited solution (top-32): 93.13 0.06\n",
            "- Average Ranked of SH selection: 1.0\n",
            "\n",
            "ZC METRIC: SNIP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=snip, max_eval=2000)]: 83.82 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 93.13 0.0\n",
            "- Total Cost: 2093 + 17578 = 19671 seconds\n",
            "- Total Epochs: 746\n",
            "- Best visited solution (all): 94.31 0.07\n",
            "- Best visited solution (top-32): 93.13 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "\n",
            "ZC METRIC: SYNFLOW\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=synflow, max_eval=2000)]: 93.76 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 94.35 0.04\n",
            "- Total Cost: 1459 + 18204 = 19663 seconds\n",
            "- Total Epochs: 617\n",
            "- Best visited solution (all): 94.37 0.0\n",
            "- Best visited solution (top-32): 94.37 0.0\n",
            "- Average Ranked of SH selection: 1.99\n",
            "\n",
            "ZC METRIC: L2_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=l2_norm, max_eval=2000)]: 92.93 0.06\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 93.67 0.04\n",
            "- Total Cost: 408 + 19055 = 19463 seconds\n",
            "- Total Epochs: 600\n",
            "- Best visited solution (all): 94.37 0.01\n",
            "- Best visited solution (top-32): 93.76 0.02\n",
            "- Average Ranked of SH selection: 3.05\n",
            "\n",
            "ZC METRIC: ZEN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=zen, max_eval=2000)]: 90.64 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 90.64 0.0\n",
            "- Total Cost: 1533 + 18087 = 19620 seconds\n",
            "- Total Epochs: 989\n",
            "- Best visited solution (all): 94.37 0.02\n",
            "- Best visited solution (top-32): 90.74 0.0\n",
            "- Average Ranked of SH selection: 2.0\n",
            "\n",
            "ZC METRIC: NWOT\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=nwot, max_eval=2000)]: 93.32 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 93.56 0.01\n",
            "- Total Cost: 1999 + 17820 = 19819 seconds\n",
            "- Total Epochs: 566\n",
            "- Best visited solution (all): 94.37 0.01\n",
            "- Best visited solution (top-32): 93.89 0.02\n",
            "- Average Ranked of SH selection: 8.62\n",
            "\n",
            "ZC METRIC: PARAMS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=params, max_eval=2000)]: 93.76 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 94.36 0.0\n",
            "- Total Cost: 760 + 18774 = 19534 seconds\n",
            "- Total Epochs: 592\n",
            "- Best visited solution (all): 94.37 0.0\n",
            "- Best visited solution (top-32): 94.36 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "\n",
            "ZC METRIC: FLOPS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=flops, max_eval=2000)]: 93.76 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=20000 sec)]: 94.36 0.0\n",
            "- Total Cost: 765 + 18776 = 19541 seconds\n",
            "- Total Epochs: 592\n",
            "- Best visited solution (all): 94.37 0.0\n",
            "- Best visited solution (top-32): 94.36 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### MF-NAS (Random Search)"
      ],
      "metadata": {
        "id": "mfnKrBl4tDC5"
      },
      "id": "mfnKrBl4tDC5"
    },
    {
      "cell_type": "code",
      "source": [
        "metric = 'val_acc'\n",
        "for zc_metric in ['jacov', 'plain', 'grasp', 'fisher', 'epe_nas', 'grad_norm', 'snip', 'synflow', 'l2_norm', 'zen', 'nwot', 'params', 'flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = RS_SH(zc_metric=zc_metric, metric=metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, dataset=dataset, verbose=verbose)\n",
        "    print()"
      ],
      "metadata": {
        "id": "1aiOeYMLYecC",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ac061af4-7a21-4be7-c09e-621ce7747303"
      },
      "id": "1aiOeYMLYecC",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: JACOV\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=jacov, max_eval=2000)]: 92.02 1.18\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.35 0.56\n",
            "- Total Cost: 2056 + 17431 = 19487 seconds\n",
            "- Total Epochs: 863\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 93.85 0.22\n",
            "\n",
            "ZC METRIC: PLAIN\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=plain, max_eval=2000)]: 88.57 4.67\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 92.45 0.88\n",
            "- Total Cost: 2034 + 17445 = 19479 seconds\n",
            "- Total Epochs: 920\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 92.88 0.48\n",
            "\n",
            "ZC METRIC: GRASP\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=grasp, max_eval=2000)]: 89.48 2.56\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.25 0.31\n",
            "- Total Cost: 6259 + 13539 = 19798 seconds\n",
            "- Total Epochs: 571\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 93.58 0.12\n",
            "\n",
            "ZC METRIC: FISHER\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=fisher, max_eval=2000)]: 89.12 2.44\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.19 0.35\n",
            "- Total Cost: 2141 + 17501 = 19642 seconds\n",
            "- Total Epochs: 763\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 93.48 0.18\n",
            "\n",
            "ZC METRIC: EPE_NAS\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=epe_nas, max_eval=2000)]: 91.48 1.84\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.2 0.7\n",
            "- Total Cost: 2052 + 17441 = 19493 seconds\n",
            "- Total Epochs: 877\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 93.87 0.21\n",
            "\n",
            "ZC METRIC: GRAD_NORM\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=grad_norm, max_eval=2000)]: 89.2 2.47\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.39 0.32\n",
            "- Total Cost: 2050 + 17617 = 19667 seconds\n",
            "- Total Epochs: 706\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 93.67 0.16\n",
            "\n",
            "ZC METRIC: SNIP\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=snip, max_eval=2000)]: 89.37 2.52\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.38 0.29\n",
            "- Total Cost: 2063 + 17597 = 19660 seconds\n",
            "- Total Epochs: 705\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 93.63 0.14\n",
            "\n",
            "ZC METRIC: SYNFLOW\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=synflow, max_eval=2000)]: 93.44 0.66\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 94.01 0.26\n",
            "- Total Cost: 1289 + 18396 = 19685 seconds\n",
            "- Total Epochs: 690\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 94.24 0.14\n",
            "\n",
            "ZC METRIC: L2_NORM\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=l2_norm, max_eval=2000)]: 93.02 0.94\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.92 0.46\n",
            "- Total Cost: 362 + 19273 = 19635 seconds\n",
            "- Total Epochs: 675\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 94.07 0.26\n",
            "\n",
            "ZC METRIC: ZEN\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=zen, max_eval=2000)]: 89.19 1.31\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 88.77 1.11\n",
            "- Total Cost: 1616 + 17772 = 19388 seconds\n",
            "- Total Epochs: 982\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 90.42 0.28\n",
            "\n",
            "ZC METRIC: NWOT\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=nwot, max_eval=2000)]: 93.18 0.55\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.41 0.3\n",
            "- Total Cost: 1627 + 18041 = 19668 seconds\n",
            "- Total Epochs: 640\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 93.79 0.15\n",
            "\n",
            "ZC METRIC: PARAMS\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=params, max_eval=2000)]: 93.55 0.33\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.98 0.32\n",
            "- Total Cost: 721 + 18916 = 19637 seconds\n",
            "- Total Epochs: 673\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 94.15 0.22\n",
            "\n",
            "ZC METRIC: FLOPS\n",
            "Dataset: cifar10\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=flops, max_eval=2000)]: 93.53 0.29\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=20000 sec)]: 93.99 0.3\n",
            "- Total Cost: 720 + 18918 = 19638 seconds\n",
            "- Total Epochs: 673\n",
            "- Best visited solution (all): 94.27 0.1\n",
            "- Best visited solution (top-32): 94.15 0.22\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Random Search"
      ],
      "metadata": {
        "id": "_oQPEPM2Cwfp"
      },
      "id": "_oQPEPM2Cwfp"
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_acc = []\n",
        "iepoch = 12\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_random_search(seed=run_id, iepoch=iepoch, max_time_budget=max_budget)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ex7l4EKZCyGS",
        "outputId": "ea4489a0-5699-4697-ee2e-7deac0462112"
      },
      "id": "ex7l4EKZCyGS",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 93.35   Std: 0.66\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Local Search"
      ],
      "metadata": {
        "id": "hB1ToAbrAjXk"
      },
      "id": "hB1ToAbrAjXk"
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_acc = []\n",
        "iepoch = 12\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_ils(seed=run_id, iepoch=iepoch, max_budget=max_budget, dataset=dataset)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "I5h-PMqkAk3G",
        "outputId": "83941e87-ee57-44f1-a7ed-c547c38ea4da"
      },
      "id": "I5h-PMqkAk3G",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 93.64   Std: 0.52\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Succesive Halving"
      ],
      "metadata": {
        "id": "pEAl-SjPlCjt"
      },
      "id": "pEAl-SjPlCjt"
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_acc = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, total_epoch = run_succesive_halving(seed=run_id, N=k, max_budget=max_budget, dataset=dataset, list_epochs=list_epochs)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "imNuNlKTlFAp",
        "outputId": "eb26f7ff-c71a-4713-da88-7a357d06c8eb"
      },
      "id": "imNuNlKTlFAp",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 93.06   Std: 0.74\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### REA"
      ],
      "metadata": {
        "id": "WFWZlrX0EMpl"
      },
      "id": "WFWZlrX0EMpl"
    },
    {
      "cell_type": "code",
      "source": [
        "population_size, tournament_size = 10, 10\n",
        "warm_up = False\n",
        "all_test_acc = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_evolution_search(seed=run_id, max_time_budget=max_budget, iepoch=iepoch, population_size=population_size, tournament_size=tournament_size, warm_up=warm_up)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "kNlHI00qESNY",
        "outputId": "a5b53494-f046-422d-d1a3-cd6e30b4d421"
      },
      "id": "kNlHI00qESNY",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 93.45   Std: 0.68\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### REA + W"
      ],
      "metadata": {
        "id": "gp_IhrAGHqNI"
      },
      "id": "gp_IhrAGHqNI"
    },
    {
      "cell_type": "code",
      "source": [
        "population_size, tournament_size = 10, 10\n",
        "n_warmup = 2000\n",
        "warm_up = True\n",
        "all_test_acc = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_evolution_search(seed=run_id, max_time_budget=max_budget, iepoch=iepoch, population_size=population_size, tournament_size=tournament_size, warm_up=warm_up, n_warmup=n_warmup)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "id": "4rMD-jM8Hpqz",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "dd3554ac-1acf-4369-bf93-2c578e25bf07"
      },
      "id": "4rMD-jM8Hpqz",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 93.86   Std: 0.36\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## CIFAR-100"
      ],
      "metadata": {
        "id": "btPhKXxGuzc8"
      },
      "id": "btPhKXxGuzc8"
    },
    {
      "cell_type": "code",
      "source": [
        "benchmark = p.load(open('[CIFAR-100]_data.p', 'rb'))"
      ],
      "metadata": {
        "id": "2nSJRlfa34dT"
      },
      "execution_count": 11,
      "outputs": [],
      "id": "2nSJRlfa34dT"
    },
    {
      "cell_type": "code",
      "source": [
        "dataset = 'cifar100'\n",
        "n_run = 500\n",
        "max_eval = 2000  # Stage 1 (Local Search)\n",
        "max_budget = 40000  # Stage 2 (SH, seconds)\n",
        "list_epochs = [12, 25, 50, 100, 200]  # First epoch\n",
        "k = 32  # Top-k\n",
        "all_sol = list(benchmark['200'].keys())\n",
        "verbose = False"
      ],
      "metadata": {
        "id": "stlToHvs34dT"
      },
      "execution_count": 12,
      "outputs": [],
      "id": "stlToHvs34dT"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### MF-NAS (Local Search)"
      ],
      "metadata": {
        "id": "aKpOveV9uUkc"
      },
      "id": "aKpOveV9uUkc"
    },
    {
      "cell_type": "code",
      "source": [
        "metric = 'val_acc'\n",
        "for zc_metric in ['jacov', 'plain', 'grasp', 'fisher', 'epe_nas', 'grad_norm', 'snip', 'synflow', 'l2_norm', 'zen', 'nwot', 'params', 'flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, metric=metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, dataset=dataset, verbose=verbose)\n",
        "    print()"
      ],
      "metadata": {
        "id": "kDwCc0_IuUkd",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9860191b-5673-47e8-8e60-90e827468665"
      },
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: JACOV\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=jacov, max_eval=2000)]: 68.34 0.74\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 71.73 0.67\n",
            "- Total Cost: 2005 + 25760 = 27765 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.17 0.27\n",
            "- Best visited solution (top-32): 71.91 0.45\n",
            "- Average Ranked of SH selection: 1.88\n",
            "\n",
            "ZC METRIC: PLAIN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=plain, max_eval=2000)]: 62.66 4.5\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 68.93 1.15\n",
            "- Total Cost: 2078 + 24568 = 26646 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 72.45 0.62\n",
            "- Best visited solution (top-32): 69.14 1.1\n",
            "- Average Ranked of SH selection: 1.49\n",
            "\n",
            "ZC METRIC: GRASP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grasp, max_eval=2000)]: 50.64 3.07\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 69.42 0.17\n",
            "- Total Cost: 7232 + 27710 = 34942 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.16 0.25\n",
            "- Best visited solution (top-32): 69.47 0.17\n",
            "- Average Ranked of SH selection: 1.17\n",
            "\n",
            "ZC METRIC: FISHER\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=fisher, max_eval=2000)]: 49.81 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 65.67 0.91\n",
            "- Total Cost: 2125 + 28196 = 30321 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 72.87 0.4\n",
            "- Best visited solution (top-32): 66.72 0.54\n",
            "- Average Ranked of SH selection: 5.98\n",
            "\n",
            "ZC METRIC: EPE_NAS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=epe_nas, max_eval=2000)]: 64.26 4.14\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 71.51 1.14\n",
            "- Total Cost: 1996 + 23787 = 25783 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 72.99 0.36\n",
            "- Best visited solution (top-32): 71.76 0.72\n",
            "- Average Ranked of SH selection: 2.05\n",
            "\n",
            "ZC METRIC: GRAD_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grad_norm, max_eval=2000)]: 50.31 2.52\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 68.82 0.77\n",
            "- Total Cost: 2064 + 28768 = 30832 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.13 0.37\n",
            "- Best visited solution (top-32): 68.85 0.64\n",
            "- Average Ranked of SH selection: 1.13\n",
            "\n",
            "ZC METRIC: SNIP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=snip, max_eval=2000)]: 50.52 2.74\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 69.28 0.34\n",
            "- Total Cost: 2090 + 29463 = 31553 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.23 0.32\n",
            "- Best visited solution (top-32): 69.33 0.0\n",
            "- Average Ranked of SH selection: 1.07\n",
            "\n",
            "ZC METRIC: SYNFLOW\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=synflow, max_eval=2000)]: 71.11 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 73.51 0.05\n",
            "- Total Cost: 1089 + 34045 = 35134 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.51 0.0\n",
            "- Best visited solution (top-32): 73.51 0.0\n",
            "- Average Ranked of SH selection: 1.01\n",
            "\n",
            "ZC METRIC: L2_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=l2_norm, max_eval=2000)]: 71.28 0.05\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 70.8 0.37\n",
            "- Total Cost: 413 + 37922 = 38335 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.45 0.11\n",
            "- Best visited solution (top-32): 71.59 0.16\n",
            "- Average Ranked of SH selection: 10.44\n",
            "\n",
            "ZC METRIC: ZEN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=zen, max_eval=2000)]: 68.1 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 60.7 0.76\n",
            "- Total Cost: 1555 + 18983 = 20538 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.5 0.05\n",
            "- Best visited solution (top-32): 68.55 0.0\n",
            "- Average Ranked of SH selection: 21.65\n",
            "\n",
            "ZC METRIC: NWOT\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=nwot, max_eval=2000)]: 69.84 0.04\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 71.16 0.46\n",
            "- Total Cost: 1959 + 37422 = 39381 seconds\n",
            "- Total Epochs: 1177\n",
            "- Best visited solution (all): 73.38 0.15\n",
            "- Best visited solution (top-32): 71.74 0.0\n",
            "- Average Ranked of SH selection: 5.24\n",
            "\n",
            "ZC METRIC: PARAMS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=params, max_eval=2000)]: 71.11 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 73.51 0.0\n",
            "- Total Cost: 778 + 37593 = 38371 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.51 0.0\n",
            "- Best visited solution (top-32): 73.51 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "\n",
            "ZC METRIC: FLOPS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=flops, max_eval=2000)]: 71.11 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 73.51 0.0\n",
            "- Total Cost: 773 + 37587 = 38360 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.51 0.0\n",
            "- Best visited solution (top-32): 73.51 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "\n"
          ]
        }
      ],
      "id": "kDwCc0_IuUkd"
    },
    {
      "cell_type": "code",
      "source": [
        "metric = 'train_loss'\n",
        "for zc_metric in ['jacov', 'plain', 'grasp', 'fisher', 'epe_nas', 'grad_norm', 'snip', 'synflow', 'l2_norm', 'zen', 'nwot', 'params', 'flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, metric=metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, dataset=dataset, verbose=verbose)"
      ],
      "metadata": {
        "id": "ds9YHUhuuUke",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "f27ed25e-2f64-404c-aebb-924d8d16110c"
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: JACOV\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=jacov, max_eval=2000)]: 68.34 0.74\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 71.79 0.5\n",
            "- Total Cost: 2005 + 26091 = 28096 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.17 0.27\n",
            "- Best visited solution (top-32): 71.91 0.45\n",
            "- Average Ranked of SH selection: 1.63\n",
            "ZC METRIC: PLAIN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=plain, max_eval=2000)]: 62.66 4.5\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 68.61 1.18\n",
            "- Total Cost: 2078 + 24877 = 26955 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 72.45 0.62\n",
            "- Best visited solution (top-32): 69.14 1.1\n",
            "- Average Ranked of SH selection: 2.01\n",
            "ZC METRIC: GRASP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grasp, max_eval=2000)]: 50.64 3.07\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 69.35 0.13\n",
            "- Total Cost: 7232 + 28144 = 35376 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.16 0.25\n",
            "- Best visited solution (top-32): 69.47 0.17\n",
            "- Average Ranked of SH selection: 2.05\n",
            "ZC METRIC: FISHER\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=fisher, max_eval=2000)]: 49.81 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 65.72 0.73\n",
            "- Total Cost: 2125 + 28298 = 30423 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 72.87 0.4\n",
            "- Best visited solution (top-32): 66.72 0.54\n",
            "- Average Ranked of SH selection: 5.37\n",
            "ZC METRIC: EPE_NAS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=epe_nas, max_eval=2000)]: 64.26 4.14\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 71.58 0.98\n",
            "- Total Cost: 1996 + 24612 = 26608 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 72.99 0.36\n",
            "- Best visited solution (top-32): 71.76 0.72\n",
            "- Average Ranked of SH selection: 1.72\n",
            "ZC METRIC: GRAD_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grad_norm, max_eval=2000)]: 50.31 2.52\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 68.25 1.59\n",
            "- Total Cost: 2064 + 28814 = 30878 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.13 0.37\n",
            "- Best visited solution (top-32): 68.85 0.64\n",
            "- Average Ranked of SH selection: 2.31\n",
            "ZC METRIC: SNIP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=snip, max_eval=2000)]: 50.52 2.74\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 69.33 0.0\n",
            "- Total Cost: 2090 + 29626 = 31716 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.23 0.32\n",
            "- Best visited solution (top-32): 69.33 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "ZC METRIC: SYNFLOW\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=synflow, max_eval=2000)]: 71.11 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 73.51 0.0\n",
            "- Total Cost: 1089 + 34356 = 35445 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.51 0.0\n",
            "- Best visited solution (top-32): 73.51 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "ZC METRIC: L2_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=l2_norm, max_eval=2000)]: 71.28 0.05\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 71.15 0.12\n",
            "- Total Cost: 413 + 37980 = 38393 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.45 0.11\n",
            "- Best visited solution (top-32): 71.59 0.16\n",
            "- Average Ranked of SH selection: 4.63\n",
            "ZC METRIC: ZEN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=zen, max_eval=2000)]: 68.1 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 66.11 0.0\n",
            "- Total Cost: 1555 + 22630 = 24185 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.5 0.05\n",
            "- Best visited solution (top-32): 68.55 0.0\n",
            "- Average Ranked of SH selection: 9.34\n",
            "ZC METRIC: NWOT\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=nwot, max_eval=2000)]: 69.84 0.04\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 71.28 0.25\n",
            "- Total Cost: 1959 + 37776 = 39735 seconds\n",
            "- Total Epochs: 1189\n",
            "- Best visited solution (all): 73.38 0.15\n",
            "- Best visited solution (top-32): 71.74 0.0\n",
            "- Average Ranked of SH selection: 4.6\n",
            "ZC METRIC: PARAMS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=params, max_eval=2000)]: 71.11 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 73.51 0.0\n",
            "- Total Cost: 778 + 37328 = 38106 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.51 0.0\n",
            "- Best visited solution (top-32): 73.51 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "ZC METRIC: FLOPS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=flops, max_eval=2000)]: 71.11 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=40000 sec)]: 73.51 0.0\n",
            "- Total Cost: 773 + 37326 = 38099 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.51 0.0\n",
            "- Best visited solution (top-32): 73.51 0.0\n",
            "- Average Ranked of SH selection: 1.0\n"
          ]
        }
      ],
      "id": "ds9YHUhuuUke"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### MF-NAS (Random Search)"
      ],
      "metadata": {
        "id": "C11jLvxHuUke"
      },
      "id": "C11jLvxHuUke"
    },
    {
      "cell_type": "code",
      "source": [
        "metric = 'val_acc'\n",
        "for zc_metric in ['jacov', 'plain', 'grasp', 'fisher', 'epe_nas', 'grad_norm', 'snip', 'synflow', 'l2_norm', 'zen', 'nwot', 'params', 'flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = RS_SH(zc_metric=zc_metric, metric=metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, dataset=dataset, verbose=verbose)"
      ],
      "metadata": {
        "id": "EXnMP6VvuUkf",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "5e42817e-0b61-4e77-954f-fdfeef6fb4dd"
      },
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: JACOV\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=jacov, max_eval=2000)]: 68.38 1.87\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 71.45 1.0\n",
            "- Total Cost: 2046 + 24589 = 26635 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 71.74 0.7\n",
            "ZC METRIC: PLAIN\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=plain, max_eval=2000)]: 63.34 5.03\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 69.64 1.01\n",
            "- Total Cost: 2010 + 25151 = 27161 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 69.95 0.83\n",
            "ZC METRIC: GRASP\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=grasp, max_eval=2000)]: 62.45 5.4\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 70.63 0.68\n",
            "- Total Cost: 6256 + 30777 = 37033 seconds\n",
            "- Total Epochs: 1191\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 71.05 0.48\n",
            "ZC METRIC: FISHER\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=fisher, max_eval=2000)]: 60.87 4.92\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 70.07 0.63\n",
            "- Total Cost: 2154 + 29089 = 31243 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 70.42 0.51\n",
            "ZC METRIC: EPE_NAS\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=epe_nas, max_eval=2000)]: 65.22 4.05\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 71.02 1.14\n",
            "- Total Cost: 2003 + 23344 = 25347 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 71.36 0.68\n",
            "ZC METRIC: GRAD_NORM\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=grad_norm, max_eval=2000)]: 61.27 5.01\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 70.6 0.64\n",
            "- Total Cost: 2041 + 31343 = 33384 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 71.05 0.44\n",
            "ZC METRIC: SNIP\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=snip, max_eval=2000)]: 61.12 5.03\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 70.62 0.61\n",
            "- Total Cost: 2059 + 31701 = 33760 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 71.13 0.44\n",
            "ZC METRIC: SYNFLOW\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=synflow, max_eval=2000)]: 70.59 1.93\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 72.61 0.61\n",
            "- Total Cost: 1005 + 31234 = 32239 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 72.91 0.43\n",
            "ZC METRIC: L2_NORM\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=l2_norm, max_eval=2000)]: 69.85 1.47\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 72.17 1.0\n",
            "- Total Cost: 374 + 34176 = 34550 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 72.39 0.79\n",
            "ZC METRIC: ZEN\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=zen, max_eval=2000)]: 63.73 4.07\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 60.24 0.99\n",
            "- Total Cost: 1611 + 20153 = 21764 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 66.91 1.2\n",
            "ZC METRIC: NWOT\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=nwot, max_eval=2000)]: 70.0 0.97\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 70.82 0.48\n",
            "- Total Cost: 1610 + 33994 = 35604 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 71.42 0.29\n",
            "ZC METRIC: PARAMS\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=params, max_eval=2000)]: 70.57 0.83\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 72.53 0.85\n",
            "- Total Cost: 733 + 33321 = 34054 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 72.7 0.68\n",
            "ZC METRIC: FLOPS\n",
            "Dataset: cifar100\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=flops, max_eval=2000)]: 70.57 0.83\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=40000 sec)]: 72.52 0.88\n",
            "- Total Cost: 729 + 33281 = 34010 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 73.09 0.32\n",
            "- Best visited solution (top-32): 72.7 0.68\n"
          ]
        }
      ],
      "id": "EXnMP6VvuUkf"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Random Search"
      ],
      "metadata": {
        "id": "uD8jle3GJDAa"
      },
      "id": "uD8jle3GJDAa"
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_acc = []\n",
        "iepoch = 12\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_random_search(seed=run_id, iepoch=iepoch, max_time_budget=max_budget)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3beeda54-f742-4ffc-8cdf-412370ca0a8b",
        "id": "-gycjL1dI-05"
      },
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 70.93   Std: 1.18\n"
          ]
        }
      ],
      "id": "-gycjL1dI-05"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Local Search"
      ],
      "metadata": {
        "id": "Wn7ZipcbI-05"
      },
      "id": "Wn7ZipcbI-05"
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_acc = []\n",
        "iepoch = 12\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_ils(seed=run_id, iepoch=iepoch, max_budget=max_budget, dataset=dataset)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "39c7b0da-2901-40f7-ae59-cf686152032b",
        "id": "V01RRZjkI-06"
      },
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 71.43   Std: 0.84\n"
          ]
        }
      ],
      "id": "V01RRZjkI-06"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Succesive Halving"
      ],
      "metadata": {
        "id": "MutClg0OI-06"
      },
      "id": "MutClg0OI-06"
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_acc = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, total_epoch = run_succesive_halving(seed=run_id, N=k, max_budget=max_budget, dataset=dataset, list_epochs=list_epochs)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b40dd47b-0f83-4c41-c200-68a8f83d267d",
        "id": "7Ru_IpX4I-06"
      },
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 70.36   Std: 1.18\n"
          ]
        }
      ],
      "id": "7Ru_IpX4I-06"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### REA"
      ],
      "metadata": {
        "id": "kM_Al5oCI-06"
      },
      "id": "kM_Al5oCI-06"
    },
    {
      "cell_type": "code",
      "source": [
        "population_size, tournament_size = 10, 10\n",
        "warm_up = False\n",
        "all_test_acc = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_evolution_search(seed=run_id, max_time_budget=max_budget, iepoch=iepoch, population_size=population_size, tournament_size=tournament_size, warm_up=warm_up)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "6c0c7d00-e5b9-4807-a565-972ddb1fa7e0",
        "id": "H1twlRkJI-07"
      },
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 71.2   Std: 1.16\n"
          ]
        }
      ],
      "id": "H1twlRkJI-07"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### REA + W"
      ],
      "metadata": {
        "id": "6L6n6hFXI-07"
      },
      "id": "6L6n6hFXI-07"
    },
    {
      "cell_type": "code",
      "source": [
        "population_size, tournament_size = 10, 10\n",
        "n_warmup = 2000\n",
        "warm_up = True\n",
        "all_test_acc = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_evolution_search(seed=run_id, max_time_budget=max_budget, iepoch=iepoch, population_size=population_size, tournament_size=tournament_size, warm_up=warm_up, n_warmup=n_warmup)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "486c22d4-78fb-48f2-afe9-18145e555861",
        "id": "AmM_3lAQI-08"
      },
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 71.69   Std: 0.73\n"
          ]
        }
      ],
      "id": "AmM_3lAQI-08"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## ImageNet16-120"
      ],
      "metadata": {
        "id": "hfKJWyowqYBb"
      },
      "id": "hfKJWyowqYBb"
    },
    {
      "cell_type": "code",
      "source": [
        "benchmark = p.load(open('[ImageNet16-120]_data.p', 'rb'))"
      ],
      "metadata": {
        "id": "MxmVo9bhu-Pf"
      },
      "id": "MxmVo9bhu-Pf",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "dataset = 'ImageNet16-120'\n",
        "n_run = 500\n",
        "max_eval = 2000  # Stage 1 (Local Search)\n",
        "max_budget = 120000  # Stage 2 (SH, seconds)\n",
        "list_epochs = [12, 25, 50, 100, 200]  # First epoch\n",
        "k = 32  # Top-k\n",
        "all_sol = list(benchmark['200'].keys())\n",
        "verbose = False"
      ],
      "metadata": {
        "id": "T9pjmiViq_Su"
      },
      "id": "T9pjmiViq_Su",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### MF-NAS (Local Search)"
      ],
      "metadata": {
        "id": "2skbzrjRuYZs"
      },
      "id": "2skbzrjRuYZs"
    },
    {
      "cell_type": "code",
      "source": [
        "metric = 'val_acc'\n",
        "for zc_metric in ['jacov', 'plain', 'grasp', 'fisher', 'epe_nas', 'grad_norm', 'snip', 'synflow', 'l2_norm', 'zen', 'nwot', 'params', 'flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, metric=metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, dataset=dataset, verbose=verbose)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "a3322c64-72c0-4041-dd3f-a28d4849cff7",
        "id": "HdMKOlVmuYZs"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: JACOV\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=jacov, max_eval=2000)]: 40.67 3.33\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 44.4 0.9\n",
            "- Total Cost: 2037 + 74477 = 76514 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.88 0.29\n",
            "- Best visited solution (top-32): 45.33 0.35\n",
            "- Average Ranked of SH selection: 4.29\n",
            "ZC METRIC: PLAIN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=plain, max_eval=2000)]: 25.33 6.3\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 42.74 2.35\n",
            "- Total Cost: 2000 + 74949 = 76949 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.45 0.42\n",
            "- Best visited solution (top-32): 43.4 1.7\n",
            "- Average Ranked of SH selection: 1.87\n",
            "ZC METRIC: GRASP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grasp, max_eval=2000)]: 24.5 5.23\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 42.77 1.42\n",
            "- Total Cost: 3278 + 85105 = 88383 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.06 0.22\n",
            "- Best visited solution (top-32): 42.87 1.13\n",
            "- Average Ranked of SH selection: 1.16\n",
            "ZC METRIC: FISHER\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=fisher, max_eval=2000)]: 28.6 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 35.97 0.0\n",
            "- Total Cost: 2071 + 80612 = 82683 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.92 0.27\n",
            "- Best visited solution (top-32): 38.83 0.0\n",
            "- Average Ranked of SH selection: 2.92\n",
            "ZC METRIC: EPE_NAS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=epe_nas, max_eval=2000)]: 39.29 4.06\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 44.12 1.19\n",
            "- Total Cost: 2011 + 71555 = 73566 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.81 0.4\n",
            "- Best visited solution (top-32): 44.94 0.53\n",
            "- Average Ranked of SH selection: 3.07\n",
            "ZC METRIC: GRAD_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grad_norm, max_eval=2000)]: 24.38 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 35.97 0.0\n",
            "- Total Cost: 2016 + 81085 = 83101 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.99 0.23\n",
            "- Best visited solution (top-32): 38.83 0.0\n",
            "- Average Ranked of SH selection: 3.09\n",
            "ZC METRIC: SNIP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=snip, max_eval=2000)]: 20.56 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 35.97 0.0\n",
            "- Total Cost: 2027 + 83090 = 85117 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.96 0.24\n",
            "- Best visited solution (top-32): 38.83 0.0\n",
            "- Average Ranked of SH selection: 3.5\n",
            "ZC METRIC: SYNFLOW\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=synflow, max_eval=2000)]: 41.44 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 46.34 0.0\n",
            "- Total Cost: 1831 + 104532 = 106363 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.08 0.2\n",
            "- Best visited solution (top-32): 46.48 0.01\n",
            "- Average Ranked of SH selection: 2.55\n",
            "ZC METRIC: L2_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=l2_norm, max_eval=2000)]: 45.7 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 46.48 0.21\n",
            "- Total Cost: 973 + 113733 = 114706 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.27 0.1\n",
            "- Best visited solution (top-32): 46.53 0.0\n",
            "- Average Ranked of SH selection: 1.32\n",
            "ZC METRIC: ZEN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=zen, max_eval=2000)]: 40.77 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 40.77 0.0\n",
            "- Total Cost: 1833 + 61417 = 63250 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.12 0.23\n",
            "- Best visited solution (top-32): 40.77 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "ZC METRIC: NWOT\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=nwot, max_eval=2000)]: 45.48 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 45.32 0.06\n",
            "- Total Cost: 1941 + 112605 = 114546 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.02 0.25\n",
            "- Best visited solution (top-32): 46.55 0.0\n",
            "- Average Ranked of SH selection: 12.98\n",
            "ZC METRIC: PARAMS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=params, max_eval=2000)]: 41.44 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 46.34 0.0\n",
            "- Total Cost: 1240 + 113945 = 115185 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.08 0.2\n",
            "- Best visited solution (top-32): 46.34 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "ZC METRIC: FLOPS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=flops, max_eval=2000)]: 41.44 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 46.34 0.0\n",
            "- Total Cost: 1250 + 113917 = 115167 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.06 0.2\n",
            "- Best visited solution (top-32): 46.34 0.0\n",
            "- Average Ranked of SH selection: 1.0\n"
          ]
        }
      ],
      "id": "HdMKOlVmuYZs"
    },
    {
      "cell_type": "code",
      "source": [
        "metric = 'train_loss'\n",
        "for zc_metric in ['jacov', 'plain', 'grasp', 'fisher', 'epe_nas', 'grad_norm', 'snip', 'synflow', 'l2_norm', 'zen', 'nwot', 'params', 'flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = ILS_SH(zc_metric=zc_metric, metric=metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, dataset=dataset, verbose=verbose)"
      ],
      "metadata": {
        "id": "Pwo312WXuYZt",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "930eab97-17e0-49b3-bf24-bdcd54919623"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: JACOV\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=jacov, max_eval=2000)]: 40.67 3.33\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 44.47 1.01\n",
            "- Total Cost: 2037 + 77998 = 80035 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.88 0.29\n",
            "- Best visited solution (top-32): 45.33 0.35\n",
            "- Average Ranked of SH selection: 4.29\n",
            "ZC METRIC: PLAIN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=plain, max_eval=2000)]: 25.33 6.3\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 42.83 2.04\n",
            "- Total Cost: 2000 + 76976 = 78976 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.45 0.42\n",
            "- Best visited solution (top-32): 43.4 1.7\n",
            "- Average Ranked of SH selection: 1.69\n",
            "ZC METRIC: GRASP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grasp, max_eval=2000)]: 24.5 5.23\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 42.87 1.13\n",
            "- Total Cost: 3278 + 85219 = 88497 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.06 0.22\n",
            "- Best visited solution (top-32): 42.87 1.13\n",
            "- Average Ranked of SH selection: 1.04\n",
            "ZC METRIC: FISHER\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=fisher, max_eval=2000)]: 28.6 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 38.83 0.0\n",
            "- Total Cost: 2071 + 80531 = 82602 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.92 0.27\n",
            "- Best visited solution (top-32): 38.83 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "ZC METRIC: EPE_NAS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=epe_nas, max_eval=2000)]: 39.29 4.06\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 44.34 1.02\n",
            "- Total Cost: 2011 + 74502 = 76513 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.81 0.4\n",
            "- Best visited solution (top-32): 44.94 0.53\n",
            "- Average Ranked of SH selection: 2.59\n",
            "ZC METRIC: GRAD_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=grad_norm, max_eval=2000)]: 24.38 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 38.83 0.0\n",
            "- Total Cost: 2016 + 81165 = 83181 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.99 0.23\n",
            "- Best visited solution (top-32): 38.83 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "ZC METRIC: SNIP\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=snip, max_eval=2000)]: 20.56 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 38.83 0.0\n",
            "- Total Cost: 2027 + 83383 = 85410 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.96 0.24\n",
            "- Best visited solution (top-32): 38.83 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "ZC METRIC: SYNFLOW\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=synflow, max_eval=2000)]: 41.44 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 46.34 0.0\n",
            "- Total Cost: 1831 + 104297 = 106128 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.08 0.2\n",
            "- Best visited solution (top-32): 46.48 0.01\n",
            "- Average Ranked of SH selection: 2.55\n",
            "ZC METRIC: L2_NORM\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=l2_norm, max_eval=2000)]: 45.7 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 46.53 0.04\n",
            "- Total Cost: 973 + 113804 = 114777 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.27 0.1\n",
            "- Best visited solution (top-32): 46.53 0.0\n",
            "- Average Ranked of SH selection: 1.03\n",
            "ZC METRIC: ZEN\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=zen, max_eval=2000)]: 40.77 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 40.32 0.0\n",
            "- Total Cost: 1833 + 67942 = 69775 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.12 0.23\n",
            "- Best visited solution (top-32): 40.77 0.0\n",
            "- Average Ranked of SH selection: 2.7\n",
            "ZC METRIC: NWOT\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=nwot, max_eval=2000)]: 45.48 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 46.53 0.05\n",
            "- Total Cost: 1941 + 112793 = 114734 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.02 0.25\n",
            "- Best visited solution (top-32): 46.55 0.0\n",
            "- Average Ranked of SH selection: 2.02\n",
            "ZC METRIC: PARAMS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=params, max_eval=2000)]: 41.44 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 46.34 0.0\n",
            "- Total Cost: 1240 + 114035 = 115275 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.08 0.2\n",
            "- Best visited solution (top-32): 46.34 0.0\n",
            "- Average Ranked of SH selection: 1.0\n",
            "ZC METRIC: FLOPS\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [ILS (metric=flops, max_eval=2000)]: 41.44 0.0\n",
            "- Stage 2 [Successive Halving (top-32, metric=train_loss, budget=120000 sec)]: 46.34 0.0\n",
            "- Total Cost: 1250 + 114025 = 115275 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 47.06 0.2\n",
            "- Best visited solution (top-32): 46.34 0.0\n",
            "- Average Ranked of SH selection: 1.0\n"
          ]
        }
      ],
      "id": "Pwo312WXuYZt"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### MF-NAS (Random Search)"
      ],
      "metadata": {
        "id": "2VmOSPQMuYZt"
      },
      "id": "2VmOSPQMuYZt"
    },
    {
      "cell_type": "code",
      "source": [
        "metric = 'val_acc'\n",
        "for zc_metric in ['jacov', 'plain', 'grasp', 'fisher', 'epe_nas', 'grad_norm', 'snip', 'synflow', 'l2_norm', 'zen', 'nwot', 'params', 'flops']:\n",
        "    print('ZC METRIC:', zc_metric.upper())\n",
        "    all_best_solution_SH, all_f_best_solution_SH, all_cost_SH, all_total_epoch_SH = RS_SH(zc_metric=zc_metric, metric=metric, k=k, n_run=n_run, max_eval=max_eval, allowed_time=max_budget, list_epochs=list_epochs, dataset=dataset, verbose=verbose)"
      ],
      "metadata": {
        "id": "UyIoOE8BuYZv",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "775b6020-1e69-40ff-86a6-a35ac65901b1"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ZC METRIC: JACOV\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=jacov, max_eval=2000)]: 40.87 3.46\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 45.13 0.97\n",
            "- Total Cost: 2017 + 75366 = 77383 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 45.64 0.55\n",
            "ZC METRIC: PLAIN\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=plain, max_eval=2000)]: 29.73 10.66\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 43.29 1.89\n",
            "- Total Cost: 1986 + 76737 = 78723 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 44.33 1.38\n",
            "ZC METRIC: GRASP\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=grasp, max_eval=2000)]: 28.51 8.96\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 45.27 0.84\n",
            "- Total Cost: 3050 + 92101 = 95151 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 45.66 0.62\n",
            "ZC METRIC: FISHER\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=fisher, max_eval=2000)]: 23.74 10.69\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 44.79 1.09\n",
            "- Total Cost: 2078 + 86102 = 88180 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 45.0 0.95\n",
            "ZC METRIC: EPE_NAS\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=epe_nas, max_eval=2000)]: 39.14 4.02\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 44.31 1.14\n",
            "- Total Cost: 2003 + 71176 = 73179 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 44.97 0.55\n",
            "ZC METRIC: GRAD_NORM\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=grad_norm, max_eval=2000)]: 25.13 9.22\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 45.49 0.72\n",
            "- Total Cost: 1998 + 95416 = 97414 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 45.78 0.62\n",
            "ZC METRIC: SNIP\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=snip, max_eval=2000)]: 23.82 10.87\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 45.46 0.81\n",
            "- Total Cost: 2012 + 95462 = 97474 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 45.78 0.65\n",
            "ZC METRIC: SYNFLOW\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=synflow, max_eval=2000)]: 42.42 3.74\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 46.24 0.51\n",
            "- Total Cost: 1762 + 95123 = 96885 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 46.64 0.34\n",
            "ZC METRIC: L2_NORM\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=l2_norm, max_eval=2000)]: 44.61 1.79\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 45.94 0.65\n",
            "- Total Cost: 950 + 103200 = 104150 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 46.35 0.34\n",
            "ZC METRIC: ZEN\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=zen, max_eval=2000)]: 36.82 3.31\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 35.66 2.3\n",
            "- Total Cost: 1838 + 63115 = 64953 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 39.1 1.23\n",
            "ZC METRIC: NWOT\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=nwot, max_eval=2000)]: 44.51 1.94\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 45.52 0.53\n",
            "- Total Cost: 1816 + 102082 = 103898 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 46.09 0.34\n",
            "ZC METRIC: PARAMS\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=params, max_eval=2000)]: 41.67 2.28\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 45.89 0.81\n",
            "- Total Cost: 1208 + 101074 = 102282 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 46.3 0.43\n",
            "ZC METRIC: FLOPS\n",
            "Dataset: ImageNet16-120\n",
            "#Runs: 500\n",
            "- Stage 1 [Random Search (metric=flops, max_eval=2000)]: 41.67 2.28\n",
            "- Stage 2 [Successive Halving (top-32, metric=val_acc, budget=120000 sec)]: 45.93 0.8\n",
            "- Total Cost: 1217 + 101048 = 102265 seconds\n",
            "- Total Epochs: 1192\n",
            "- Best visited solution (all): 46.83 0.3\n",
            "- Best visited solution (top-32): 46.3 0.45\n"
          ]
        }
      ],
      "id": "UyIoOE8BuYZv"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Random Search"
      ],
      "metadata": {
        "id": "i9xrDWH9JLQ6"
      },
      "id": "i9xrDWH9JLQ6"
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_acc = []\n",
        "iepoch = 12\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_random_search(seed=run_id, iepoch=iepoch, max_time_budget=max_budget)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bc2d5089-4842-4edb-ce95-80e523d4b77b",
        "id": "Cn1oQnaCJLQ7"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 44.82   Std: 1.23\n"
          ]
        }
      ],
      "id": "Cn1oQnaCJLQ7"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Local Search"
      ],
      "metadata": {
        "id": "g-OtxcoUJLQ8"
      },
      "id": "g-OtxcoUJLQ8"
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_acc = []\n",
        "iepoch = 12\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_ils(seed=run_id, iepoch=iepoch, max_budget=max_budget, dataset=dataset)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "0bf94a30-6f9e-40f7-8884-1d48491a1db0",
        "id": "Lb-_PGDWJLQ9"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 44.98   Std: 0.61\n"
          ]
        }
      ],
      "id": "Lb-_PGDWJLQ9"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Succesive Halving"
      ],
      "metadata": {
        "id": "0YiP8KGMJLQ9"
      },
      "id": "0YiP8KGMJLQ9"
    },
    {
      "cell_type": "code",
      "source": [
        "all_test_acc = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, total_epoch = run_succesive_halving(seed=run_id, N=k, max_budget=max_budget, dataset=dataset, list_epochs=list_epochs)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "dd8b77c3-eaf7-4ad2-e5b2-8c43c7a4da64",
        "id": "SLSFqfvdJLQ-"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 43.96   Std: 1.5\n"
          ]
        }
      ],
      "id": "SLSFqfvdJLQ-"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### REA"
      ],
      "metadata": {
        "id": "mCKD7Ft5JLQ-"
      },
      "id": "mCKD7Ft5JLQ-"
    },
    {
      "cell_type": "code",
      "source": [
        "population_size, tournament_size = 20, 4\n",
        "warm_up = False\n",
        "all_test_acc = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_evolution_search(seed=run_id, max_time_budget=max_budget, iepoch=iepoch, population_size=population_size, tournament_size=tournament_size, warm_up=warm_up)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "6c85adab-8a57-48bf-c11b-3a12ff084d8f",
        "id": "CUNzDVyFJLQ-"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 45.2   Std: 0.98\n"
          ]
        }
      ],
      "id": "CUNzDVyFJLQ-"
    },
    {
      "cell_type": "markdown",
      "source": [
        "### REA + W"
      ],
      "metadata": {
        "id": "J3RBNsK8JLQ_"
      },
      "id": "J3RBNsK8JLQ_"
    },
    {
      "cell_type": "code",
      "source": [
        "population_size, tournament_size = 20, 4\n",
        "n_warmup = 2000\n",
        "warm_up = True\n",
        "all_test_acc = []\n",
        "for run_id in range(1, n_run+1):\n",
        "    arch, test_acc, total_time, _ = run_evolution_search(seed=run_id, max_time_budget=max_budget, iepoch=iepoch, population_size=population_size, tournament_size=tournament_size, warm_up=warm_up, n_warmup=n_warmup)\n",
        "    if verbose:\n",
        "        print('ID:', run_id)\n",
        "        print(f'-> Best solution: {encode_int_list_2_ori_input(arch)}')\n",
        "        print(f'-> Test accuracy: {test_acc*100} %')\n",
        "    all_test_acc.append(test_acc*100)\n",
        "print('Mean:', np.round(np.mean(all_test_acc), 2), '  Std:', np.round(np.std(all_test_acc), 2))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e7805feb-5816-48fe-99eb-3fd0be62a5a6",
        "id": "Yu6rReeSJLQ_"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mean: 45.74   Std: 0.71\n"
          ]
        }
      ],
      "id": "Yu6rReeSJLQ_"
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.3"
    },
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "njsh04MeTDnm",
        "p3F54IGCn9BG",
        "B5u0GCCMjV4R",
        "PeH_GXxNjuIp",
        "thcVugGOjzPz",
        "0Y2qS2Qmj2Sz",
        "rDPFwN6xjEn1",
        "X37x-iSgqTHY",
        "XQ72nto5o4_T",
        "wvcbOh1VpBg3"
      ],
      "toc_visible": true
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}