# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library of functions for use in tests."""

import itertools
import re
from typing import Optional, Sequence, Set, Union
from unittest import mock

import tensorflow as tf

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import enums
from conceptual_learning.cscan import grammar_schema as gs
from conceptual_learning.cscan import induction
from conceptual_learning.cscan import inference
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import outputs

# Dataset suite spec ID suitable for a small-scale test run.
# The following needs to be a valid id defined in `dataset_suite_specs.json`.
TEST_DATASET_SUITE_SPEC_ID = 'test'

# Dataset spec IDs suitable for small-scale test runs.
# The following need to be valid ids defined in `dataset_specs.json` and should
# be contained in the TEST_DATASET_SUITE_SPEC_ID suite.
TEST_DATASET_SPEC_ID_WITH_TEMPLATE = 'small_with_validation_and_test'
TEST_DATASET_SPEC_ID_WITHOUT_TEMPLATE = (
    'feature_grammar_small_without_template')

# For cases where we don't care whether a grammar template is specified or not.
TEST_DATASET_SPEC_ID = TEST_DATASET_SPEC_ID_WITH_TEMPLATE

# Invalid dataset[|suite] spec ID for using in testing of error cases.
# The following id should be defined in neither of `dataset_specs.json` nor
# `dataset_suite_specs.json`.
INVALID_DATASET_SPEC_ID = 'invalid_id'


def strip_blank_and_comment_lines(grammar_string):
  """Removes blank lines and comment lines from grammar strings."""
  return '\n'.join([
      line for line in grammar_string.splitlines()
      if line and not re.search(r'^\s*#', line)
  ])


def get_grammar_schema_for_scan_finite_nye_standardized():
  """Returns a GrammarSchema equivalent to scan_finite_nye_standardized.fcfg.

  This equivalence is tested in grammar_schema_test.py.
  """
  schema = gs.GrammarSchema()
  schema.primitives = [
      gs.PrimitiveMapping(
          input_sequence=['walk'], output_sequence=['WALK'], category='U'),
      gs.PrimitiveMapping(
          input_sequence=['look'], output_sequence=['LOOK'], category='U'),
      gs.PrimitiveMapping(
          input_sequence=['run'], output_sequence=['RUN'], category='U'),
      gs.PrimitiveMapping(
          input_sequence=['jump'], output_sequence=['JUMP'], category='U'),
      gs.PrimitiveMapping(
          input_sequence=['turn'], output_sequence=[''], category='U'),
      gs.PrimitiveMapping(
          input_sequence=['left'], output_sequence=['LTURN'], category='W'),
      gs.PrimitiveMapping(
          input_sequence=['right'], output_sequence=['RTURN'], category='W'),
  ]
  schema.functions_by_level = {
      2: [
          gs.FunctionRule(
              function_phrase=['opposite'],
              category='V',
              num_args=2,
              num_postfix_args=1,
              args=[
                  gs.FunctionArg(variable='?x1', category='U'),
                  gs.FunctionArg(variable='?x2', category='W'),
              ],
              output_sequence=['?x2', '?x2', '?x1']),
          gs.FunctionRule(
              function_phrase=['around'],
              category='V',
              num_args=2,
              num_postfix_args=1,
              args=[
                  gs.FunctionArg(variable='?x1', category='U'),
                  gs.FunctionArg(variable='?x2', category='W'),
              ],
              output_sequence=[
                  '?x2', '?x1', '?x2', '?x1', '?x2', '?x1', '?x2', '?x1'
              ]),
      ],
      3: [
          gs.FunctionRule(
              function_phrase=['twice'],
              category='S',
              num_args=1,
              num_postfix_args=0,
              args=[
                  gs.FunctionArg(variable='?x1', category='V'),
              ],
              output_sequence=['?x1', '?x1']),
          gs.FunctionRule(
              function_phrase=['thrice'],
              category='S',
              num_args=1,
              num_postfix_args=0,
              args=[
                  gs.FunctionArg(variable='?x1', category='V'),
              ],
              output_sequence=['?x1', '?x1', '?x1']),
      ],
      4: [
          gs.FunctionRule(
              function_phrase=['and'],
              category='C',
              num_args=2,
              num_postfix_args=1,
              args=[
                  gs.FunctionArg(variable='?x1', category='S'),
                  gs.FunctionArg(variable='?x2', category='S'),
              ],
              output_sequence=['?x1', '?x2']),
          gs.FunctionRule(
              function_phrase=['after'],
              category='C',
              num_args=2,
              num_postfix_args=1,
              args=[
                  gs.FunctionArg(variable='?x1', category='S'),
                  gs.FunctionArg(variable='?x2', category='S'),
              ],
              output_sequence=['?x2', '?x1']),
      ],
  }
  schema.pass_through_rules = {
      1:
          gs.PassThroughRule(
              category='D', arg=gs.FunctionArg(variable='?x1', category='U')),
      2:
          gs.PassThroughRule(
              category='V', arg=gs.FunctionArg(variable='?x1', category='D')),
      3:
          gs.PassThroughRule(
              category='S', arg=gs.FunctionArg(variable='?x1', category='V')),
      4:
          gs.PassThroughRule(
              category='C', arg=gs.FunctionArg(variable='?x1', category='S')),
  }
  schema.concat_rule_level = 1
  schema.concat_rule = gs.ConcatRule(
      category='D',
      arg1=gs.FunctionArg(variable='?x1', category='U'),
      arg2=gs.FunctionArg(variable='?x2', category='W'),
      output_sequence=['?x2', '?x1'])
  return schema


