# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for creating distractor rules.

We currently have two strategies for creating distractor rules:

1. Heuristic edits:
This implementation for creating distractor rules is by editing source rules (as
FeatureGrammar productions) with swapping, adding, removing, or replacing
tokens.

Examples:
  - Swap: "A[sem=(?x1+'WALK')] -> B" becomes "A[sem=('WALK'+?x1)] -> B".
  - Add: "A[sem=?x1] -> B" becomes "A[sem=('WALK'+?x1)] -> B", where the new
    non-variable token must appear in the grammar's output tokens.
  - Remove: "A[sem=(?x1+'WALK')] -> B" becomes "A[sem=?x1] -> B".
  - Replace: "A[sem=(?x1+'WALK')] -> B" becomes "A[sem=(?x1+'JUMP')] -> B".
  - Repeat: "A[sem=(?x1)] -> B" becomes "A[sem=(?x1+?x1)] -> B".

When we add or replace with a non-variable token, we only use tokens that
already appear in the grammar.  When we add or replace with a variable token,
we only use tokens that appear on the RHS of the production.

We also intentionally separate the add/remove/replace/repeat edits that operate
on variables vs non-variables in order to better track which edits are more
difficult to learners.

2. Alternative grammar:
This is by replacing one of the source rules with a different one that could
have been generated by grammar generator.

For example, if the positive rule is
"A[sem=(?x1+WALK)] -> B[sem=?x1] 'and' 'walk'", which is built from two source
rules: "A[sem=(?x1+'?x2')] -> B[sem=?x1] 'and' B[sem=?x2]" and
"A[sem=WALK] -> 'walk'".  We would generate a new grammar with the same grammar
generator and replace one of the source rules with the new source rule with the
same input sequence.  So for example in the new grammar we have
"A[sem=JUMP] -> 'walk'" and all other productions in the grammar are the same as
the original grammar, then the negative rule would be
"A[sem=(?x1+JUMP)] -> B[sem=?x1] 'and' 'walk'".

"""

import copy
import dataclasses
from typing import Callable, Dict, List, Sequence, Union

import nltk
import numpy as np

from conceptual_learning.cscan import grammar_generation
from conceptual_learning.cscan import grammar_representation
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import production_composition


class FailedToCreateDistractorError(Exception):
  """Failed to create a distractor rule similar to the given positive rule."""


@dataclasses.dataclass
class DistractorCreationResult:
  """The result and some metadata for distractor creation.

  Attributes:
    distractor: The distractor production.
    applied_edits: The edits applied to the original productions.  The values
      should be the names of heuristic edits or
      "sample_from_alternative_grammar".
    new_source_production_by_source_production: The mapping from source
      production to new source production.  This should be populated only for
      the alternative grammar strategy.
  """
  distractor: nltk.grammar.Production
  applied_edits: List[str]

  # Populated only for the alternative grammar strategy.
  new_source_production_by_source_production: Dict[
      nltk.grammar.Production,
      nltk.grammar.Production] = dataclasses.field(default_factory=dict)


# Negative example strategy: heuristic edits.
class _FailedToApplyEditError(Exception):
  """Failed to apply a single edit to the given production.

  Every _ProductionEdit function should validate the input and raise this
  exception for invalid input.
  """


_ProductionEdit = Callable[[
    nltk.grammar.Production, nltk.grammar.FeatureGrammar, np.random.RandomState
], nltk.grammar.Production]


def _lhs_tokens_from_production(
    production):
  """Returns the list of tokens from the production's LHS semantics."""
  sem = production.lhs()['sem']
  if isinstance(sem, (str, nltk.Variable)):
    return [sem]
  elif isinstance(
      sem,
      (nltk.featstruct.FeatureValueConcat, nltk.featstruct.FeatureValueTuple)):
    return list(sem)
  else:
    raise ValueError("Production LHS's sem feature should be str, Variable,"
                     f'FeatureValueConcat, or FeatureValueTuple.  got: {sem}')


