"""Pyglove algorithm for LayerNAS.

The algorithms works for the AutoML problem that the model is constructed by
multiple layers, and each layer has a range of search options. The algorithm
sovles the problem by considering it as a knapsack problem:
  Opt[i+1, c] = max{Acc(model[i, c-cost[i+1, k]]) for k in search_options[i]}
  Opt[i+1, c] means for layer i+1 and target cost of c, the best architecture
  of the model

"""

import enum
import random
import time
from typing import Optional, Tuple, Union, List
from absl import logging
import pyglove as pg


_DNA_METADATA_PROPOSAL_ID = 'proposal_id'

_DNA_METADATA_GROUP_ID = 'group_id'

_DNA_METADATA_SEARCH_ID = 'search_id'

_DNA_METADATA_REWARD = 'reward'

_DNA_METADATA_COST = 'cost'

_DNA_METADATA_BUCKET = 'bucket'

_DNA_METADATA_PARENT = 'parent'

_DNA_METADATA_PROPOSED_TIMESTAMP = 'proposed_timestamp'


LAYERWISE_BUILDER_NAME = 'builtin:layerwise_append_search_option_fn'

COST_BUCKET_NAME = 'builtin:cost_bucket_fn'

UNIQUE_BUCKET_NAME = 'builtin:unique_bucket_fn'

COST_IN_RANGER_NAME = 'builtin:is_descendant_in_cost_range'

ALWAYS_SEARCH_NAME = 'builtin:always_search_fn'


class Registry:
  """Register fns for the algorithm."""
  filter_fn = {}
  builder_fn = {}
  bucket_fn = {}

  @classmethod
  def register_filter_fn(cls, name):
    def decorator(fn):
      assert name not in cls.filter_fn
      cls.filter_fn[name] = fn
      return fn
    return decorator

  @classmethod
  def register_builder_fn(cls, name):
    def decorator(fn):
      assert name not in cls.builder_fn
      cls.builder_fn[name] = fn
      return fn
    return decorator

  @classmethod
  def register_bucket_fn(cls, name):
    def decorator(fn):
      assert name not in cls.bucket_fn
      cls.bucket_fn[name] = fn
      return fn
    return decorator


class Proposal(pg.symbolic.Object):
  """Wrapper class with extended functionality for DNA."""

  def __init__(self, dna: pg.geno.DNA, parent: Optional[pg.geno.DNA],
               search_id: int, group_id: int):
    self.dna = dna
    self.dna.set_metadata(_DNA_METADATA_PARENT, parent)
    self.dna.set_metadata(_DNA_METADATA_GROUP_ID, group_id)
    self.dna.set_metadata(_DNA_METADATA_SEARCH_ID, search_id)
    self.dna.set_metadata(_DNA_METADATA_PROPOSED_TIMESTAMP, time.time())

  def get_search_id(self) -> int:
    return self.dna.metadata[_DNA_METADATA_SEARCH_ID]

  def set_search_id(self, search_id: int) -> None:
    self.dna.set_metadata(_DNA_METADATA_SEARCH_ID, search_id)

  def get_group_id(self) -> int:
    return self.dna.metadata[_DNA_METADATA_GROUP_ID]

  def set_group_id(self, group_id: int) -> None:
    self.dna.set_metadata(_DNA_METADATA_GROUP_ID, group_id)

  def get_cost(self) -> Optional[float]:
    return self.dna.metadata.get(_DNA_METADATA_COST, None)

  def set_cost(self, cost: float):
    self.dna.set_metadata(_DNA_METADATA_COST, cost)

  def get_bucket(self) -> Optional[int]:
    return self.dna.metadata.get(_DNA_METADATA_BUCKET, None)

  def set_bucket(self, bucket: int):
    self.dna.set_metadata(_DNA_METADATA_BUCKET, bucket)

  def get_reward(self) -> Optional[float]:
    return self.dna.metadata.get(_DNA_METADATA_REWARD, None)

  def set_reward(self, reward: float):
    self.dna.set_metadata(_DNA_METADATA_REWARD, reward)

  def get_life_minutes(self) -> Optional[int]:
    """Returns the duration in minutes since proposed."""
    proposed_time = self.dna.metadata.get(_DNA_METADATA_PROPOSED_TIMESTAMP, -1)
    if proposed_time < 0:
      return None
    return (time.time() - proposed_time) // 60

  def equal(self, dna: pg.geno.DNA) -> bool:
    return self.dna.to_numbers() == dna.to_numbers()

  def get_dna_str(self) -> str:
    return str(self.dna)

  def get_key(self) -> str:
    return '{}/{}'.format(self.get_dna_str(), self.get_search_id())

  def get_parent(self):
    return self.dna.metadata.get(_DNA_METADATA_PARENT, None)

  def parent_equal(self, dna: pg.geno.DNA) -> bool:
    parent = self.get_parent()
    if not parent:
      return False
    return parent.to_numbers() == dna.to_numbers()

  def clone(self) -> 'Proposal':
    dna = pg.geno.DNA.from_numbers(self.dna.to_numbers(), self.dna.spec)
    newp = Proposal(dna, self.get_parent(), self.get_search_id(),
                    self.get_group_id())
    for k in self.dna.metadata:
      newp.dna.set_metadata(k, self.dna.metadata[k])
    newp.dna.set_metadata(_DNA_METADATA_PROPOSED_TIMESTAMP, time.time())
    return newp

  def is_descendant_of(self, ancestor: 'Proposal') -> bool:
    if self.equal(ancestor.dna):
      return self.get_search_id() >= ancestor.get_search_id()
    if self.parent_equal(ancestor.dna):
      return self.get_search_id() > ancestor.get_search_id()
    return False