def create_grammar_options_with_fixed_number_of_rules(
    num_primitives = 5,
    num_precedence_levels = 3,
    num_functions_per_level = 2,
    has_pass_through_rules = True,
    has_concat_rule = False):
  """Returns GrammarOptions that produce grammars with a fixed number of rules.

  The default values are some reasonable values for use in testing purposes, in
  cases where the test doesn't care about the exact value.

  Args:
    num_primitives: Exact number of primitive mappings to generate.
    num_precedence_levels: Exact number of precedence levels for which to
      generate non-primitive rules.
    num_functions_per_level: Exact number of function rules to generate for any
      given precedence level.
    has_pass_through_rules: If True, then every precedence level will contain a
      PassThroughRule; otherwise, none will.
    has_concat_rule: Whether the grammar as a whole should contain a ConcatRule.
  """
  return inputs.GrammarOptions(
      num_primitives=num_primitives,
      num_precedence_levels=num_precedence_levels,
      # Avoid multiple categories per level, as this could force the number of
      # rules per level to be bigger than requested.
      max_num_categories_per_level=1,
      min_num_functions_per_level=num_functions_per_level,
      max_num_functions_per_level=num_functions_per_level,
      prob_pass_through_rule=float(has_pass_through_rules),
      prob_concat_rule=float(has_concat_rule))


def get_expected_num_rules_per_grammar(options):
  """Returns how many rules grammars created from these options will contain.

  Args:
    options: The options based on which the grammar will be generated. These
      must be options that produce grammars with a fixed number of rules, such
      as would be returned by create_grammar_options_with_fixed_number_of_rules.
  """
  if (options.num_primitives is None or options.num_precedence_levels is None or
      options.min_num_functions_per_level != options.max_num_functions_per_level
      or not options.prob_pass_through_rule.is_integer() or
      not options.prob_concat_rule.is_integer()):
    raise ValueError(
        'Only defined if the number of rules in the grammar is fixed')
  return (options.num_primitives + options.num_precedence_levels *
          (options.min_num_functions_per_level +
           int(options.prob_pass_through_rule)) + int(options.prob_concat_rule))


def get_expected_num_top_level_examples(options):
  """Returns how many top-level examples a generated dataset should contain.

  More specifically, this is the number of top-level examples that a dataset
  generated based on the given options would be expected to contain, assuming
  that no unrecoverable errors were to occur.

  Args:
    options: The options based on which the dataset will be generated.
  """
  return options.num_contexts * options.num_requests_per_context


def maybe_read_file(filepath):
  """Returns the contents of the file if it is exists, or else empty string."""
  try:
    with tf.io.gfile.GFile(filepath) as f:
      file_contents = f.read()
    return file_contents
  except tf.errors.NotFoundError:
    return ''