def _lhs_non_variable_indices_from_production(
    production):
  """Returns the list of non-variable indices from the productions's LHS."""
  non_variable_indices = []
  for i, token in enumerate(_lhs_tokens_from_production(production)):
    if isinstance(token, str):
      non_variable_indices.append(i)
  return non_variable_indices


def _lhs_variable_indices_from_production(
    production):
  """Returns the list of variable indices from the productions's LHS."""
  variable_indices = []
  for i, token in enumerate(_lhs_tokens_from_production(production)):
    if isinstance(token, nltk.Variable):
      variable_indices.append(i)
  return variable_indices


def _distinct_lhs_non_variables_from_production(
    production):
  """Returns the list of non-variables from the productions's LHS."""
  non_variables = set()
  tokens = _lhs_tokens_from_production(production)
  for i in _lhs_non_variable_indices_from_production(production):
    non_variables.add(tokens[i])
  return sorted(list(non_variables))


def _distinct_lhs_variables_from_production(
    production):
  """Returns the list of variables from the productions's LHS."""
  variables = set()
  tokens = _lhs_tokens_from_production(production)
  for i in _lhs_variable_indices_from_production(production):
    variables.add(tokens[i])
  return sorted(list(variables))


def _rhs_variable_indices_from_production(
    production):
  """Returns the list of variable indices from the productions's RHS."""
  variable_indices = []
  for i, token in enumerate(production.rhs()):
    if isinstance(token, nltk.grammar.FeatStructNonterminal):
      variable_indices.append(i)
  return variable_indices


def _distinct_rhs_variables_from_production(
    production):
  """Returns the list of variables from the production's RHS."""
  variables = set()
  for i in _rhs_variable_indices_from_production(production):
    variables.add(production.rhs()[i]['sem'])
  return sorted(list(variables))


def _distinct_non_variables_from_grammar(
    grammar):
  """Returns the list of distinct non-variable tokens in the grammar."""
  non_variables = set()
  for production in grammar.productions():
    tokens = _lhs_tokens_from_production(production)
    for i in _lhs_non_variable_indices_from_production(production):
      non_variables.add(tokens[i])
  return sorted(list(non_variables))


def _new_production_from_production_and_tokens(
    production,
    tokens):
  """Returns a new production from the production replacing the LHS semantics.

  Args:
    production: A nltk.grammar.Production.
    tokens: A list of strings or nltk.Variables that will be used as the new
      production's LHS semantics feature.
  """
  new_rhs = production.rhs()[:]
  if not tokens:
    new_sem = ''
  elif len(tokens) == 1:
    new_sem = tokens[0]
  else:
    new_sem = nltk.featstruct.FeatureValueConcat(values=tokens)
  new_lhs = nltk.grammar.FeatStructNonterminal(
      production.lhs()[nltk.grammar.TYPE], sem=new_sem)
  new_production = nltk.grammar.Production(lhs=new_lhs, rhs=new_rhs)
  return new_production


def _swap(production,
          grammar,
          rng):
  """Returns a new production created by swapping two tokens."""
  del grammar
  tokens = _lhs_tokens_from_production(production)
  if len(set(tokens)) < 2:
    raise _FailedToApplyEditError("The production's LHS does not have at least "
                                  'two distinct tokens.')

  while True:
    new_tokens = tokens[:]
    i, j = rng.choice(range(len(new_tokens)), size=2, replace=False)
    new_tokens[j], new_tokens[i] = new_tokens[i], new_tokens[j]
    if new_tokens != tokens:
      break

  return _new_production_from_production_and_tokens(production, new_tokens)


