# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for applying the inductive bias of a conceptual learning task."""

import abc
import dataclasses
from typing import Any, Dict

import dataclasses_json

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import production_trees


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class InductiveBias(metaclass=abc.ABCMeta):
  """Interface for conceptual learning inductive biases.

  """

  @abc.abstractmethod
  def can_induce_rule(self, rule, context):
    """Returns whether the rule can be induced directly from the given context.

    Note that some rules may still be inducible even though this function
    returns false. This could happen if they would require more than one
    induction step, or if they would require additional deductive reasoning on
    top of the initial inductive step.

    The full set of inducible rules can be determined by first gathering all of
    the rules for which this function returns true, then adding them to the
    context, and then continuing to expand the context iteratively via deductive
    reasoning. Any rule that can be obtained in this way but not via deductive
    reasoning on the original context alone is considered to be inducible.

    It is left to the specific inductive bias implementation to decide which
    rules are directly inducible from a given context vs. only indirectly.

    Args:
      rule: The rule for which we want to determine whether it can be induced.
      context: The set of examples that are given to be true.
    """

  @classmethod
  def json_encode(cls, bias):
    """Returns a JSON representation of the InductiveBias."""
    unstructured = bias.to_dict()
    unstructured['type'] = type(bias).__name__
    return unstructured

  @classmethod
  def json_decode(cls, raw_value):
    """Returns an InductiveBias instance decoded from the raw JSON value.

    We need this custom decoder (and the corresponding encoder) because
    dataclasses_json does not support polymorphism out-of-the-box. See, e.g.:
    - https://github.com/lidatong/dataclasses-json/issues/106
    - https://github.com/lidatong/dataclasses-json/issues/222

    Args:
      raw_value: JSON representation of the inductive bias.
    """
    # The default type name is for backward compatibility with spec files that
    # were serialized by earlier versions of the code, which only supported one
    # inductive bias implementation and didn't output a type name.
    type_name = raw_value.get('type', 'IllustrativeExamplesInductiveBias')
    if type_name == 'IllustrativeExamplesInductiveBias':
      return IllustrativeExamplesInductiveBias.from_dict(raw_value)
    elif type_name == 'IllustrativeSubstitutionsInductiveBias':
      return IllustrativeSubstitutionsInductiveBias.from_dict(raw_value)
    else:
      raise ValueError(f'Unknown inductive bias type: {type_name}')


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class IllustrativeExamplesInductiveBias(InductiveBias):
  """Inductive bias based on the number of illustrative examples of a rule.


  Attributes:
    min_illustrative_examples: Minimum number of examples that there need to be
      in the context that show the given rule in action (i.e., depend on the
      given rule), and which don't depend on any unreliable rules.
  """
  min_illustrative_examples: int = 4

  def can_induce_rule(self, rule, context):
    """See parent class."""
    num_illustrative_examples = 0
    for example in context:
      if rule in example.metadata.rules and not example.is_unreliable:
        num_illustrative_examples += 1
    return num_illustrative_examples >= self.min_illustrative_examples


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class IllustrativeSubstitutionsInductiveBias(InductiveBias):
  """Inductive bias based on the number of illustrative substitutions of a rule.


  Under this inductive bias, we look for two kinds of evidence to justify the
  induction of the given rule:
  1. Evidence to justify each variable:
     For each variable, we require that we have either seen in the context a
     sufficient number of unique substitutions for the given variable, or else
     that we have seen at least one instance of the rule with the given variable
     left unsubstituted. (Note that this latter scenario would be equivalent to
     considering the number of observed variable substitutions in that case to
     be infinite, since the behavior of the given variable is specified
     explicitly in a generic form, which already covers all possible variable
     substitutions.)
  2. Evidence that the rule yields a well-defined output:
     This is necessary because we allow for the possibility that a given input
     token may not actually be meaningful on its own, but only in the context of
     one or more idiomatic phrases. As evidence that the rule's output is indeed
     meaningful on its own, we require that we have either seen in the context
     a sufficient number of unique outer substitutions, or else that there is at
     least one context example for which the given rule is the topmost rule in
     the application tree (i.e., for which the outer substitution is '__').

  Attributes:
    min_illustrative_variable_substitutions: Minimum number of unique variable
      substitutions that need to be illustrated in the context for each variable
      in the given rule that has not been observed in unsubstituted form.
    min_illustrative_outer_substitutions: Minimum number of outer substitutions
      in which the rules needs to be illustrated in the context, if the rule was
      never seen as the topmost node of an example's rule tree.
    reliable_examples_only: If True, then will only count substitutions observed
      in reliable examples (i.e., examples that don't depend on any unreliable
      rules). If False, then will consider all examples.
  """
  min_illustrative_variable_substitutions: int = 4
  min_illustrative_outer_substitutions: int = 2
  reliable_examples_only: bool = True

  def can_induce_rule(self, rule, context):
    """See parent class."""
    if self.reliable_examples_only:
      variable_substitutions_by_rule = (
          context.metadata.reliable_variable_substitutions_by_rule)
      outer_substitutions_by_rule = (
          context.metadata.reliable_outer_substitutions_by_rule)
    else:
      variable_substitutions_by_rule = (
          context.metadata.variable_substitutions_by_rule)
      outer_substitutions_by_rule = context.metadata.outer_substitutions_by_rule

    # Check variable substitutions.
    if rule not in variable_substitutions_by_rule:
      raise ValueError(
          f'Rule missing variable substitutions metadata: {rule}, {self.reliable_examples_only}, {variable_substitutions_by_rule}'
      )
    variable_substitutions_map = variable_substitutions_by_rule[rule]
    if (production_trees.get_effective_min_num_variable_substitutions(
        variable_substitutions_map) <
        self.min_illustrative_variable_substitutions):
      return False

    # Check outer substitutions.
    if rule not in outer_substitutions_by_rule:
      raise ValueError(f'Rule missing outer substitutions metadata: {rule}')
    outer_substitutions = outer_substitutions_by_rule[rule]
    if (production_trees.get_effective_num_outer_substitutions(
        outer_substitutions) < self.min_illustrative_outer_substitutions):
      return False

    return True


def get_rule_illustration_quality(
    rule, context,
    inductive_bias):
  """Returns the illustration quality of a rule illustrated in a context.

  A rule is illustrated in a context if it is either explicit or hidden.  An
  explicit rule has GOOD illustration quality. A hidden rule has GOOD
  illustration quality if it satisfies the criteria to be induced, based on the
  given inductive bias.

  Args:
    rule: A rule string.
    context: A context in which to evaluate the illustration quality of the
      rule.
    inductive_bias: Inductive bias that encapsulates the criteria for inducing a
      rule from a given context.

  Raises:
    ValueError: If the rule is not illustrated in the context.
  """
  if not context.illustrates_rule(rule):
    raise ValueError('Rule not illustrated in context.')

  if rule in context.metadata.explicit_rules:
    return cl.IllustrationQuality.GOOD

  if inductive_bias.can_induce_rule(rule, context):
    return cl.IllustrationQuality.GOOD
  else:
    return cl.IllustrationQuality.POOR