def create_sampling_options(
    num_contexts = 2,
    num_requests_per_context = 50,
    num_rules = -1,
    illustrative_example_non_rule_fraction = 0.5,
    num_examples_per_hidden_rule = 3,
    omitted_fraction = 0.0,
    omitted_fraction_stddev = 0.0,
    unreliable_fraction = 0.0,
    unreliable_fraction_stddev = 0.0,
    explicit_fraction = 0.5,
    explicit_fraction_stddev = 0.0,
    non_rule_fraction = 0.5,
    negative_example_fraction = 0.5,
    negative_example_fraction_stddev = 0.0,
    defeasible_example_fraction = 0.5,
    defeasible_example_fraction_stddev = 0.0,
    unknown_example_fraction = 0.2,
    unknown_example_fraction_stddev = 0.0,
    max_attempts_per_example = 100,
    max_attempts_per_context = 10,
    max_attempts_per_grammar = 5,
    min_illustrative_examples = 4,
    rule_format = enums.RuleFormat.FEATURE_GRAMMAR_PRODUCTION,
    max_input_length_standard = 0,
    max_input_length_compact = 0,
    max_output_length_standard = 0,
    max_output_length_compact = 0,
):
  """Returns options suitable for a small-scale test run."""
  return inputs.SamplingOptions(
      num_contexts=num_contexts,
      num_requests_per_context=num_requests_per_context,
      num_examples_per_hidden_rule=num_examples_per_hidden_rule,
      num_rules=num_rules,
      illustrative_example_non_rule_fraction=(
          illustrative_example_non_rule_fraction),
      omitted_fraction=omitted_fraction,
      omitted_fraction_stddev=omitted_fraction_stddev,
      unreliable_fraction=unreliable_fraction,
      unreliable_fraction_stddev=unreliable_fraction_stddev,
      explicit_fraction=explicit_fraction,
      explicit_fraction_stddev=explicit_fraction_stddev,
      non_rule_fraction=non_rule_fraction,
      negative_example_fraction=negative_example_fraction,
      negative_example_fraction_stddev=negative_example_fraction_stddev,
      defeasible_example_fraction=defeasible_example_fraction,
      defeasible_example_fraction_stddev=defeasible_example_fraction_stddev,
      unknown_example_fraction=unknown_example_fraction,
      unknown_example_fraction_stddev=unknown_example_fraction_stddev,
      max_attempts_per_example=max_attempts_per_example,
      max_attempts_per_context=max_attempts_per_context,
      max_attempts_per_grammar=max_attempts_per_grammar,
      rule_format=rule_format,
      inductive_bias=induction.IllustrativeExamplesInductiveBias(
          min_illustrative_examples=min_illustrative_examples),
      lengths=inputs.ExampleLengthOptions(
          max_input_length_standard=max_input_length_standard,
          max_input_length_compact=max_input_length_compact,
          max_output_length_standard=max_output_length_standard,
          max_output_length_compact=max_output_length_compact))


def create_example_from_explicit_rule(rule):
  example = cl.Example(
      request=rule,
      reply=cl.RuleReply.TRUE,
      metadata=cl.ExampleMetadata(rules={rule}, target_rule=rule))
  return example


def create_context(explicit_rules,
                   hidden_rules):
  mutable_context = cl.ExampleSet()
  for rule in explicit_rules:
    mutable_context.add_explicit_rule(rule,
                                      create_example_from_explicit_rule(rule))
  for rule in hidden_rules:
    mutable_context.add_hidden_rule(rule, [])
  return cl.FrozenExampleSet.from_example_set(mutable_context)


def create_example_group(label,
                         with_omitted_rule = False,
                         rules = None,
                         num_examples_per_example_group = 1,
                         num_rules_per_example = 1):
  """Returns an ExampleGroup."""
  if rules is None:
    rules = set()
  mutable_context = cl.ExampleSet()
  if with_omitted_rule:
    mutable_context.add_omitted_rule(f'omitted_rule_{label}')
  mutable_context.add_example(
      cl.Example(
          request=f'context_request_{label}',
          reply=f'context_reply_{label}',
          metadata=cl.ExampleMetadata(rules=rules)))
  context = cl.FrozenExampleSet.from_example_set(mutable_context)
  example_group = cl.ExampleGroup(context=context)
  for i in range(num_examples_per_example_group):
    example_rules = set(list(rules)[i:i + num_rules_per_example])
    example_group.add_example(
        cl.Example(
            request=f'top_level_request_{label}_{i}',
            reply=f'top_level_reply_{label}_{i}',
            metadata=cl.ExampleMetadata(rules=example_rules)))
  return example_group