@Registry.register_filter_fn(COST_IN_RANGER_NAME)
def _is_descendant_in_cost_range(
    proposal: Proposal,
    search_order: List[int],
    cost: List[List[float]],
    target_cost_min: float,
    target_cost_max: float) -> bool:
  """Returns True if a descendant of proposal could be in cost range."""
  search_id = proposal.get_search_id()
  min_cost = cost[0][0]
  dna = proposal.dna.to_numbers()

  for i in range(search_id + 1):
    group_id = search_order[i]
    min_cost += cost[group_id][dna[group_id]] - cost[0][0]
  max_cost = min_cost
  for i in range(search_id + 1, len(search_order)):
    min_cost += min(cost[search_order[i]]) - cost[0][0]
    max_cost += max(cost[search_order[i]]) - cost[0][0]

  return min_cost <= target_cost_max and max_cost >= target_cost_min


@Registry.register_builder_fn(LAYERWISE_BUILDER_NAME)
def _layerwise_append_search_option_fn(
    parent: Proposal, choice: int,
    algo: 'LayerNAS') -> Optional[Proposal]:
  """Generates children by applying a choice in next layer."""
  if choice == 0:
    # choice 0 is already searched when constructing initial dnas,
    # and has been entered into all memorial table rows.
    return None
  parent_search = parent.get_search_id()
  child_search_id = parent_search + 1
  child_group_id = algo.search_order[child_search_id]
  dna_number = parent.dna.to_numbers()
  dna_number[child_group_id] = choice
  proposal = Proposal(
      pg.geno.DNA.from_numbers(dna_number, parent.dna.spec),
      parent=parent.dna,
      search_id=child_search_id,
      group_id=child_group_id)
  return proposal


@Registry.register_bucket_fn(COST_BUCKET_NAME)
def _cost_bucket_fn(proposal: Proposal) -> int:
  return int(proposal.get_cost() * 10000)


@Registry.register_bucket_fn(UNIQUE_BUCKET_NAME)
def unique_bucket_fn(proposal: Proposal) -> int:
  del proposal
  return random.getrandbits(64)


@Registry.register_filter_fn(ALWAYS_SEARCH_NAME)
def always_search_fn(proposal: Proposal,
                     search_order: List[int],
                     cost: List[List[float]],
                     target_cost_min: float,
                     target_cost_max: float) -> bool:
  del proposal, search_order, cost, target_cost_min, target_cost_max
  return True

