"""Algorithm for Natsbench topology search.

It works with https://github.com/google/pyglove/blob/main/examples/automl/natsbench/natsbench.py
by adding `topo_search_algorithm` to `get_algorithm`.
"""

from typing import Optional, List

import pyglove as pg

import layer_nas_oss as layer_nas


# Constant specs defined by Nasbench
NUM_OP_CHOICES = 5
DNA_LENGTH = 6


INDEX_TO_EDGE = [(0, 1), (0, 2), (1, 2), (0, 3), (1, 3), (2, 3)]


def edge_to_index(from_node: int, to_node: int) -> int:
  return INDEX_TO_EDGE.index((from_node, to_node))


def _get_op_choice(choice: int, num_ops: int) -> List[int]:
  res = [0] * num_ops
  for i in range(num_ops):
    res[i] = choice % (NUM_OP_CHOICES - 1) + 1
    choice //= (NUM_OP_CHOICES - 1)
  return res


LAYER_EDGES = [
    [(0, 1)],
    [(0, 3), (1, 3)],
    [(0, 2), (1, 2), (2, 3)],
]

EDGE_BUILD_NAME = 'natsbench:edge_build_fn'

DNA_ENCODE_NAME = 'natsbench:dna_encode_fn'


@layer_nas.Registry.register_builder_fn(EDGE_BUILD_NAME)
def edge_build_fn(
    parent: layer_nas.Proposal,
    choice: int,
    algo: layer_nas.LayerNAS) -> Optional[layer_nas.Proposal]:
  """Generates children by applying all choices in next layer."""
  del algo
  op_id = parent.get_search_id() + 1
  dna_number = parent.dna.to_numbers()

  new_dna = dna_number.copy()
  op_choices = _get_op_choice(choice, len(LAYER_EDGES[op_id]))
  for i, edge in enumerate(LAYER_EDGES[op_id]):
    new_dna[edge_to_index(edge[0], edge[1])] = op_choices[i]

  if op_id == 1 and new_dna[4] == 1:
    return None

  return layer_nas.Proposal(
      pg.DNA.from_numbers(new_dna, parent.dna.spec),
      parent=parent.dna,
      search_id=op_id,
      group_id=op_id)


@layer_nas.Registry.register_bucket_fn(DNA_ENCODE_NAME)
def dna_encode_bucket_fn(proposal: layer_nas.Proposal) -> int:
  dna_number = proposal.dna.to_numbers()
  res = 0
  for x in dna_number:
    res = res * NUM_OP_CHOICES + x
  return res


def topo_search_algorithm(
    num_search_per_layer: int,
    num_children_per_search: int,
    num_sample_per_search: int,
    fill_all_rows: bool) -> layer_nas.LayerNAS:
  """Returns the algorithm for edge search in Nasbench."""

  init_proposals = []
  for i in range(1, NUM_OP_CHOICES):
    init_proposals.append([i, 0, 0, 0, 1, 0])

  return layer_nas.LayerNAS(
      num_choice=[NUM_OP_CHOICES - 1,
                  (NUM_OP_CHOICES - 1) ** 2,
                  (NUM_OP_CHOICES - 1) ** 3],
      init_dna=init_proposals,
      filter_fn_name=layer_nas.ALWAYS_SEARCH_NAME,
      bucket_fn_name=DNA_ENCODE_NAME,
      builder_fn_name=EDGE_BUILD_NAME,
      fill_all_rows=fill_all_rows,
      num_search_per_layer=num_search_per_layer,
      num_children_per_search=num_children_per_search,
      num_sample_per_search=num_sample_per_search)