def create_grouped_example_set(
    num_contexts_with_omitted_rules,
    num_contexts_without_omitted_rules):
  """Returns a GroupedExampleSet."""
  example_groups_with_omitted_rules = []
  for i in range(num_contexts_with_omitted_rules):
    label = f'with_omitted_{i}'
    example_group = create_example_group(label, with_omitted_rule=True)
    example_groups_with_omitted_rules.append(example_group)

  example_groups_without_omitted_rules = []
  for i in range(num_contexts_without_omitted_rules):
    label = f'without_omitted_{i}'
    example_group = create_example_group(label, with_omitted_rule=False)
    example_groups_without_omitted_rules.append(example_group)

  example_groups = (
      example_groups_with_omitted_rules + example_groups_without_omitted_rules)
  return cl.GroupedExampleSet(example_groups=example_groups)


def create_grouped_example_set_with_rules(
    num_contexts,
    num_rules_per_context = 3,
    num_examples_per_example_group = 1,
    num_rules_per_example = 1):
  """Returns a GroupedExampleSet with rules."""
  all_rules = [f'rule_{i}' for i in range(2 * num_contexts)]
  example_groups = []
  for i in range(num_contexts):
    # For each context we simulate the situation where there is a shared pool of
    # rules but each context only gets a subset.  This is intended to be used
    # for testing compound divergence split.
    rules = set(all_rules[i:i + num_rules_per_context])
    label = f'context_{i}'
    example_group = create_example_group(
        label,
        rules=rules,
        num_examples_per_example_group=num_examples_per_example_group,
        num_rules_per_example=num_rules_per_example)
    example_groups.append(example_group)

  return cl.GroupedExampleSet(example_groups=example_groups)


def get_dataset_summary(
    counters,
    dataset):
  """Returns a summary string for use in test assertion output."""
  context_0 = cl.FrozenExampleSet()
  if isinstance(dataset, cl.ExampleSet):
    num_top_level_examples = len(dataset)
    if num_top_level_examples:
      context_0 = dataset[0].context
  elif isinstance(dataset, cl.GroupedExampleSet):
    num_top_level_examples = len(dataset.to_example_set())
    if num_top_level_examples:
      context_0 = dataset.example_groups[0].context
  return (f'\ncounters = {counters}'
          f'\n# top-level examples = {num_top_level_examples}'
          f'\n# nested examples in 1st context = {len(context_0)}'
          f'\n# rules in 1st context = {len(context_0.metadata.rules)} ('
          f'omitted = {len(context_0.metadata.omitted_rules)}, '
          f'explicit = {len(context_0.metadata.explicit_rules)}, '
          f'hidden = {len(context_0.metadata.hidden_rules)})')