@pg.members(
    [
        # Fundamental settings
        ('num_choice', pg.typing.List(
            pg.typing.Int()), 'Num of choices per group'),
        ('target_cost_min', pg.typing.Float().noneable(), 'Min target cost'),
        ('target_cost_max', pg.typing.Float().noneable(), 'Max target cost'),

        # Initial settings
        ('init_dna', pg.typing.List(pg.typing.List(
            pg.typing.Int())).noneable(), 'Initial DNAs to start with'),
        ('reward_per_choice', pg.typing.List(pg.typing.List(
            pg.typing.Float())).noneable(),
         'reward for each choice in each layer'),
        ('cost_per_choice', pg.typing.List(pg.typing.List(
            pg.typing.Float())).noneable(),
         'cost for each choice in each layer'),

        # Changeable logics
        ('filter_fn_name',
         pg.typing.Str(default=COST_IN_RANGER_NAME),
         'Return True if the proposal should be searched'),
        ('bucket_fn_name',
         pg.typing.Str(default=COST_BUCKET_NAME),
         'Returns the bucket id of the candidate'),
        ('builder_fn_name',
         pg.typing.Str(default=LAYERWISE_BUILDER_NAME),
         'Build descendants from a proposal and search option'),

        # DFS params
        ('fill_all_rows', pg.typing.Bool(default=True),
         'Fill all rows in MemorialTable.'),
        ('num_search_per_layer', pg.typing.Int(default=-1),
         'Num of searches for each layer'),
        ('num_sample_per_search', pg.typing.Int(default=-1),
         'Num of sample of the pool to select parent '),
        ('num_children_per_search', pg.typing.Int(default=-1),
         'Num of children to generate from a parent.'),
    ],
    init_arg_list=['num_choice'])


