# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of insertion/deletion-based compound divergence maximization.


Given a list of items, and functions that calculate every item's set of atoms
and set of compounds, the algorithm iteratively assigns items to, or removes an
item from, one of the two target subsets.  The item is selected by greedily
maximizing an "adequacy score", which is a combination of atom similarity and
compound divergence of the two subsets.  In order to reduce the total processing
time, the implementation below only considers a sample of candidate items in
each iteration.

The similarity and divergence (which is just 1.0 - similarity) are calculated
using a sum over the atoms or compounds.  This particular expression allows
localized efficient update of the similarity and divergence when a new item is
assigned to (or removed from) one of the subsets.  We take advantage of this
expression in the implementation below to speed things up.
"""

from __future__ import annotations

import collections
import dataclasses
import enum
import logging
from typing import Any, Callable, Dict, FrozenSet, Generic, Hashable, Iterable, List, MutableSequence, Optional, Set, Tuple, TypeVar, Union

from language.compgen.nqg.tasks import mcd_utils
import numpy as np

from conceptual_learning.cscan import inputs

# The type of the items.
_T = TypeVar('_T')

Atom = TypeVar('Atom', bound=Hashable)
Compound = TypeVar('Compound', bound=Hashable)
AtomOrCompound = Union[Atom, Compound]

_ATOM_COEF = 0.5
_COMPOUND_COEF = 0.1


class _ChangeType(enum.Enum):
  """Represents the change to be made in a single step of splitting."""
  ADD_TO_1 = 'ADD_TO_1'
  ADD_TO_2 = 'ADD_TO_2'
  REMOVE_FROM_1 = 'REMOVE_FROM_1'
  REMOVE_FROM_2 = 'REMOVE_FROM_2'


def _get_original_numerator(sum_1, sum_2,
                            original_divergence, coef):
  original_denominator = sum_1**coef * sum_2**(1.0 - coef)

  original_similarity = 1.0 - original_divergence
  original_numerator = original_denominator * original_similarity

  return original_numerator


def _get_single_numerator_change(counts_1,
                                 counts_2,
                                 atom_or_compound, coef,
                                 change_type):
  """Returns the change in numerator of a single term in divergence."""
  count_1 = counts_1[atom_or_compound]
  count_2 = counts_2[atom_or_compound]

  if ((change_type == _ChangeType.REMOVE_FROM_1 and count_1 <= 0) or
      (change_type == _ChangeType.REMOVE_FROM_2 and count_2 <= 0)):
    raise ValueError('Item to be removed does not have positive count')

  original_numerator = count_1**coef * count_2**(1.0 - coef)

  if change_type == _ChangeType.ADD_TO_1:
    new_numerator = (count_1 + 1)**coef * count_2**(1.0 - coef)
  elif change_type == _ChangeType.ADD_TO_2:
    new_numerator = count_1**coef * (count_2 + 1)**(1.0 - coef)
  elif change_type == _ChangeType.REMOVE_FROM_1:
    new_numerator = (count_1 - 1)**coef * count_2**(1.0 - coef)
  elif change_type == _ChangeType.REMOVE_FROM_2:
    new_numerator = count_1**coef * (count_2 - 1)**(1.0 - coef)

  return new_numerator - original_numerator


def _get_new_denominator(sum_1, sum_2, num_changes,
                         coef, change_type):
  """Returns the updated denominator in divergence after the change."""
  if change_type == _ChangeType.ADD_TO_1:
    new_denominator = (sum_1 + num_changes)**coef * sum_2**(1.0 - coef)
  elif change_type == _ChangeType.ADD_TO_2:
    new_denominator = sum_1**coef * (sum_2 + num_changes)**(1.0 - coef)
  elif change_type == _ChangeType.REMOVE_FROM_1:
    new_denominator = (sum_1 - num_changes)**coef * sum_2**(1.0 - coef)
  elif change_type == _ChangeType.REMOVE_FROM_2:
    new_denominator = sum_1**coef * (sum_2 - num_changes)**(1.0 - coef)

  return new_denominator


def _compute_divergence_if_added_1(
    counts_1,
    counts_2, sum_1, sum_2,
    atoms_or_compounds_to_move, original_divergence,
    coef):
  """Returns the divergence if atoms or compounds are added to counts_1."""
  numerator = _get_original_numerator(sum_1, sum_2, original_divergence, coef)

  for atom_or_compound in atoms_or_compounds_to_move:
    numerator += _get_single_numerator_change(counts_1, counts_2,
                                              atom_or_compound, coef,
                                              _ChangeType.ADD_TO_1)

  new_denominator = _get_new_denominator(sum_1, sum_2,
                                         len(atoms_or_compounds_to_move), coef,
                                         _ChangeType.ADD_TO_1)

  new_similarity = numerator / new_denominator
  return 1.0 - new_similarity


def _compute_divergence_if_added_2(
    counts_1,
    counts_2, sum_1, sum_2,
    atoms_or_compounds_to_move, original_divergence,
    coef):
  """Returns the divergence if atoms or compounds are added to counts_2."""
  numerator = _get_original_numerator(sum_1, sum_2, original_divergence, coef)

  for atom_or_compound in atoms_or_compounds_to_move:
    numerator += _get_single_numerator_change(counts_1, counts_2,
                                              atom_or_compound, coef,
                                              _ChangeType.ADD_TO_2)

  new_denominator = _get_new_denominator(sum_1, sum_2,
                                         len(atoms_or_compounds_to_move), coef,
                                         _ChangeType.ADD_TO_2)

  new_similarity = numerator / new_denominator
  return 1.0 - new_similarity


def _compute_divergence_if_removed_1(
    counts_1,
    counts_2, sum_1, sum_2,
    atoms_or_compounds_to_move, original_divergence,
    coef):
  """Returns the divergence if atoms or compounds removed from counts_1."""
  numerator = _get_original_numerator(sum_1, sum_2, original_divergence, coef)

  for atom_or_compound in atoms_or_compounds_to_move:
    numerator += _get_single_numerator_change(counts_1, counts_2,
                                              atom_or_compound, coef,
                                              _ChangeType.REMOVE_FROM_1)

  new_denominator = _get_new_denominator(sum_1, sum_2,
                                         len(atoms_or_compounds_to_move), coef,
                                         _ChangeType.REMOVE_FROM_1)

  new_similarity = numerator / new_denominator
  return 1.0 - new_similarity


def _compute_divergence_if_removed_2(
    counts_1,
    counts_2, sum_1, sum_2,
    atoms_or_compounds_to_move, original_divergence,
    coef):
  """Returns the divergence if atoms or compounds removed from counts_2."""
  numerator = _get_original_numerator(sum_1, sum_2, original_divergence, coef)

  for atom_or_compound in atoms_or_compounds_to_move:
    numerator += _get_single_numerator_change(counts_1, counts_2,
                                              atom_or_compound, coef,
                                              _ChangeType.REMOVE_FROM_2)

  new_denominator = _get_new_denominator(sum_1, sum_2,
                                         len(atoms_or_compounds_to_move), coef,
                                         _ChangeType.REMOVE_FROM_2)

  new_similarity = numerator / new_denominator
  return 1.0 - new_similarity


def _compute_coverage_if_added_1(atom_counts_1,
                                 num_all_atoms,
                                 atoms_to_move):
  """Returns the atom coverage of items_1 if atoms are added to counts_1."""
  atom_counts_1_change = 0
  for atom in atoms_to_move:
    if atom not in atom_counts_1:
      atom_counts_1_change += 1

  return (len(atom_counts_1) + atom_counts_1_change) / num_all_atoms


def _compute_coverage_if_added_2(atom_counts_1,
                                 num_all_atoms,
                                 atoms_to_move):
  """Returns the atom coverage of items_1 if atoms are added to counts_2."""
  # Adding atoms to items_2 does not affect atom coverage in items_1.
  del atoms_to_move
  return len(atom_counts_1) / num_all_atoms


def _compute_coverage_if_removed_1(atom_counts_1,
                                   num_all_atoms,
                                   atoms_to_move):
  """Returns the atom coverage of items_1 if atoms are removed from counts_1."""
  atom_counts_1_change = 0
  for atom in atoms_to_move:
    if atom_counts_1[atom] <= 1:
      atom_counts_1_change -= 1

  return (len(atom_counts_1) + atom_counts_1_change) / num_all_atoms


def _compute_coverage_if_removed_2(atom_counts_1,
                                   num_all_atoms,
                                   atoms_to_move):
  """Returns the atom coverage of items_1 if atoms are removed from counts_2."""
  # removing atoms from items_2 does not affect atom coverage in items_1.
  del atoms_to_move
  return len(atom_counts_1) / num_all_atoms


def _get_adequacy(atom_divergence, compound_divergence,
                  atom_coverage, target_atom_divergence,
                  atom_similarity_exponent,
                  atom_coverage_exponent):
  """Returns the adaquacy score for atom divergence and compound divergence.

  The adequacy is used to rank the items to be added to the subsets of items.


  Args:
    atom_divergence: The atom divergence.
    compound_divergence: The compound divergence.
    atom_coverage: The coverage of atoms in items_1 among all atoms.
    target_atom_divergence: The target atom divergence.
    atom_similarity_exponent: The exponent used to calculate the atom similarity
      factor.
    atom_coverage_exponent: The exponent used to calculate the atom coverage
      factor.
  """
  # The factors are the larger the better.
  atom_similarity = 1.0 - atom_divergence
  target_atom_similarity = 1.0 - target_atom_divergence
  atom_similarity_factor = (1.0 - max(
      0.0, target_atom_similarity - atom_similarity))**atom_similarity_exponent
  atom_coverage_factor = atom_coverage**atom_coverage_exponent
  compound_divergence_factor = compound_divergence
  return (atom_similarity_factor * compound_divergence_factor *
          atom_coverage_factor)


def _populate_initial_items(items, size,
                            rng):
  result = []
  while len(result) < size:
    result.append(items.pop(rng.choice(len(items))))

  return result


def _get_sampled_item_by_original_index(
    items,
    sample_size,
    rng,
    filter_fn = None):
  """Returns a sample of items and mapping to the original indices."""
  candidate_original_indices = []
  if filter_fn is not None:
    candidate_original_indices = [
        i for i, item in enumerate(items) if filter_fn(item)
    ]

  if candidate_original_indices:
    if len(candidate_original_indices) > sample_size:
      sampled_original_indices = rng.choice(
          candidate_original_indices, sample_size, replace=False)
    else:
      sampled_original_indices = candidate_original_indices
  else:
    # Either filter_fn is None, or all items are filtered out.  In this case we
    # consider all available items, but we avoid realizing range(len(items)) as
    # a list to sample from, since for most typical iterations len(items) is in
    # the range of hundreds of thousands.
    if len(items) > sample_size:
      # This is inefficient if len(items) is just slightly larger than
      # sample_size.
      sampled_original_indices = set()
      while len(sampled_original_indices) < sample_size:
        sampled_original_index = rng.randint(len(items))
        while sampled_original_index in sampled_original_indices:
          sampled_original_index = rng.randint(len(items))
        sampled_original_indices.add(sampled_original_index)
    else:
      sampled_original_indices = list(range(len(items)))

  return {
      original_index: items[original_index]
      for original_index in sampled_original_indices
  }


def _get_sizes_to_delete(num_items,
                         delete_period):
  """Returns the list of total sizes when items should be removed."""
  return list(range(delete_period, num_items, delete_period))


def _select_change_type(current_size_1, current_size_2, size_1,
                        size_2, delete_period,
                        sizes_to_delete,
                        rng):
  """Returns the change type according to current progress of splitting."""
  current_total = current_size_1 + current_size_2

  if (current_size_1 < delete_period or current_size_2 < delete_period):
    should_delete = False
  elif current_total in sizes_to_delete:
    should_delete = True
    sizes_to_delete.remove(current_total)
  else:
    should_delete = False

  if should_delete:
    # We choose the subset to remove item from according to their size.
    change_type = rng.choice(
        [_ChangeType.REMOVE_FROM_1, _ChangeType.REMOVE_FROM_2],
        p=[current_size_1 / current_total, current_size_2 / current_total])
  elif not size_2:
    # In the typical use case size_2 is smaller than size_1, and in some small
    # scale tests size_2 could be zero.
    change_type = _ChangeType.ADD_TO_1
  else:
    items_1_fraction = current_size_1 / size_1
    items_2_fraction = current_size_2 / size_2
    if current_size_1 < size_1 and (items_1_fraction <= items_2_fraction):
      change_type = _ChangeType.ADD_TO_1
    else:
      change_type = _ChangeType.ADD_TO_2

  return change_type


@dataclasses.dataclass
class _MaximizeDivergenceState(Generic[Atom, Compound]):
  """Collection of counters and metrics for splitting algorithm iterations."""
  all_atoms: FrozenSet[Atom]
  atom_counts_1: collections.Counter[Atom]
  atom_counts_2: collections.Counter[Atom]
  compound_counts_1: collections.Counter[Compound]
  compound_counts_2: collections.Counter[Compound]

  # The following fields are computed from the counters above, but we keep them
  # as separate fields and manually update them to improve efficiency, since
  # they are accessed multiple times during each iteration.
  atom_counts_1_sum: int = dataclasses.field(init=False)
  atom_counts_2_sum: int = dataclasses.field(init=False)
  compound_counts_1_sum: int = dataclasses.field(init=False)
  compound_counts_2_sum: int = dataclasses.field(init=False)

  atom_divergence: float = dataclasses.field(init=False)
  compound_divergence: float = dataclasses.field(init=False)
  atom_coverage: float = dataclasses.field(init=False)

  def __post_init__(self):
    self.update_sum_fields()
    self.atom_divergence = mcd_utils.compute_divergence(
        self.atom_counts_1, self.atom_counts_2, coef=_ATOM_COEF)
    self.compound_divergence = mcd_utils.compute_divergence(
        self.compound_counts_1, self.compound_counts_2, coef=_COMPOUND_COEF)
    self.atom_coverage = 0.0

  def update_sum_fields(self):
    self.atom_counts_1_sum = sum(self.atom_counts_1.values())
    self.atom_counts_2_sum = sum(self.atom_counts_2.values())
    self.compound_counts_1_sum = sum(self.compound_counts_1.values())
    self.compound_counts_2_sum = sum(self.compound_counts_2.values())


def _get_original_index_of_item_with_max_adequacy_and_metrics(
    state, sampled_item_by_original_index,
    get_compounds_fn,
    get_atoms_fn,
    compute_divergence_if_changed_fn,
    compute_coverage_if_changed_fn,
    options):
  """Returns the original index of item to move, and new divergence/coverage."""
  max_adequacy = 0.0
  original_index_of_item_with_max_adequacy = 0
  new_atom_divergence = state.atom_divergence
  new_compound_divergence = state.compound_divergence
  new_atom_coverage = state.atom_coverage
  for original_index, item in sampled_item_by_original_index.items():
    atoms_of_item = get_atoms_fn(item)
    compounds_of_item = get_compounds_fn(item)
    atom_divergence_if_changed = compute_divergence_if_changed_fn(
        counts_1=state.atom_counts_1,
        counts_2=state.atom_counts_2,
        sum_1=state.atom_counts_1_sum,
        sum_2=state.atom_counts_2_sum,
        atoms_or_compounds_to_move=atoms_of_item,
        original_divergence=state.atom_divergence,
        coef=_ATOM_COEF)
    compound_divergence_if_changed = compute_divergence_if_changed_fn(
        counts_1=state.compound_counts_1,
        counts_2=state.compound_counts_2,
        sum_1=state.compound_counts_1_sum,
        sum_2=state.compound_counts_2_sum,
        atoms_or_compounds_to_move=compounds_of_item,
        original_divergence=state.compound_divergence,
        coef=_COMPOUND_COEF)
    coverage_if_changed = compute_coverage_if_changed_fn(
        atom_counts_1=state.atom_counts_1,
        num_all_atoms=len(state.all_atoms),
        atoms_to_move=atoms_of_item)
    adequacy_of_item = _get_adequacy(
        atom_divergence=atom_divergence_if_changed,
        compound_divergence=compound_divergence_if_changed,
        atom_coverage=coverage_if_changed,
        target_atom_divergence=options.target_atom_divergence,
        atom_similarity_exponent=options.atom_similarity_exponent,
        atom_coverage_exponent=options.atom_coverage_exponent)

    if adequacy_of_item > max_adequacy:
      max_adequacy = adequacy_of_item
      original_index_of_item_with_max_adequacy = original_index
      new_atom_divergence = atom_divergence_if_changed
      new_compound_divergence = compound_divergence_if_changed
      new_atom_coverage = coverage_if_changed

  return (original_index_of_item_with_max_adequacy,
          (new_atom_divergence, new_compound_divergence, new_atom_coverage))


def _subtract_and_maybe_delete_key(counter,
                                   items_to_subtract):
  """Subtracts items from counter and deletes keys with non-positive counts."""
  # An alternative is to use the unary "+" operator (__pos__) on counters to
  # remove zeros and negative counts:
  # https://docs.python.org/3/library/collections.html#counter-objects
  # However, that method iterates over all the items in the counter and can be
  # slow in our case when there are many atoms or compounds:
  # https://github.com/python/cpython/blob/3.10/Lib/collections/__init__.py
  # So here we opt to manage counter subtractions carefully, removing keys that
  # would result in non-positive counts ourselves.
  for item in items_to_subtract:
    if counter[item] <= 1:
      del counter[item]
    else:
      counter[item] -= 1


def maximize_divergence(
    items, size_1, size_2,
    get_compounds_fn,
    get_atoms_fn,
    options,
    rng):
  """Returns two disjoint subsets of items with maximized divergence.

  Currently only the insertion part of the algorithm has been implemented.

  Args:
    items: The collection of all the items from which to construct two disjoint
      subsets.  This should have size at least size_1 + size_2.
    size_1: The target size of the first subset.
    size_2: The target size of the second subset.
    get_compounds_fn: The function that returns the set of compounds of an item.
    get_atoms_fn: The function that returns the set of atoms of an item.
    options: The options for controlling the algorithm.
    rng: A random number generator.
  """
  if len(items) < size_1 + size_2:
    raise ValueError(f'Size of items should be at least size_1 + size_2 = '
                     f'{size_1 + size_2}, got: {len(items)}')

  # An unusual edge case, but our small scale generation tests reduce the number
  # of contexts to just 1.
  if len(items) == 1:
    return list(items), []

  all_atoms = set()
  for item in items:
    all_atoms.update(get_atoms_fn(item))

  # To avoid having to deal with the edge cases where the subsets or the atom
  # or compound counters are empty, we start with a random sample of items in
  # each subset.
  # (For the counters, this could happen when we split by top-level examples,
  # since some examples do not have compounds.)
  initial_sample_size_1 = int(options.initial_fraction * size_1) + 1
  initial_sample_size_2 = int(options.initial_fraction * size_2) + 1
  items_1 = _populate_initial_items(items, initial_sample_size_1, rng)
  items_2 = _populate_initial_items(items, initial_sample_size_2, rng)

  state = _MaximizeDivergenceState(
      all_atoms=frozenset(all_atoms),
      atom_counts_1=mcd_utils.get_all_compounds(items_1, get_atoms_fn),
      atom_counts_2=mcd_utils.get_all_compounds(items_2, get_atoms_fn),
      compound_counts_1=mcd_utils.get_all_compounds(items_1, get_compounds_fn),
      compound_counts_2=mcd_utils.get_all_compounds(items_2, get_compounds_fn))

  logging.info('Initial atom divergence=%f, initial compound divergence=%f',
               state.atom_divergence, state.compound_divergence)

  # Set up a schedule for when to remove items from the subsets.
  sizes_to_delete = _get_sizes_to_delete(len(items), options.delete_period)
  iteration_num = 0
  while len(items_1) < size_1 or len(items_2) < size_2:
    change_type = _select_change_type(
        current_size_1=len(items_1),
        current_size_2=len(items_2),
        size_1=size_1,
        size_2=size_2,
        delete_period=options.delete_period,
        sizes_to_delete=sizes_to_delete,
        rng=rng)

    # In each iteration an item is moved from one collection to another.  For
    # example, ADD_TO_1 moves an item from `items` to `items_1`, while
    # REMOVE_FROM_1 moves an item from `item_1` to `items`.
    filter_fn = None
    if change_type == _ChangeType.ADD_TO_1:
      source = items
      target = items_1
      compute_divergence_if_changed_fn = _compute_divergence_if_added_1
      compute_coverage_if_changed_fn = _compute_coverage_if_added_1

      if options.filter_items_for_missing_atom:
        missing_atoms = state.all_atoms - state.atom_counts_1.keys()
        if missing_atoms:
          filter_fn = lambda item: bool(get_atoms_fn(item) & missing_atoms)

    elif change_type == _ChangeType.ADD_TO_2:
      source = items
      target = items_2
      compute_divergence_if_changed_fn = _compute_divergence_if_added_2
      compute_coverage_if_changed_fn = _compute_coverage_if_added_2
    elif change_type == _ChangeType.REMOVE_FROM_1:
      source = items_1
      target = items
      compute_divergence_if_changed_fn = _compute_divergence_if_removed_1
      compute_coverage_if_changed_fn = _compute_coverage_if_removed_1
    elif change_type == _ChangeType.REMOVE_FROM_2:
      source = items_2
      target = items
      compute_divergence_if_changed_fn = _compute_divergence_if_removed_2
      compute_coverage_if_changed_fn = _compute_coverage_if_removed_2

      # This is to avoid the situation where all items containing an atom have
      # been assigned to items_2.
      if options.filter_items_for_missing_atom:
        missing_atoms = state.all_atoms - state.atom_counts_1.keys()
        if missing_atoms:
          filter_fn = lambda item: bool(get_atoms_fn(item) & missing_atoms)

    sampled_item_by_original_index = (
        _get_sampled_item_by_original_index(source, options.sample_size, rng,
                                            filter_fn))

    (original_index_of_item_with_max_adequacy,
     (new_atom_divergence, new_compound_divergence, new_atom_coverage
     )) = _get_original_index_of_item_with_max_adequacy_and_metrics(
         state, sampled_item_by_original_index, get_compounds_fn, get_atoms_fn,
         compute_divergence_if_changed_fn, compute_coverage_if_changed_fn,
         options)

    item_to_move = source.pop(original_index_of_item_with_max_adequacy)
    target.append(item_to_move)

    state.atom_divergence = new_atom_divergence
    state.compound_divergence = new_compound_divergence
    state.atom_coverage = new_atom_coverage

    atoms_of_item_to_move = get_atoms_fn(item_to_move)
    compounds_of_item_to_move = get_compounds_fn(item_to_move)

    if change_type == _ChangeType.ADD_TO_1:
      state.atom_counts_1.update(atoms_of_item_to_move)
      state.compound_counts_1.update(compounds_of_item_to_move)
    elif change_type == _ChangeType.ADD_TO_2:
      state.atom_counts_2.update(atoms_of_item_to_move)
      state.compound_counts_2.update(compounds_of_item_to_move)
    elif change_type == _ChangeType.REMOVE_FROM_1:
      _subtract_and_maybe_delete_key(state.atom_counts_1, atoms_of_item_to_move)
      _subtract_and_maybe_delete_key(state.compound_counts_1,
                                     compounds_of_item_to_move)
    elif change_type == _ChangeType.REMOVE_FROM_2:
      _subtract_and_maybe_delete_key(state.atom_counts_2, atoms_of_item_to_move)
      _subtract_and_maybe_delete_key(state.compound_counts_2,
                                     compounds_of_item_to_move)

    state.update_sum_fields()

    iteration_num += 1
    logging.info(
        'Iteration %d (%s): atom divergence=%f, compound divergence=%f'
        ', atom coverage=%f, items_1 size=%d/%d, items_2 size=%d/%d',
        iteration_num, change_type.value, state.atom_divergence,
        state.compound_divergence, state.atom_coverage, len(items_1), size_1,
        len(items_2), size_2)

  return items_1, items_2