def make_fake_inference_engine(*,
                               all_defeasible = False,
                               never_contains = False,
                               all_consistent = False):
  """Returns a fake InferenceEngine instance for testing.

  The return value of this factory function should implement all the public
  methods of the InferenceEngine class.

  Args:
    all_defeasible: If True, contains_production returns False whenever the
      is_monotonic argument is True.
    never_contains: If True, contains_production always returns False.
    all_consistent: If True, consistent_if_production_added always returns True.
  """
  # Cannot spec using the real class because it may already be mocked out.
  fake_inference_engine = mock.MagicMock()

  # If the inference engine contains a production, it should never be
  # inconsistent.  If the inference engine contains a production in its
  # monotonic_productions collection, it should also contain it in its
  # all_productions collection.
  # The fake inference engine picks a return value dict for every production,
  # then uses the same answer for all calls to contains_production and
  # consistent_if_production_added.
  return_values = [{
      'consistent': True,
      'contains': True,
      'contains_monotonic': True,
      'inconsistency': None
  }, {
      'consistent': True,
      'contains': True,
      'contains_monotonic': False,
      'inconsistency': None
  }, {
      'consistent': True,
      'contains': False,
      'contains_monotonic': False,
      'inconsistency': None
  }, {
      'consistent':
          False,
      'contains':
          False,
      'contains_monotonic':
          False,
      'inconsistency':
          inference.Inconsistency(
              type=cl.Qualifier.M,
              incoming_inconsistency=nltk_utils
              .production_from_production_string("U[sem='A'] -> 'a'"),
              existing_inconsistency=nltk_utils
              .production_from_production_string("U[sem='B'] -> 'a'"),
              incoming_inconsistency_source=set(),
              existing_inconsistency_source=set())
  }, {
      'consistent':
          False,
      'contains':
          False,
      'contains_monotonic':
          False,
      'inconsistency':
          inference.Inconsistency(
              type=cl.Qualifier.D,
              incoming_inconsistency=nltk_utils
              .production_from_production_string("U[sem='A'] -> 'a'"),
              existing_inconsistency=nltk_utils
              .production_from_production_string("U[sem='B'] -> 'a'"),
              incoming_inconsistency_source=set(),
              existing_inconsistency_source=set())
  }]
  return_values_generator = itertools.cycle(return_values)

  return_value_by_production = {}

  # We keep this empty so that productions are sampled from the grammar instead
  # of from the inference engine during testing.
  fake_all_productions = set()

  def _maybe_pick_return_value_for_production(production):
    if production not in return_value_by_production:
      return_value_by_production[production] = next(return_values_generator)

  def fake_contains_monotonic_production(production):
    if never_contains:
      return False
    elif all_defeasible:
      return False
    else:
      _maybe_pick_return_value_for_production(production)
      return return_value_by_production[production]['contains_monotonic']

  def fake_contains_defeasible_production(production):
    if never_contains:
      return False
    elif all_defeasible:
      return True
    else:
      _maybe_pick_return_value_for_production(production)
      return (return_value_by_production[production]['contains'] and
              not return_value_by_production[production]['contains_monotonic'])

  def fake_contains_production(production):
    if never_contains:
      return False
    else:
      _maybe_pick_return_value_for_production(production)
      return return_value_by_production[production]['contains']

  def fake_consistent_if_production_added(production, is_monotonic=False):
    del is_monotonic
    if all_consistent:
      return True
    else:
      _maybe_pick_return_value_for_production(production)
      return return_value_by_production[production]['consistent']

  # The backup_states method is used to make a copy of the base inference engine
  # during dataset generation, so here we make sure the fake method returns a
  # fake inference engine that has the same behavior.
  def fake_backup_states():
    return make_fake_inference_engine(
        all_defeasible=all_defeasible,
        never_contains=never_contains,
        all_consistent=all_consistent)

  def fake_add_production(production, is_monotonic=False, restore=False):
    del production, is_monotonic, restore
    return

  def fake_inconsistency_if_production_added(production):
    if all_consistent:
      return None
    else:
      _maybe_pick_return_value_for_production(production)
      return_value = return_value_by_production[production]
      if all_defeasible:
        return_value['inconsistency'] = cl.Qualifier.D
      return return_value['inconsistency']

  def fake_get_productions_of_num_variables(num_variables):
    del num_variables
    return []

  def fake_get_productions_of_lhs_symbol(symbol):
    del symbol
    return set()

  fake_inference_engine.all_productions = fake_all_productions
  fake_inference_engine.contains_monotonic_production = (
      fake_contains_monotonic_production)
  fake_inference_engine.contains_defeasible_production = (
      fake_contains_defeasible_production)
  fake_inference_engine.contains_production = fake_contains_production
  fake_inference_engine.consistent_if_production_added = (
      fake_consistent_if_production_added)
  fake_inference_engine.backup_states = fake_backup_states
  # copy_monotonic_engine behaves similarly enough to backup_status, so we are
  # just using the same mock function here.
  fake_inference_engine.copy_monotonic_engine = fake_backup_states
  fake_inference_engine.add_production = fake_add_production
  fake_inference_engine.inconsistency_if_production_added = (
      fake_inconsistency_if_production_added)
  fake_inference_engine.get_productions_of_num_variables = (
      fake_get_productions_of_num_variables)
  fake_inference_engine.get_productions_of_lhs_symbol = (
      fake_get_productions_of_lhs_symbol)

  return fake_inference_engine
