"""Search topology in Nasbench by adding edges to the graph.

It works with https://github.com/google/pyglove/blob/main/examples/automl/nasbench/nasbench.py
by adding `edge_search_algorithm` to `create_search_algorithm`
"""

from typing import Optional, List
from absl import logging

import pyglove as pg

import layer_nas_oss as layer_nas


# Constant specs defined by Nasbench
NUM_OPS = 5
INPUT_IDX = 0
OUTPUT_IDX = 6
NUM_OP_CHOICES = 3
DNA_LENGTH = 26
MAX_NUM_EDGE = 9


_EDGE_INDEX = [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
               (1, 2), (1, 3), (1, 4), (1, 5), (1, 6),
               (2, 3), (2, 4), (2, 5), (2, 6),
               (3, 4), (3, 5), (3, 6),
               (4, 5), (4, 6),
               (5, 6)]


def edge_to_index(from_node: int, to_node: int) -> int:
  assert from_node < to_node
  assert to_node <= OUTPUT_IDX
  assert from_node >= INPUT_IDX
  return _EDGE_INDEX.index((from_node, to_node)) + NUM_OPS


def add_op(
    parent_dna: List[int], parent_op_index: int,
    op: int, from_node: int, to_node: int) -> Optional[List[int]]:
  """Add an op to connect the ops on the graph."""
  logging.info('EdgeSearch insert %d-th node into: %d->%d',
               parent_op_index+1, from_node, to_node)
  dna_number = [0] * len(parent_dna)
  parent_dna[edge_to_index(from_node, to_node)] = 0

  for i in range(0, to_node):
    # Op (i) => (i)
    dna_number[i] = parent_dna[i]
    # Edge (j, i) => (j, i)
    for j in range(0, i):
      edge = edge_to_index(j, i)
      dna_number[edge] = parent_dna[edge]
    # This does not correctly insert the op.
    for j in range(i+1, NUM_OPS):
      dna_number[edge_to_index(i, j+1)] = parent_dna[edge_to_index(i, j)]
    # Edge (i, output) => (i, output)
    dna_number[edge_to_index(i, OUTPUT_IDX)] = parent_dna[
        edge_to_index(i, OUTPUT_IDX)]

  # All nodes after to_node increase index by 1
  for i in range(to_node, NUM_OPS):
    # Op (i) => (i+1)
    dna_number[i] = parent_dna[i - 1]
    # Edge (j, i) => (j, i+1)
    for j in range(0, i):
      dna_number[edge_to_index(j, i+1)] = parent_dna[edge_to_index(j, i)]
    # Edge (i, j) => (i+1, j+1)
    for j in range(i+1, NUM_OPS):
      dna_number[edge_to_index(i+1, j+1)] = parent_dna[edge_to_index(i, j)]
    # Edge (i, output) => (i+1, output)
    dna_number[edge_to_index(i+1, OUTPUT_IDX)] = parent_dna[
        edge_to_index(i, OUTPUT_IDX)]

  # Insert the op to the index of original to_node
  dna_number[to_node - 1] = op
  # Add two edges from_node->new_to_node->original_to_node
  dna_number[edge_to_index(from_node, to_node)] = 1
  dna_number[edge_to_index(to_node, to_node+1)] = 1
  return dna_number



_CUSTOM_METADATA_NUM_OPS = 'custom_metadata_num_ops'
_INIT_METADATA_NUM_OPS = 1