class LayerNAS(pg.geno.DNAGenerator):
  """Search model spec with LayerNAS."""

  def _setup(self) -> None:
    """Setup the algorithm."""
    self.num_group = len(self.num_choice)
    self.memorial_table = [[] for _ in range(self.num_group)]
    self.current_search = 0
    self.pending_proposals = []  # List[Proposal]
    self.active_proposals = []  # List[Proposal]
    self.dna_to_feedback = {}
    logging.info('COMPRESSOR LayerNAS starts in mode %d', self.mode)

    self.filter_fn = Registry.filter_fn[self.filter_fn_name]
    self.builder_fn = Registry.builder_fn[self.builder_fn_name]
    self.bucket_fn = Registry.bucket_fn[self.bucket_fn_name]

    # DFS helpers
    self.unsearched_children = {}
    self.num_searched_per_layer = [0] * self.num_group
    self.num_unsearched_per_layer = [0] * self.num_group
    self.num_children_per_layer = [0] * self.num_group
    self.num_dfs_iteration = 0

    if self.reward_per_choice:
      # Decide search order based reward and cost per choice.
      assert self.num_group == len(self.reward_per_choice)
      assert self.num_group == len(self.cost_per_choice)
      for i, c in enumerate(self.num_choice):
        assert c == len(self.reward_per_choice[i])
        assert c == len(self.cost_per_choice[i])

      self.init_reward = self.reward_per_choice
      self.init_cost = self.cost_per_choice
      self._finish_init()

      # Feed initial proposals.
      if self.init_dna:
        for dna in self.init_dna:
          proposal = Proposal(
              pg.geno.DNA.from_numbers(dna, self._dna_spec),
              None, search_id=0, group_id=0)
          self.pending_proposals.append(proposal)

      proposal = Proposal(
          pg.geno.DNA.from_numbers([0] * self.num_group, self._dna_spec),
          None, search_id=-1, group_id=0)
      self.pending_proposals.append(proposal)

      for i in range(0, len(self.num_choice)):
        self.init_cost.append([0.0] * self.num_choice[i])
        self.init_reward.append([0.0] * self.num_choice[i])
        for j in range(1, self.num_choice[i]):
          dna_value = [0] * self.num_group
          dna_value[i] = j
          proposal = Proposal(
              pg.geno.DNA.from_numbers(dna_value, self._dna_spec),
              None,
              search_id=-1,
              group_id=i)
          self.pending_proposals.append(proposal)
    else:
      raise RuntimeError('Unknown mode: {}'.format(self.mode))
    self._dump_debug_info()

  @property
  def multi_objective(self) -> bool:
    return True

  def _proposal_index_in_list(self, proposal: Proposal,
                              plist: List[Proposal]) -> int:
    for i, p in enumerate(plist):
      if proposal.equal(p.dna):
        return i
    return -1

  def _try_to_propose(self, proposal: Proposal) -> bool:
    if not self._should_search(proposal):
      logging.info(
          'COMPRESSOR_SEARCH Skip %s, it should not be searched', proposal.dna)
      return False

    dna_str = proposal.get_dna_str()
    active_proposal_index = self._proposal_index_in_list(
        proposal, self.active_proposals)
    if active_proposal_index >= 0:
      # Proposal is running, update proposals with new metadata.
      # This should not happen.
      logging.error(
          'COMPRESSOR_SEARCH Propose %s from search, which is active',
          proposal.dna)
      return False
    elif dna_str in self.dna_to_feedback:
      # Proposal has been executed, insert to memorial table.
      logging.info(
          'COMPRESSOR_SEARCH Propose %s from search, which has completed',
          proposal.dna)
      self._insert_proposal_to_memorial_table(
          proposal, self.dna_to_feedback[dna_str])
      return False
    else:
      if self._proposal_index_in_list(proposal, self.pending_proposals) >= 0:
        logging.info(
            'COMPRESSOR_SEARCH Already propose %s from search, skip',
            proposal.dna)
        return False
      logging.info(
          'COMPRESSOR_SEARCH Propose %s from search, add to pending',
          proposal.dna)
      return True

  def _should_search(self, parent: Proposal) -> bool:
    # If the parent has been fully searched, return False.
    parent_key = parent.get_key()
    if parent_key in self.unsearched_children and not self.unsearched_children[
        parent_key]:
      return False

    # Check if the proposal should be searched,
    # it usually checks if the cost is in range.
    return self.filter_fn(
        parent, self.search_order, self.cost_per_choice,
        self.target_cost_min, self.target_cost_max)

  def _search_proposal(self, parent: Proposal) -> bool:
    if not self._should_search(parent):
      logging.info('COMPRESSOR_SEARCH %s skip, out of range')
      return False
    parent_search = parent.get_search_id()
    logging.info('COMPRESSOR_SEARCH %s at search_order %d', parent.dna,
                 parent_search)
    # Parent should not be from initializing.
    assert parent_search >= 0
    if parent_search == self.num_group - 1:
      return False

    # Store children choices
    parent_key = parent.get_key()
    if parent_key not in self.unsearched_children:
      num_children_choice = self.num_choice[
          self.search_order[parent_search + 1]]
      self.unsearched_children[parent_key] = list(range(num_children_choice))
      self.num_children_per_layer[parent_search] += num_children_choice
      self.num_unsearched_per_layer[parent_search] += 1
    children = self.unsearched_children[parent_key]
    num_children = min(len(children), self.num_children_per_search)

    # Randomly select children choice, and generate proposal
    new_search_cnt = 0
    children_pop_cnt = 0
    while children:
      if self.num_children_per_search < 0:
        # Not using random to make unit test easier to verify
        choice = children.pop(0)
      else:
        k = random.randint(0, len(children) - 1)
        choice = children.pop(k)
      children_pop_cnt += 1
      p = self.builder_fn(parent, choice, self)
      if p and self._try_to_propose(p):
        self.pending_proposals.append(p)
        new_search_cnt += 1
      if new_search_cnt == num_children:
        break
    if new_search_cnt > 0:
      self.num_searched_per_layer[parent_search] += 1
    if not children:
      self.num_unsearched_per_layer[parent_search] -= 1
    self.num_children_per_layer[parent_search] -= children_pop_cnt
    return new_search_cnt > 0

  def _dump_debug_info(self):
    logging.info('-----------------')
    logging.info('\tCurrent group: %d', self.current_search)
    logging.info('\tPending proposals:')
    for p in self.pending_proposals:
      logging.info('\t\t%s', p.dna)
    logging.info('\tActive proposals:')
    for p in self.active_proposals:
      logging.info('\t\t%s: %d min', p.dna, p.get_life_minutes())
    if hasattr(self, 'num_searched_per_layer'):
      logging.info('\t#%d DFS run', self.num_dfs_iteration)
      logging.info('\tSearched per group: %s', str(self.num_searched_per_layer))
    logging.info('\tMemorial table size: %s',
                 str([len(x) for x in self.memorial_table]))
    logging.info('\tMemorial table unsearched candidates: %s',
                 str(self.num_unsearched_per_layer))
    logging.info('\tMemorial table unsearched children: %s',
                 str(self.num_children_per_layer))
    logging.info('-----------------')

  def _update_group_info(self):
    num_current_unsearched = len(
        [p for p in self.memorial_table[self.current_search]
         if self._should_search(p)])
    for p in self.pending_proposals:
      if p.get_search_id() == self.current_search and self._should_search(p):
        num_current_unsearched += 1

    # Move to next layer if:
    #   * no searchable proposal
    #   * searched enough proposals in current layer
    move_to_next = num_current_unsearched == 0
    if self.num_search_per_layer > 0:
      if not move_to_next:
        move_to_next = self.num_searched_per_layer[
            self.current_search] >= self.num_search_per_layer

    if move_to_next:
      logging.info(
          'COMPRESSOR_SEARCH move to next search order %d, '
          '%d have been searched, %d unsearched.',
          self.current_search + 1,
          self.num_searched_per_layer[self.current_search],
          num_current_unsearched)
      self.current_search += 1

    # Move to head if reach the last layer.
    if self.current_search >= self.num_group - 1:
      self.num_dfs_iteration += 1
      self.num_searched_per_layer = [0] * self.num_group
      self.current_search = 0

  def _propose_from_pending(self) -> Optional[Proposal]:
    if not self.pending_proposals:
      return None
    proposal = self.pending_proposals.pop(0)
    logging.error('COMPRESSOR_PROPOSE activate: %s', proposal.dna)
    self.active_proposals.append(proposal)
    return proposal

  def _propose_from_grid(self, search_id: int) -> Optional[Proposal]:
    memorial_row = [p for p in self.memorial_table[search_id]
                    if self._should_search(p)]
    logging.info(
        'COMPRESSOR_PROPOSE search from memorial grids,'
        'Current search order %d contains %d elements, %d searchable',
        search_id, len(self.memorial_table[search_id]), len(memorial_row))
    if not memorial_row:
      return None
    if self.num_sample_per_search == -1:
      best = memorial_row[0]
    else:
      sample_pool_size = min(
          self.num_sample_per_search,
          max(1, len(memorial_row) // 2))
      pool = random.sample(memorial_row, sample_pool_size)
      best = None
      for proposal in pool:
        if not best or best.get_reward() < proposal.get_reward():
          best = proposal

    assert best is not None
    logging.info('COMPRESSOR_PROPOSE search memorial grid: %s', best.dna)
    if self._search_proposal(best):
      return self._propose_from_pending()
    else:
      return None

  def _propose(self) -> pg.geno.DNA:
    """Implementation of pg.geno.DNA proposal."""
    # If there's available proposal in pending, propose it.
    p = self._propose_from_pending()
    if p:
      return p.dna

    # Keep searching layer by layer because some candidates have been searched.
    # Stop the loop when search more than one iteration.
    previous_dfs_iteration = self.num_dfs_iteration
    while self.num_dfs_iteration <= previous_dfs_iteration + 1:
      p = self._propose_from_grid(self.current_search)
      self._update_group_info()
      if p:
        self._dump_debug_info()
        return p.dna
    raise RuntimeError('COMPRESSOR_PROPOSE no new proposals found')

  def _remove_proposal(self, proposal: Proposal) -> None:
    search_id = proposal.get_search_id()
    logging.error('COMPRESSOR_REMOVE proposal %s in search %d', proposal.dna,
                  search_id)

    # Mark active proposals searched from this one.
    # active_proposals can only be removed from _feedback.
    for p in self.active_proposals:
      if p.is_descendant_of(proposal):
        logging.info('COMPRESSOR_REMOVE %s is active', p.dna)

    # Remove pending proposals searched from this one.
    i = 0
    while i < len(self.pending_proposals):
      p = self.pending_proposals[i]
      if p.is_descendant_of(proposal):
        logging.info('COMPRESSOR_REMOVE %s from pending proposals',
                     self.pending_proposals[i].dna)
        self.pending_proposals.pop(i)
      else:
        i += 1

  def _finish_init(self):
    logging.info(
        'COMPRESSOR_INIT finish init, cost per group: %s; reward per group %s',
        str(self.init_cost), str(self.init_reward))
    avg_reward = []
    base_reward = self.init_reward[0][0]
    for i in range(len(self.init_reward)):
      reward_delta = [abs(x - base_reward) for x in self.init_reward[i]]
      avg_reward.append(sum(reward_delta) / len(self.init_reward[i]))
    sorted_reward = sorted(
        enumerate(avg_reward), key=lambda x: x[1], reverse=True)
    self.search_order = [i[0] for i in sorted_reward]
    logging.info('COMPRESSOR_INIT Search order: %s', str(self.search_order))

    # Add first level proposals to memorial table,
    # they will be searched in next _propose.
    for i in range(len(self.search_order)):
      for j in range(self.num_choice[self.search_order[i]]):
        dna_number = [0] * self.num_group
        dna_number[self.search_order[i]] = j
        proposal = Proposal(
            pg.geno.DNA.from_numbers(dna_number, self._dna_spec),
            parent=None,
            search_id=i,
            group_id=self.search_order[i])
        self.dna_to_feedback[proposal.get_dna_str()] = (
            self.init_reward[self.search_order[i]][j],
            self.init_cost[self.search_order[i]][j])
        if i == 0:
          self._insert_proposal_to_memorial_table(
              proposal, self.dna_to_feedback[proposal.get_dna_str()])
    self._dump_debug_info()

  def _update_initialized_proposal_feedback(
      self, proposal: Proposal, reward: Union[float, Tuple[float, ...]]):
    # Update base proposal
    dna_number = proposal.dna.to_numbers()
    if sum(dna_number) == 0:
      logging.info('COMPRESSOR_FEEDBACK Receive base proposal.')
      for i in range(self.num_group):
        self.init_reward[i][0] = reward[0]
        self.init_cost[i][0] = reward[1]
    else:
      # Update other init proposals
      group_id = proposal.get_group_id()
      choice = proposal.dna.to_numbers()[group_id]
      logging.info(
          'COMPRESSOR_FEEDBACK Receive %d initial proposal %s, '
          'group: %d, choice: %d', self._num_initialized_proposals,
          proposal.get_dna_str(), group_id, choice)
      self.init_reward[group_id][choice] = reward[0]
      self.init_cost[group_id][choice] = reward[1]
    self._num_initialized_proposals += 1
    if self._num_initialized_proposals == self._num_expected_initialized_proposals:
      self._finish_init()

  def _insert_proposal_to_memorial_table(
      self, proposal: Proposal, reward: Union[float, Tuple[float, ...]]):
    dna_search_id = proposal.get_search_id()

    if self.fill_all_rows:
      # For each group row in memorial table, record this proposal
      search_id_range = range(dna_search_id, self.num_group)
    else:
      search_id_range = range(dna_search_id, dna_search_id + 1)
    for search_id in search_id_range:
      proposal_copy = proposal.clone()
      proposal_copy.set_search_id(search_id)
      proposal_copy.set_group_id(self.search_order[search_id])
      proposal_copy.set_reward(reward[0])
      proposal_copy.set_cost(reward[1])
      bucket = self.bucket_fn(proposal_copy)
      proposal_copy.set_bucket(bucket)

      # Find location to insert
      pos = 0
      for i in range(len(self.memorial_table[search_id])):
        if self.memorial_table[search_id][i].get_bucket() < bucket:
          pos += 1
        else:
          break
      logging.info(
          '\tCOMPRESOR_INSERT Try to insert %s to group %d at index %d',
          proposal.dna, search_id, pos)

      # Clear inferior proposals and insert new proposal if possile.
      if pos == len(self.memorial_table[search_id]):
        logging.error('\tCOMPRESOR_INSERT insert without conflict %s',
                      proposal_copy.dna)
        self.memorial_table[search_id].insert(pos, proposal_copy)
        continue
      next_proposal = self.memorial_table[search_id][pos]
      if next_proposal.get_bucket() == bucket:
        if next_proposal.get_reward() < reward[0] + 1e-5:
          logging.error('\tCOMPRESOR_INSERT Override %s by %s for reward %f',
                        next_proposal.dna, proposal_copy.dna, reward[0])
          self._remove_proposal(next_proposal)
          self.memorial_table[search_id][pos] = proposal_copy
        else:
          logging.error('\tCOMPRESOR_INSERT Fail to override %s by %s',
                        next_proposal.dna, proposal_copy.dna)
          self._remove_proposal(proposal_copy)
      else:
        logging.error('\tCOMPRESOR_INSERT insert without conflict %s',
                      proposal_copy.dna)
        self.memorial_table[search_id].insert(pos, proposal_copy)
    self._update_group_info()
    self._dump_debug_info()

  def _feedback(
      self, dna: pg.geno.DNA, reward: Union[float, Tuple[float, ...]]) -> None:
    """Feedback a pg.geno.DNA with its reward."""
    logging.info('COMPRESSOR_FEEDBACK receive %s: %s', dna, str(reward))
    if str(dna) in self.dna_to_feedback:
      logging.error('COMPRESSOR_FEEDBACK DNA executed twice: %s', dna)
      return

    if isinstance(reward, float):
      reward = (reward, reward)
    elif len(reward) == 1:
      reward = (reward[0], reward[0])

    self.dna_to_feedback[str(dna)] = reward

    # Find proposal from active and pending proposals.
    proposal = None
    for i in range(len(self.active_proposals)):
      if self.active_proposals[i].equal(dna):
        proposal = self.active_proposals.pop(i)
        break
    if not proposal:
      logging.error('COMPRESSOR_FEEDBACK proposal not found in active, '
                    'is it replayed from another algorithm?')
      return

    self._insert_proposal_to_memorial_table(proposal, reward)
