"""Algorithm for natsbench size search.

It works with https://github.com/google/pyglove/blob/main/examples/automl/natsbench/natsbench.py
by adding `size_search_algorithm` to `get_algorithm`.

"""

import pyglove as pg

import layer_nas_oss as layernas


# Constant specs defined by Natsbench
NUM_LAYERS = 5
# Choice: [8, 16, 24, 32, 40, 48, 56, 64]
NUM_CHOICES = 8


COST = {
    'cifar10': [
        [274.40, 228.63, 188.76, 154.79, 126.71, 104.54, 88.26, 77.88],
        [274.40, 269.43, 263.72, 258.29, 253.17, 248.33, 243.79, 239.55],
        [274.40, 265.99, 257.43, 249.74, 242.94, 237.03, 232.00, 227.85],
        [274.40, 273.16, 271.73, 270.37, 269.09, 267.88, 266.75, 265.69],
        [227.40, 272.38, 270.32, 268.48, 266.86, 265.47, 264.29, 265.34],
    ],
    'cifar100': [
        [274.41, 228.64, 188.77, 154.79, 126.72, 104.54, 88.27, 77.89],
        [274.41, 269.44, 263.72, 258.30, 254.17, 248.34, 243.80, 239.56],
        [274.41, 266.00, 257.43, 249.75, 242.95, 237.03, 232.00, 227.86],
        [274.41, 273.16, 271.74, 270.38, 269.10, 267.89, 266.75, 265.69],
        [274.41, 272.39, 270.32, 268.48, 266.87, 265.47, 264.29, 263.34],
    ],
    'ImageNet16-120': [
        [68.61, 57.17, 47.20, 38.70, 31.69, 26.14, 22.07, 19.48],
        [68.61, 67.37, 65.94, 64.58, 63.30, 62.09, 60.96, 59.90],
        [68.61, 66.51, 64.36, 62.44, 60.74, 59.26, 58.01, 56.97],
        [68.61, 68.30, 67.94, 67.60, 67.28, 66.98, 66.69, 66.43],
        [68.61, 68.10, 67.59, 67.13, 66.72, 66.37, 66.07, 65.84],
    ],
}

REWARD = {
    'cifar10': [
        [84.956, 84.464, 84.467, 84.36, 83.564, 83.624, 84.048, 83.256],
        [84.956, 83.984, 84.124, 84.216, 83.824, 83.652, 83.54, 83.14],
        [84.956, 84.204, 84.576, 84.256, 84.104, 83.572, 82.956, 81.74],
        [84.956, 84.884, 84.444, 84.556, 83.98, 83.428, 83.084, 81.776],
        [84.956, 84.668, 84.672, 84.524, 84.716, 84.592, 84.236, 82.58],
    ],
    'cifar100': [
        [61.06, 60.16, 60.58, 60.12, 59.46, 59.24, 60.04, 59.42],
        [61.06, 59.20, 59.42, 58.68, 60.08, 59.38, 60.26, 57.10],
        [61.06, 60.90, 60.24, 59.06, 59.02, 58.38, 57.74, 56.26],
        [61.06, 60.04, 60.58, 59.44, 59.16, 57.88, 57.50, 55.72,],
        [61.06, 60.70, 59.84, 58.44, 57.94, 56.56, 54.60, 44.90],
    ],
    'ImageNet16-120': [
        [39.13, 38.47, 37.03, 37.10, 37.23, 36.67, 36.60, 34.73],
        [39.13, 37.97, 38.37, 37.47, 35.77, 36.50, 36.60, 34.70],
        [39.13, 38.13, 38.07, 38.03, 37.37, 36.53, 36.57, 33.70],
        [39.13, 38.40, 37.90, 37.83, 37.23, 37.77, 36.07, 35.17],
        [39.13, 38.13, 37.53, 37.53, 37.3, 34.67, 33.13, 28.03],
    ],
}


def size_search_algorithm(
    dataset: str,
    cost_min: float,
    cost_max: float,
    num_search_per_layer: int,
    num_children_per_search: int,
    num_sample_per_search: int,
    fill_all_rows: bool) -> pg.DNAGenerator:
  """Returns the algorithm for edge search in Nasbench."""
  init_proposals = []
  for i in range(NUM_CHOICES):
    dna = [0] * NUM_LAYERS
    dna[0] = i
    init_proposals.append(dna)

  return layer_nas.LayerNAS(
      num_choice=[NUM_CHOICES] * NUM_LAYERS,
      reward_per_choice=REWARD[dataset],
      cost_per_choice=COST[dataset],
      target_cost_min=cost_min,
      target_cost_max=cost_max,
      init_dna=init_proposals,
      # Use cost_in_range when comment out
      # should_search_fn=layer_nas.ALWAYS_SEARCH_NAME,
      bucket_fn_name=layer_nas.UNIQUE_BUCKET_NAME,
      fill_all_rows=fill_all_rows,
      num_search_per_layer=num_search_per_layer,
      num_children_per_search=num_children_per_search,
      num_sample_per_search=num_sample_per_search)