def build_proposal(parent: layer_nas.Proposal,
                   op: int,
                   from_node: int,
                   to_node: int) -> Optional[layer_nas.Proposal]:
  """Build a proposal from the parent proposal."""
  dna_ops = parent.dna.metadata.get(
      _CUSTOM_METADATA_NUM_OPS, _INIT_METADATA_NUM_OPS)
  if from_node > dna_ops:
    return None
  if to_node != OUTPUT_IDX and to_node > dna_ops:
    return None
  if op < NUM_OP_CHOICES:
    if dna_ops == NUM_OPS:
      return None
    if to_node == OUTPUT_IDX:
      dna_number = append_op(parent.dna.to_numbers(), dna_ops, op, from_node)
    else:
      dna_number = add_op(
          parent.dna.to_numbers(), dna_ops, op, from_node, to_node)

    new_dna = pg.geno.DNA.from_numbers(dna_number, parent.dna.spec)
    new_dna.set_metadata(_CUSTOM_METADATA_NUM_OPS, dna_ops+1)
  else:
    dna_number = add_edge(parent.dna.to_numbers(), dna_ops, from_node, to_node)
    if not dna_number:
      return None
    new_dna = pg.geno.DNA.from_numbers(dna_number, parent.dna.spec)
    new_dna.set_metadata(_CUSTOM_METADATA_NUM_OPS, dna_ops)

  if sum(dna_number[NUM_OPS:]) > MAX_NUM_EDGE:
    return None
  logging.info('EdgeSearch try to propose: %s from %s for group %d ',
               str(dna_number), str(parent.dna), sum(dna_number[NUM_OPS:]))
  proposal = layer_nas.Proposal(
      new_dna,
      parent=parent.dna,
      search_id=sum(dna_number[NUM_OPS:]),
      group_id=sum(dna_number[NUM_OPS:]))
  return proposal


EDGE_BUILD_NAME = 'nasbench:edge_build_fn'


@layer_nas.Registry.register_builder_fn(EDGE_BUILD_NAME)
def edge_build_fn(
    parent: layer_nas.Proposal,
    choice: int,
    algo: layer_nas.LayerNAS) -> Optional[layer_nas.Proposal]:
  """Generates children by applying all choices in next layer."""
  del algo
  op_choice = choice % 4
  from_node, to_node = _EDGE_INDEX[choice // 4]
  return build_proposal(parent, op_choice, from_node, to_node)


def append_op(
    parent_dna: List[int], parent_op_index: int,
    op: int, from_node: int) -> Optional[List[int]]:
  """Add an op to connect the output on the graph."""
  logging.info('EdgeSearch add node: %d->%d->%d',
               from_node, parent_op_index+1, OUTPUT_IDX)
  dna_number = parent_dna.copy()
  dna_number[parent_op_index] = op
  dna_number[edge_to_index(from_node, OUTPUT_IDX)] = 0
  dna_number[edge_to_index(from_node, parent_op_index+1)] = 1
  dna_number[edge_to_index(parent_op_index+1, OUTPUT_IDX)] = 1
  return dna_number


def add_edge(
    parent_dna: List[int], parent_op_index: int,
    from_node: int, to_node: int) -> Optional[List[int]]:
  """Add an edge on the current graph."""
  edge_index = edge_to_index(from_node, to_node)
  if parent_dna[edge_index] == 1:
    return None
  if from_node > parent_op_index:
    return None
  if to_node > parent_op_index and to_node != OUTPUT_IDX:
    return None
  logging.info('EdgeSearch add edge directly: %d->%d', from_node, to_node)
  dna_number = parent_dna.copy()
  dna_number[edge_index] = 1
  return dna_number


def edge_search_algorithm(
    num_search_per_layer: int,
    num_children_per_search: int,
    num_sample_per_search: int):
  """Returns the algorithm for edge search in Nasbench."""
  init_proposals = []
  for i in range(NUM_OP_CHOICES):
    dna = [0] * DNA_LENGTH
    dna[0] = i
    dna[edge_to_index(INPUT_IDX, 1)] = 1
    dna[edge_to_index(1, OUTPUT_IDX)] = 1
    init_proposals.append(dna)

  num_choice = (NUM_OP_CHOICES + 1) * len(_EDGE_INDEX)
  return layer_nas.LayerNAS(
      num_choice=[num_choice] * (MAX_NUM_EDGE + 1),
      init_dna=init_proposals,
      filter_fn_name=layer_nas.ALWAYS_SEARCH_NAME,
      bucket_fn_name=layer_nas.UNIQUE_BUCKET_NAME,
      builder_fn_name=EDGE_BUILD_NAME,
      fill_all_rows=False,
      num_search_per_layer=num_search_per_layer,
      num_children_per_search=num_children_per_search,
      num_sample_per_search=num_sample_per_search)