def _add_non_variable(production,
                      grammar,
                      rng):
  """Returns a new production created by adding a new non-variable token."""
  grammar_non_variables = _distinct_non_variables_from_grammar(grammar)
  if not grammar_non_variables:
    raise _FailedToApplyEditError('The grammar does not have any non-variable'
                                  'token.')

  tokens = _lhs_tokens_from_production(production)
  index = rng.randint(0, len(tokens) + 1)
  new_token = rng.choice(grammar_non_variables)
  tokens.insert(index, new_token)

  return _new_production_from_production_and_tokens(production, tokens)


def _add_variable(production,
                  grammar,
                  rng):
  """Returns a new production created by adding a new variable token."""
  del grammar
  rhs_variables = _distinct_rhs_variables_from_production(production)
  if not rhs_variables:
    raise _FailedToApplyEditError("The production's RHS does not have any "
                                  'variable token.')

  tokens = _lhs_tokens_from_production(production)
  index = rng.randint(0, len(tokens) + 1)
  new_token = rng.choice(rhs_variables)
  tokens.insert(index, new_token)

  return _new_production_from_production_and_tokens(production, tokens)


def _remove_non_variable(production,
                         grammar,
                         rng):
  """Returns a new production created by removing a non-variable token."""
  del grammar
  lhs_non_variable_indices = _lhs_non_variable_indices_from_production(
      production)
  if not lhs_non_variable_indices:
    raise _FailedToApplyEditError("The production's LHS does not have any "
                                  'non-variable token.')

  tokens = _lhs_tokens_from_production(production)
  index = rng.choice(lhs_non_variable_indices)
  tokens.pop(index)

  return _new_production_from_production_and_tokens(production, tokens)


def _remove_variable(production,
                     grammar,
                     rng):
  """Returns a new production created by removing a variable token."""
  del grammar
  lhs_variable_indices = _lhs_variable_indices_from_production(production)
  if not lhs_variable_indices:
    raise _FailedToApplyEditError("The production's LHS does not have any "
                                  'variable token.')

  tokens = _lhs_tokens_from_production(production)
  index = rng.choice(lhs_variable_indices)
  tokens.pop(index)

  return _new_production_from_production_and_tokens(production, tokens)


def _replace_non_variable(
    production, grammar,
    rng):
  """Returns a new production created by replacing a non-variable token."""
  lhs_non_variables = _distinct_lhs_non_variables_from_production(production)
  grammar_non_variables = _distinct_non_variables_from_grammar(grammar)
  if not lhs_non_variables or (
      len(lhs_non_variables) == 1 and
      set(grammar_non_variables) <= set(lhs_non_variables)):
    raise _FailedToApplyEditError("The production's LHS does not have any "
                                  'non-variable token, or the grammar does not '
                                  'have a different non-variable token.')

  lhs_non_variable_indices = _lhs_non_variable_indices_from_production(
      production)
  tokens = _lhs_tokens_from_production(production)
  while True:
    new_tokens = tokens[:]
    index = rng.choice(lhs_non_variable_indices)
    new_token = rng.choice(grammar_non_variables)
    new_tokens[index] = new_token
    if new_tokens != tokens:
      break

  return _new_production_from_production_and_tokens(production, new_tokens)


def _replace_variable(production,
                      grammar,
                      rng):
  """Returns a new production created by replacing a variable token."""
  del grammar
  lhs_variables = _distinct_lhs_variables_from_production(production)
  rhs_variables = _distinct_rhs_variables_from_production(production)
  if not lhs_variables or (len(lhs_variables) == 1 and
                           set(rhs_variables) <= set(lhs_variables)):
    raise _FailedToApplyEditError("The production's LHS does not have any "
                                  "variable token, or the production's RHS "
                                  'does not have a different variable token.')

  lhs_variable_indices = _lhs_variable_indices_from_production(production)
  tokens = _lhs_tokens_from_production(production)
  while True:
    new_tokens = tokens[:]
    index = rng.choice(lhs_variable_indices)
    new_variable = rng.choice(rhs_variables)
    new_tokens[index] = new_variable
    if new_tokens != tokens:
      break

  return _new_production_from_production_and_tokens(production, new_tokens)


def _repeat_non_variable(production,
                         grammar,
                         rng):
  """Returns a new production created by repeating a non-variable token."""
  del grammar
  lhs_non_variable_indices = _lhs_non_variable_indices_from_production(
      production)
  if not lhs_non_variable_indices:
    raise _FailedToApplyEditError("The production's LHS does not have any "
                                  'non-variable token.')

  tokens = _lhs_tokens_from_production(production)
  index = rng.choice(lhs_non_variable_indices)
  tokens.insert(index, tokens[index])

  return _new_production_from_production_and_tokens(production, tokens)


def _repeat_variable(production,
                     grammar,
                     rng):
  """Returns a new production created by repeating a variable."""
  del grammar
  lhs_variable_indices = _lhs_variable_indices_from_production(production)
  if not lhs_variable_indices:
    raise _FailedToApplyEditError("The production's LHS does not have any "
                                  'variable token.')

  tokens = _lhs_tokens_from_production(production)
  index = rng.choice(lhs_variable_indices)
  tokens.insert(index, tokens[index])

  return _new_production_from_production_and_tokens(production, tokens)


supported_edits = [
    _swap, _add_non_variable, _add_variable, _remove_non_variable,
    _remove_variable, _replace_non_variable, _replace_variable,
    _repeat_non_variable, _repeat_variable
]


def _apply_edit(production,
                grammar,
                rng,
                edit):
  edited_production = edit(production, grammar, rng)
  return production_composition.normalize_semantics(edited_production)


def create_distractor_production_with_heuristic_edit(
    production, grammar,
    rng, max_edits):
  """Creates a distractor production.

  With the current implementation, the returned production and intermediate
  productions are guaranteed to be different from the original
  production, but the intermediate productions could be equal to each other.

  Args:
    production: The production to create a distractor for.  The production's LHS
      should have a "sem" feature that is either
      nltk.featstruct.FeatureValueConcat or nltk.featstruct.FeatureValueTuple.
    grammar: A grammar containing the production.  This is used as the source
      for non-variable tokens.
    rng: Random number generator.
    max_edits: The maximum number of edits to be applied to the production.  The
      actual number of edits is uniformly sampled from [0, max_edits].

  Returns:
    A DistractorCreationResult instance.

  Raises:
    FailedToCreateDistractorError: If failed to create the distractor
      production.
  """
  target_num_edits = rng.randint(1, max_edits + 1)
  distractor_production = copy.deepcopy(production)
  applied_edits = []
  for _ in range(target_num_edits):
    available_edits = supported_edits[:]
    rng.shuffle(available_edits)
    # We greedily apply any random edit that can be successfully applied and
    # continue without backtracking.
    for edit in available_edits:
      try:
        edited_production = _apply_edit(distractor_production, grammar, rng,
                                        edit)
        if edited_production == production:
          continue
        else:
          distractor_production = edited_production
          applied_edits.append(edit.__name__)
          break
      except _FailedToApplyEditError:
        pass
    else:
      raise FailedToCreateDistractorError(
          f'Failed to create distractor for production: '
          f'{production}, distractor_production: {distractor_production}, '
          f'target_num_edits: {target_num_edits}, '
          f'applied_edits: {applied_edits}.')

  return DistractorCreationResult(
      distractor=distractor_production, applied_edits=applied_edits)


# Negative example strategy: alternative grammar.
def _rhs_string_from_production(production):
  return str(production).split(' -> ')[1]


def create_distractor_production_with_alternative_grammar(
    production, grammar,
    rng,
    grammar_generator,
    provenance_by_production,
    max_attempts_per_negative_example):
  """Creates a distractor production.

  The implementation is by replacing one of the source productions with a new
  one that has the same input sequence but different output sequence.

  Args:
    production: The production to create a distractor for.
    grammar: The grammar from which the production is generated.
    rng: Random number generator.
    grammar_generator: The grammar generator that generated the grammar.
    provenance_by_production: A ProductionProvenanceDict instance to look up
      source productions of the production, and to record the distractor's
      provenance in.
    max_attempts_per_negative_example: The maximum number of attempts

  Returns:
    A DistractorCreationResult instance whose applied_edits field is set to
    "sample_from_alternative_grammar".

  Raises:
    ValueError: If the grammar does not have exactly one production with the
      same RHS as the sampled source production.
    FailedToCreateDistractorError: If production is a pass-through rule, or if
      the function fails to create a distractor after a number of attempts.
  """
  # Currently in the tests where we use the fake inference engine, example
  # sampling falls back to sampling productions from the grammar directly, and
  # pass-through rules could get sampled.
  if nltk_utils.is_pass_through_rule(production):
    raise FailedToCreateDistractorError(
        f'Cannot create distractor for pass-through rule {production}, '
        f'grammar: {grammar}')

  # The production could be one of the source productions in the grammar when
  # it is chosen to be unreliable, in which case it would not have been recorded
  # in provenance_by_production.
  provenance = provenance_by_production.setdefault(
      production,
      production_composition.ProductionProvenance(source=production))

  source_productions = {provenance.source} | {
      other_parent for other_parent, _ in provenance.compositions
  }
  source_productions = sorted(
      {
          production for production in source_productions
          if not nltk_utils.is_pass_through_rule(production)
      },
      key=str)

  for _ in range(max_attempts_per_negative_example):
    # Sample a source production from the provenance to be resampled.
    source_production = rng.choice(source_productions)

    # Create a grammar schema template that is identical to the given grammar
    # except for the grammar production whose RHS is the same as the RHS of the
    # sampled source production.
    target_rhs_string = _rhs_string_from_production(source_production)


    def rule_filter(rule_schema):
      return rule_schema.to_rule_string().endswith(target_rhs_string)



    grammar_schema = grammar_representation.grammar_schema_from_feature_grammar(
        grammar)
    template = (
        grammar_generation.fixed_phrase_structure_template_from_grammar_schema(
            grammar_schema, rule_filter))

    # Generate a new feature grammar that has a different output sequence for
    # the sampled source rule.
    new_grammar = grammar_generator.generate_grammar(template)

    productions_with_matching_rhs = [
        production for production in new_grammar.productions()
        if _rhs_string_from_production(production) == target_rhs_string
    ]
    if len(productions_with_matching_rhs) != 1:
      raise ValueError(
          'Expected exactly one production with target RHS string '
          f'{target_rhs_string}, got: {productions_with_matching_rhs} for '
          f'production {production} and source production {source_production}.'
          f'\nProvenance: {provenance}'
          f'\nGrammar: {grammar}'
          f'\nNew grammar: {new_grammar}')

    new_source_production = productions_with_matching_rhs[0]

    # Now we rebuild the production using the new source production.
    try:
      new_provenance = provenance.replace(source_production,
                                          new_source_production,
                                          provenance_by_production)
    except Exception as e:
      message = (
          f'Failed to replace source production {source_production} with new '
          f'source production {new_source_production} for creating distractor '
          f'for production {production}.'
          f'\nProvenance: {provenance}'
          f'\nGrammar: {grammar}'
          f'\nNew grammar: {new_grammar}')
      raise ValueError(message) from e

    new_production = new_provenance.get_production()

    if new_production != production:
      new_source_production_by_source_production = {
          source_production: new_source_production
      }
      break
  else:
    raise FailedToCreateDistractorError(
        f'Failed to create distractor for production {production} after '
        f'{max_attempts_per_negative_example} attempts, grammar: {grammar}')

  return DistractorCreationResult(
      distractor=new_production,
      applied_edits=['sample_from_alternative_grammar'],
      new_source_production_by_source_production=(
          new_source_production_by_source_production))
