# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data structures bundling side outputs of the dataset generation process.

These include various types of diagnostic information, such as event counters
and timing information. Note that the main output of the dataset generation
process -- that is the ExampleSet itself -- is not covered here, but rather in
conceptual_learning.py, as ExampleSet is central to the conceptual learning
framework and serves multiple purposes other than simply being the output of a
dataset generation run.
"""

import collections
import dataclasses
import math
import sys
from typing import AbstractSet, Any, Dict, Optional

import dataclasses_json
import tensorflow_datasets as tfds

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import production_trees


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class DistributionSummaryStats:
  """Key statistics summarizing a distribution, e.g., of input/output lengths.

  Although the underlying values of the original distribution could potentially
  be integers (e.g., for distributions of input/output lengths), we represent
  all summary statistics except for count as floats to support both int and
  float distributions.

  Attributes:
    min: Minimum value.
    max: Maximum value.
    sum: Sum of all values (i.e., numerator of the mean).
    sum_of_squares: Sum of the squares of all values (i.e., numerator of the
      standard deviation).
    count: Number of values (i.e., denominator of the mean).
    mean: Average (mean) value.
    stddev: Sample standard deviation.
  """
  min: float = sys.float_info.max
  max: float = sys.float_info.min
  sum: float = 0.0
  sum_of_squares: float = 0.0
  count: int = 0

  # The following properties are derived from the attributes above, but we are
  # are materializing their values here anyway so that they will be included
  # when the counters are serialized to JSON.
  mean: float = 0.0
  stddev: float = 0.0

  def _update_mean(self):
    """Updates the value of self.mean based on the other attributes."""
    self.mean = self.sum / self.count if self.count else 0.0

  def _update_stddev(self):
    """Updates the value of self.stddev based on the other attributes."""
    if self.count < 2:
      self.stddev = 0.0
      return

    # Incremental calculation of standard deviation, based on this equivalence:
    #   (a-b)^2 = a^2 - 2ab + b^2
    #   ==> SUM(Xi-Xmean)^2 = SUM(Xi^2) - 2*SUM(Xi*XMean) + SUM(Xmean^2)
    variance = ((self.sum_of_squares - 2.0 * self.mean * self.sum +
                 self.count * self.mean * self.mean) / (self.count - 1.0))
    self.stddev = math.sqrt(variance)

  def _update_derived_properties(self):
    """Updates all derived properties based on the underlying attributes."""
    # These need to be updated in this order, as stddev depends on mean.
    self._update_mean()
    self._update_stddev()

  def __add__(self,
              other):
    result = DistributionSummaryStats(
        min=min(self.min, other.min),
        max=max(self.max, other.max),
        sum=self.sum + other.sum,
        sum_of_squares=self.sum_of_squares + other.sum_of_squares,
        count=self.count + other.count)

    result._update_derived_properties()

    return result

  def update_with_value(self, value):
    """Updates the stats to reflect an observation of the given value."""
    self.min = min(self.min, value)
    self.max = max(self.max, value)
    self.sum += value
    self.sum_of_squares += value * value
    self.count += 1

    self._update_derived_properties()


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class ExampleAttemptCounters:
  """Counters tracking numbers of example generation attempts.

  The attributes in this class represent mutually exclusive buckets that
  categorize example generation attempts based on their success status and
  (in the case of unsuccessful attempts) their reason for failure.

  Counts attempts to generate both top-level examples and examples that are
  nested inside of the contexts of other examples.

  Attributes:
    already_in_context: Number of examples that were discarded due to having the
      same request as another example in its own context.
    ambiguous: Number of examples that were discarded because its request could
      be parsed in multiple ways by the underlying grammar.
    duplicate: Number of examples that were discarded because an identical
      example already existed in the dataset.
    max_derivation_level_reached: Number of examples that failed to be generated
      because the maximum level was reached in the random walk through the
      grammar.
    missing_target_rule: Number of examples that were discarded because its
      parse did not use the grammar rule that it was intended to illustrate.
    unparseable: Number of examples that were discarded because its request
      could not be parsed using the underlying grammar.
    unable_to_create_negative_example: Number of examples that were discarded
      due to failure to convert a positive example into a related negative
      distractor rule.
    context_example_inconsistent: Number of context examples that were discarded
      because they would introduce inconsistencies in the context.
    non_rule_example_inconsistent: Number of top-level non-rule examples that
      were discarded because they were inconsistent with the context.
    rule_example_inconsistent: Number of top-level rule examples that were
      discarded because they were inconsistent with the context.
    distractor_consistent_with_context: Number of top-level rule examples that
      were discarded because they were consistent with the context.
    wrong_reply_when_targeting_unknown: Number of top-level examples that were
      discarded because they were known to the inference engine when an example
      of unknown reply is requested.
    wrong_qualifier_when_targeting_monotonic: Number of top-level examples that
      were discarded because they do not have qualifier M when one is requested.
    wrong_qualifier_when_targeting_defeasible: Number of top-level examples that
      were discarded because they do not have qualifier D when one is requested.
    illustrating_distractor_rule_too_many_times: Number of context examples that
      were discarded because they would illustrate a distractor rule too many
      times.
    exceeded_max_input_length: Number of top-level examples that were discarded
      because their input string exceeded the maximum allowed number of tokens.
    exceeded_max_output_length: Number of top-level examples that were discarded
      because their output string exceeded the maximum allowed number of tokens.
    valid: Number of valid examples generated.
  """
  already_in_context: int = 0
  ambiguous: int = 0
  duplicate: int = 0
  max_derivation_level_reached: int = 0
  missing_target_rule: int = 0
  unparseable: int = 0
  unable_to_create_negative_example: int = 0
  context_example_inconsistent: int = 0
  non_rule_example_inconsistent: int = 0
  rule_example_inconsistent: int = 0
  distractor_consistent_with_context: int = 0
  wrong_reply_when_targeting_unknown: int = 0
  wrong_qualifier_when_targeting_monotonic: int = 0
  wrong_qualifier_when_targeting_defeasible: int = 0
  illustrating_distractor_rule_too_many_times: int = 0
  exceeded_max_input_length: int = 0
  exceeded_max_output_length: int = 0
  valid: int = 0

  def get_total(self):
    """Returns the total number of examples attempted."""
    total = 0
    for field in dataclasses.fields(self):
      total += getattr(self, field.name)
    return total

  def get_valid_fraction(self):
    """Returns the fraction of attempted examples that were valid."""
    return 1.0 * self.valid / (self.get_total() + 1e-100)

  def __add__(self, other):
    init_kwargs = {}
    for field in dataclasses.fields(self):
      self_value = getattr(self, field.name)
      other_value = getattr(other, field.name)
      init_kwargs[field.name] = self_value + other_value

    return ExampleAttemptCounters(**init_kwargs)


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class ContextAttemptCounters:
  """Counters tracking numbers of context generation attempts.

  Attributes:
    poor_illustration_quality: The context is of poor illustration quality.
    unreliable_rule_illustrated_only_one_way: At least one unreliable rule is
      illustrated only as one distractor rule.
    exceeded_max_input_length: Number of contexts that were discarded because
      they exceeded or came too close to exceeding the maximum allowed top-level
      example input length.
  """
  poor_illustration_quality: int = 0
  unreliable_rule_illustrated_only_one_way: int = 0
  exceeded_max_input_length: int = 0

  def __add__(self, other):
    return ContextAttemptCounters(
        poor_illustration_quality=(self.poor_illustration_quality +
                                   other.poor_illustration_quality),
        unreliable_rule_illustrated_only_one_way=(
            self.unreliable_rule_illustrated_only_one_way +
            other.unreliable_rule_illustrated_only_one_way),
        exceeded_max_input_length=(self.exceeded_max_input_length +
                                   other.exceeded_max_input_length),
    )


def _counter_field():
  """All dataclass fields of type Counter should be created with this method."""
  return dataclasses.field(
      default_factory=collections.Counter,
      # Without the decoder specified explicitly, the Counter would for some
      # reason get mangled when restoring from JSON. E.g., instead of
      #   Counter({'non_rule_request': 5})
      # it would get restored as
      #   Counter({('non_rule_request', 5): 1})
      metadata=dataclasses_json.config(decoder=collections.Counter))


def _by_output_pattern_counter_encoder(
    counter):
  """Returns a dict summarizing ExampleCounters.by_output_pattern_counter."""
  num_distinct_output_patterns = len(counter)
  min_num_examples = min(counter.values(), default=0)
  max_num_examples = max(counter.values(), default=0)
  stats = {
      'num_distinct_output_patterns': num_distinct_output_patterns,
      'min_num_examples': min_num_examples,
      'max_num_examples': max_num_examples
  }

  return stats


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class ExampleCounters:
  """Counters tracking numbers of examples generated in the dataset.

  Only counts top-level examples, not examples that are nested inside of other
  examples' contexts.

  Attributes:
    by_request_type: Number of examples broken down by request type.
    by_qualifier: Number of examples broken down by qualifier.
    by_example_type: Number of examples broken down by example type.
    by_knownness: Number of examples broken down by the replies type: UNKNOWN,
      KNOWN_DEFEASIBLE, and KNOWN_MONOTONIC.
    by_derivation_level: Number of examples broken down by derivation level.
    by_num_variables: Number of examples broken down by number of unbound
      variables.
    by_derivation_level_and_num_variables: Number of examples broken down by
      derivation level and number of unbound variables.
    by_num_omitted_rules: Number of examples broken down by number of dependency
      rules that are omitted in the context.
    by_num_explicit_rules: Number of examples broken down by number of
      dependency rules that are explicit in the context.
    by_num_unreliable_rules: Number of examples broken down by number of
      dependency rules that are unreliable in the context.
    by_num_hidden_rules: Number of examples broken down by number of dependency
      rules that are hidden in the context.
    by_num_rules: Number of examples broken down by number of dependency rules.
    by_applied_edits: Number of negative rule examples broken down by the type
      of heuristic edit(s) applied to create them.
    by_triviality: Number of examples broken down by triviality.  See
      conceptual_learning.Triviality for the triviality types.
    by_output_pattern: Number of examples broken down by output pattern.  This
      counter is JSON-serialized only as a summary and cannot be recovered.
    input_length_stats_standard: Statistics summarizing the distribution of
      top-level example input lengths, based on the STANDARD string
      representation and whitespace tokenizer.
    output_length_stats_standard: Statistics summarizing the distribution of
      top-levelexample output lengths, based on the STANDARD string
      representation and whitespace tokenizer.
    input_length_stats_compact: Statistics summarizing the distribution of
      top-level example input lengths, based on the COMPACT string
      representation and T5X's default SentencePieceVocabulary tokenizer.
    output_length_stats_compact: Statistics summarizing the distribution of
      top-levelexample output lengths, based on the COMPACT string
      representation and T5X's default SentencePieceVocabulary tokenizer.
  """
  by_request_type: collections.Counter[str] = _counter_field()
  by_example_type: collections.Counter[str] = _counter_field()
  by_qualifier: collections.Counter[str] = _counter_field()
  by_knownness: collections.Counter[str] = _counter_field()
  by_derivation_level: collections.Counter[int] = _counter_field()
  by_num_variables: collections.Counter[int] = _counter_field()
  by_derivation_level_and_num_variables: collections.Counter[str] = (
      _counter_field())
  by_num_omitted_rules: collections.Counter[int] = _counter_field()
  by_num_explicit_rules: collections.Counter[int] = _counter_field()
  by_num_unreliable_rules: collections.Counter[int] = _counter_field()
  by_num_hidden_rules: collections.Counter[int] = _counter_field()
  by_num_rules: collections.Counter[int] = _counter_field()
  by_applied_edits: collections.Counter[str] = _counter_field()
  by_triviality: collections.Counter[str] = _counter_field()
  # For output patterns, we use a real counter during dataset generation and
  # splitting, but for JSON-serialization we summarize the data into stats so
  # the JSON files are more readable (but can no longer correctly recover the
  # by_output_pattern counter).
  by_output_pattern: collections.Counter[str] = dataclasses.field(
      default_factory=collections.Counter,
      compare=False,
      metadata=dataclasses_json.config(
          encoder=_by_output_pattern_counter_encoder, decoder=dict))

  input_length_stats_standard: DistributionSummaryStats = dataclasses.field(
      default_factory=DistributionSummaryStats)
  output_length_stats_standard: DistributionSummaryStats = dataclasses.field(
      default_factory=DistributionSummaryStats)
  input_length_stats_compact: DistributionSummaryStats = dataclasses.field(
      default_factory=DistributionSummaryStats)
  output_length_stats_compact: DistributionSummaryStats = dataclasses.field(
      default_factory=DistributionSummaryStats)

  def get_total(self):
    """Returns the total number of examples generated."""
    return sum(self.by_request_type.values())

  def get_fraction_by_example_type_by_request_type(
      self):
    """Returns the fraction of examples broken down by types.

    The fractions are grouped by request type.  So if the returned dict looks
    like:
    {
        'rule_request': {k1: v1, k2: v2, ...},
        'non_rule_request': {l1: w1, l2: w2, ...}
    }
    Then v1 + v2 + ... should be 1.0 or 0.0, the latter case happens when there
    are no rule examples).
    """
    result = {}
    for request_type, example_types in cl.EXAMPLE_TYPES_BY_REQUEST_TYPE.items():
      num_examples_of_request_type = self.by_request_type[request_type]
      result[request_type] = {}
      for example_type in example_types:
        if num_examples_of_request_type == 0:
          result[request_type][example_type] = 0.0
        else:
          num_examples_of_example_type = self.by_example_type[example_type]
          result[request_type][example_type] = (
              num_examples_of_example_type / num_examples_of_request_type)

    return result

  def __add__(self, other):
    return ExampleCounters(
        by_request_type=self.by_request_type + other.by_request_type,
        by_example_type=self.by_example_type + other.by_example_type,
        by_qualifier=self.by_qualifier + other.by_qualifier,
        by_knownness=self.by_knownness + other.by_knownness,
        by_derivation_level=self.by_derivation_level +
        other.by_derivation_level,
        by_num_variables=self.by_num_variables + other.by_num_variables,
        by_derivation_level_and_num_variables=self
        .by_derivation_level_and_num_variables +
        other.by_derivation_level_and_num_variables,
        by_num_omitted_rules=self.by_num_omitted_rules +
        other.by_num_omitted_rules,
        by_num_explicit_rules=self.by_num_explicit_rules +
        other.by_num_explicit_rules,
        by_num_unreliable_rules=self.by_num_unreliable_rules +
        other.by_num_unreliable_rules,
        by_num_hidden_rules=self.by_num_hidden_rules +
        other.by_num_hidden_rules,
        by_num_rules=self.by_num_rules + other.by_num_rules,
        by_applied_edits=self.by_applied_edits + other.by_applied_edits,
        by_triviality=self.by_triviality + other.by_triviality,
        by_output_pattern=self.by_output_pattern + other.by_output_pattern,
        input_length_stats_standard=self.input_length_stats_standard +
        other.input_length_stats_standard,
        output_length_stats_standard=self.output_length_stats_standard +
        other.output_length_stats_standard,
        input_length_stats_compact=self.input_length_stats_compact +
        other.input_length_stats_compact,
        output_length_stats_compact=self.output_length_stats_compact +
        other.output_length_stats_compact)

  def update_with_example_and_context(self,
                                      example,
                                      context = None):
    """Updates the ExampleCounters with the example and the context.

    The current implementation has non-idempotent side effects and should be
    called exactly once for every top-level example added to the dataset.

    Args:
      example: An Example whose content and metadata will be used to update the
        counters.
      context: The optional context in which the example appears.  If not
        provided, some counters will not be incremented.
    """
    self.by_qualifier[example.qualifier] += 1
    self.by_request_type[example.get_request_type()] += 1
    self.by_example_type[example.get_example_type()] += 1
    self.by_knownness[example.get_knownness()] += 1
    self.by_derivation_level[example.metadata.derivation_level] += 1
    self.by_num_variables[example.metadata.num_variables] += 1
    # We should just use the tuple as the key, but dataclasses_json's to_json
    # method turns any Collection into a list and causes TypeError because the
    # tuple is used as a dict key.
    derivation_level_and_num_variables = str(
        (example.metadata.derivation_level, example.metadata.num_variables))
    self.by_derivation_level_and_num_variables[
        derivation_level_and_num_variables] += 1

    if context is not None:
      num_omitted_rules = len(
          example.metadata.rules.intersection(context.metadata.omitted_rules))
      self.by_num_omitted_rules[num_omitted_rules] += 1
      num_explicit_rules = len(
          example.metadata.rules.intersection(context.metadata.explicit_rules))
      self.by_num_explicit_rules[num_explicit_rules] += 1
      num_unreliable_rules = len(
          example.metadata.distractor_rules_by_unreliable_rule)
      self.by_num_unreliable_rules[num_unreliable_rules] += 1
      num_hidden_rules = len(
          example.metadata.rules.intersection(context.metadata.hidden_rules))
      self.by_num_hidden_rules[num_hidden_rules] += 1
      self.by_triviality[example.get_triviality(context)] += 1

    self.by_num_rules[len(example.metadata.rules)] += 1
    self.by_output_pattern[example.get_output_pattern()] += 1

    for applied_edit in example.metadata.applied_edits:
      self.by_applied_edits[applied_edit] += 1

    self.input_length_stats_standard.update_with_value(
        example.metadata.input_length_standard)
    self.output_length_stats_standard.update_with_value(
        example.metadata.output_length_standard)
    self.input_length_stats_compact.update_with_value(
        example.metadata.input_length_compact)
    self.output_length_stats_compact.update_with_value(
        example.metadata.output_length_compact)


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class ContextCounters:
  """Counters tracking numbers of contexts generated in the dataset.

  Attributes:
    total: Total number of unique contexts generated.
    by_num_omitted_rules: Number of contexts broken down by number of omitted
      rules.
    by_num_explicit_rules: Number of contexts broken down by number of explicit
      rules.
    by_num_unreliable_rules:  Number of contexts broken down by number of
      unreliable rules.
    by_num_hidden_rules: Number of contexts broken down by number of hidden
      rules.
    by_num_hidden_true_rules: Number of contexts broken down by number of hidden
      rules that satisfy the inductive bias.
    by_num_hidden_unknown_rules: Number of contexts broken down by number of
      hidden rules that don't satisfy the inductive bias.
    by_num_rules: Number of contexts broken down by number of rules.
    by_num_examples: Number of contexts broken down by number of examples in the
      context.
    by_num_unknown_rule_top_level_examples:  Number of contexts broken down by
      number of top-level unknown rule examples in the example group.
    by_num_positive_rule_top_level_examples: Number of contexts broken down by
      number of top-level positive rule examples in the example group.
    by_num_negative_rule_top_level_examples: Number of contexts broken down by
      number of top-level negative rule examples in the example group.
    by_num_unknown_nonrule_top_level_examples: Number of contexts broken down by
      number of top-level unknown non-rule examples in the example group.
  """
  total: int = 0
  by_num_omitted_rules: collections.Counter[int] = _counter_field()
  by_num_explicit_rules: collections.Counter[int] = _counter_field()
  by_num_unreliable_rules: collections.Counter[int] = _counter_field()
  by_num_hidden_rules: collections.Counter[int] = _counter_field()
  by_num_hidden_true_rules: collections.Counter[int] = _counter_field()
  by_num_hidden_unknown_rules: collections.Counter[int] = _counter_field()
  by_num_rules: collections.Counter[int] = _counter_field()
  by_num_examples: collections.Counter[int] = _counter_field()
  by_num_unknown_rule_top_level_examples: collections.Counter[int] = (
      _counter_field())
  by_num_positive_rule_top_level_examples: collections.Counter[int] = (
      _counter_field())
  by_num_negative_rule_top_level_examples: collections.Counter[int] = (
      _counter_field())
  by_num_unknown_nonrule_top_level_examples: collections.Counter[int] = (
      _counter_field())

  def __add__(self, other):
    return ContextCounters(
        total=self.total + other.total,
        by_num_omitted_rules=self.by_num_omitted_rules +
        other.by_num_omitted_rules,
        by_num_explicit_rules=self.by_num_explicit_rules +
        other.by_num_explicit_rules,
        by_num_unreliable_rules=self.by_num_unreliable_rules +
        other.by_num_unreliable_rules,
        by_num_hidden_rules=self.by_num_hidden_rules +
        other.by_num_hidden_rules,
        by_num_hidden_true_rules=self.by_num_hidden_true_rules +
        other.by_num_hidden_true_rules,
        by_num_hidden_unknown_rules=self.by_num_hidden_unknown_rules +
        other.by_num_hidden_unknown_rules,
        by_num_rules=self.by_num_rules + other.by_num_rules,
        by_num_examples=self.by_num_examples + other.by_num_examples,
        by_num_unknown_rule_top_level_examples=self
        .by_num_unknown_rule_top_level_examples +
        other.by_num_unknown_rule_top_level_examples,
        by_num_positive_rule_top_level_examples=self
        .by_num_positive_rule_top_level_examples +
        other.by_num_positive_rule_top_level_examples,
        by_num_negative_rule_top_level_examples=self
        .by_num_negative_rule_top_level_examples +
        other.by_num_negative_rule_top_level_examples,
        by_num_unknown_nonrule_top_level_examples=self
        .by_num_unknown_nonrule_top_level_examples +
        other.by_num_unknown_nonrule_top_level_examples)

  def update_with_context(self, context):
    """Updates the ContextCounters with information in the context.

    The current implementation has non-idempotent side effects and should be
    called exactly once for every context added to the dataset.

    Args:
      context: A FrozenExampleSet whose content and metadata will be used to
        update the counters.
    """
    self.total += 1

    num_omitted_rules = len(context.metadata.omitted_rules)
    self.by_num_omitted_rules[num_omitted_rules] += 1

    num_explicit_rules = len(context.metadata.explicit_rules)
    self.by_num_explicit_rules[num_explicit_rules] += 1

    num_unreliable_rules = len(context.metadata.unreliable_rules)
    self.by_num_unreliable_rules[num_unreliable_rules] += 1

    num_hidden_rules = len(context.metadata.hidden_rules)
    self.by_num_hidden_rules[num_hidden_rules] += 1

    num_hidden_true_rules = len(context.metadata.hidden_true_rules)
    self.by_num_hidden_true_rules[num_hidden_true_rules] += 1

    num_hidden_unknown_rules = len(context.metadata.hidden_unknown_rules)
    self.by_num_hidden_unknown_rules[num_hidden_unknown_rules] += 1

    num_rules = len(context.metadata.rules)
    self.by_num_rules[num_rules] += 1

    num_examples = len(context)
    self.by_num_examples[num_examples] += 1

  def update_with_example_group(self, example_group):
    """Updates the ContextCounters with information in its example group."""
    num_examples_dict = {
        'by_num_unknown_rule_top_level_examples': 0,
        'by_num_positive_rule_top_level_examples': 0,
        'by_num_negative_rule_top_level_examples': 0,
        'by_num_unknown_nonrule_top_level_examples': 0
    }
    for example in example_group:
      if example.get_example_type() == cl.ExampleType.RULE_UNKNOWN_D:
        num_examples_dict['by_num_unknown_rule_top_level_examples'] += 1
      elif example.get_example_type() in (cl.ExampleType.RULE_KNOWN_TRUE_D,
                                          cl.ExampleType.RULE_KNOWN_TRUE_M):
        num_examples_dict['by_num_positive_rule_top_level_examples'] += 1
      elif example.get_example_type() in (cl.ExampleType.RULE_KNOWN_FALSE_D,
                                          cl.ExampleType.RULE_KNOWN_FALSE_M):
        num_examples_dict['by_num_negative_rule_top_level_examples'] += 1
      elif example.get_example_type() == cl.ExampleType.NONRULE_UNKNOWN_D:
        num_examples_dict['by_num_unknown_nonrule_top_level_examples'] += 1

    for key, count in num_examples_dict.items():
      getattr(self, key)[count] += 1


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class RuleCounters:
  """Counters tracking numbers of rules generated in the dataset.

  Rules from different contexts are considered different even if they might be
  identical strings.

  Attributes:
    total: Total number of rules.
    by_num_context_examples: Number of rules broken down by the number of
      context examples that depend on the rule.
    by_num_reliable_context_examples: Number of rules broken down by the number
      of reliable illustrative examples.
    by_min_reliable_derivation_level: Number of rules broken down by the
      smallest derivation level among all reliable context examples that depend
      on the rule. Only counts rules that were illustrated in at least one
      reliable example (which thus leaves out omitted, unreliable and distractor
      rules).
    by_num_context_variable_substitutions: Number of rules broken down by the
      minimum number of variable substitutions across all of the variables of
      the given rule. Only counts rules that have variables. If there is at
      least one example for which the given variable is left unsubstituted, then
      considers the number of variable substitutions to be infinite. Only counts
      rules that were illustrated in at least one example.
    by_num_context_outer_substitutions: Number of rules broken down by the
      number of outer substitutions. If there is at least one example for which
      the given rule is the topmost rule in the application tree (i.e., for
      which the outer substitution is '__'), then considers the number of outer
      substitutions to be infinite. Only counts rules that were illustrated in
      at least one example.
    by_num_context_reliable_variable_substitutions: Same as
      by_num_context_variable_substitutions, but only counting substitutions
      observed in reliable examples.
    by_num_context_reliable_outer_substitutions: Same as
      by_num_context_outer_substitutions, but only counting substitutions
      observed in reliable examples.
  """
  total: int = 0
  by_num_context_examples: collections.Counter[int] = _counter_field()
  by_num_reliable_context_examples: collections.Counter[int] = _counter_field()
  by_min_reliable_derivation_level: collections.Counter[int] = _counter_field()
  by_num_context_variable_substitutions: collections.Counter[int] = (
      _counter_field())
  by_num_context_outer_substitutions: collections.Counter[int] = (
      _counter_field())
  by_num_context_reliable_variable_substitutions: collections.Counter[int] = (
      _counter_field())
  by_num_context_reliable_outer_substitutions: collections.Counter[int] = (
      _counter_field())

  def __add__(self, other):
    return RuleCounters(
        total=self.total + other.total,
        by_num_context_examples=(self.by_num_context_examples +
                                 other.by_num_context_examples),
        by_num_reliable_context_examples=(
            self.by_num_reliable_context_examples +
            other.by_num_reliable_context_examples),
        by_min_reliable_derivation_level=(
            self.by_min_reliable_derivation_level +
            other.by_min_reliable_derivation_level),
        by_num_context_variable_substitutions=(
            self.by_num_context_variable_substitutions +
            other.by_num_context_variable_substitutions),
        by_num_context_outer_substitutions=(
            self.by_num_context_outer_substitutions +
            other.by_num_context_outer_substitutions),
        by_num_context_reliable_variable_substitutions=(
            self.by_num_context_reliable_variable_substitutions +
            other.by_num_context_reliable_variable_substitutions),
        by_num_context_reliable_outer_substitutions=(
            self.by_num_context_reliable_outer_substitutions +
            other.by_num_context_reliable_outer_substitutions),
    )

  def update_with_context(self, context,
                          rules):
    """Updates the rule counters with information in the example group.

    Args:
      context: The context whose content and metadata will be used to update the
        counters.
      rules: The subset of rules from the given context to consider (e.g., could
        be the full set of rules from the given context, or just the explicit
        rules, etc.).
    """
    self.total += len(rules)

    num_context_examples_by_rule = collections.Counter(
        {rule: 0 for rule in rules})
    num_reliable_context_examples_by_rule = collections.Counter(
        {rule: 0 for rule in rules})
    reliable_derivation_levels_by_rule = collections.defaultdict(list)

    for context_example in context:
      for rule in context_example.metadata.rules:
        if rule not in rules:
          continue
        num_context_examples_by_rule[rule] += 1
        if not context_example.is_unreliable:
          num_reliable_context_examples_by_rule[rule] += 1
          reliable_derivation_levels_by_rule[rule].append(
              context_example.metadata.derivation_level)

    for num_context_examples in num_context_examples_by_rule.values():
      self.by_num_context_examples[num_context_examples] += 1

    for num_context_examples in num_reliable_context_examples_by_rule.values():
      self.by_num_reliable_context_examples[num_context_examples] += 1

    for reliable_derivation_levels in (
        reliable_derivation_levels_by_rule.values()):
      self.by_min_reliable_derivation_level[min(
          reliable_derivation_levels)] += 1

    for rule, mapping in (
        context.metadata.variable_substitutions_by_rule.items()):
      if rule not in rules:
        continue
      num_substitutions = (
          production_trees.get_effective_min_num_variable_substitutions(mapping)
      )
      self.by_num_context_variable_substitutions[num_substitutions] += 1

    for rule, substitutions in (
        context.metadata.outer_substitutions_by_rule.items()):
      if rule not in rules:
        continue
      num_substitutions = (
          production_trees.get_effective_num_outer_substitutions(substitutions))
      self.by_num_context_outer_substitutions[num_substitutions] += 1

    for rule, mapping in (
        context.metadata.reliable_variable_substitutions_by_rule.items()):
      if rule not in rules:
        continue
      num_substitutions = (
          production_trees.get_effective_min_num_variable_substitutions(mapping)
      )
      self.by_num_context_reliable_variable_substitutions[
          num_substitutions] += 1

    for rule, substitutions in (
        context.metadata.reliable_outer_substitutions_by_rule.items()):
      if rule not in rules:
        continue
      num_substitutions = (
          production_trees.get_effective_num_outer_substitutions(substitutions))
      self.by_num_context_reliable_outer_substitutions[num_substitutions] += 1


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class RuleBreakdownCounters:
  """Bundle of rule counters broken down by rule type.

  Attributes:
    all: Counters tracking all rules regardless of type.
    explicit: Counters tracking explicit rules.
    hidden_true: Counters tracking hidden rules with rule reply TRUE.
    hidden_unknown: Counters tracking hidden rules with rule reply FALSE.
    unreliable: Counters tracking unreliable rules.
    distractor: Counters tracking distractor rules.
    omitted: Counters tracking omitted rules.
  """

  all: RuleCounters = dataclasses.field(default_factory=RuleCounters)
  explicit: RuleCounters = dataclasses.field(default_factory=RuleCounters)
  hidden_true: RuleCounters = dataclasses.field(default_factory=RuleCounters)
  hidden_unknown: RuleCounters = dataclasses.field(default_factory=RuleCounters)
  unreliable: RuleCounters = dataclasses.field(default_factory=RuleCounters)
  distractor: RuleCounters = dataclasses.field(default_factory=RuleCounters)
  omitted: RuleCounters = dataclasses.field(default_factory=RuleCounters)

  def __add__(self, other):
    return RuleBreakdownCounters(
        all=self.all + other.all,
        explicit=self.explicit + other.explicit,
        hidden_true=self.hidden_true + other.hidden_true,
        hidden_unknown=self.hidden_unknown + other.hidden_unknown,
        unreliable=self.unreliable + other.unreliable,
        distractor=self.distractor + other.distractor,
        omitted=self.omitted + other.omitted)

  def update_with_context(self, context):
    """Updates the rule counters with information in the example group.

    Args:
      context: The context whose content and metadata will be used to update the
        counters.
    """
    self.all.update_with_context(
        context,
        set(context.metadata.rules) | set(context.metadata.distractor_rules))
    self.explicit.update_with_context(context,
                                      set(context.metadata.explicit_rules))
    self.hidden_true.update_with_context(
        context, set(context.metadata.hidden_true_rules))
    self.hidden_unknown.update_with_context(
        context, set(context.metadata.hidden_unknown_rules))
    self.unreliable.update_with_context(context,
                                        set(context.metadata.unreliable_rules))
    self.distractor.update_with_context(context,
                                        set(context.metadata.distractor_rules))
    self.omitted.update_with_context(context,
                                     set(context.metadata.omitted_rules))


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class GenerationErrorCounters:
  """Counters tracking unrecoverable errors occurring during sampling.

  Attributes:
    failed_to_illustrate_target_rule: Number of times that the generation
      algorithm gave up on generating an example illustrating a given target
      rule due to the maximum number of failed attempts being reached. Includes
      both top-level examples and examples nested inside of other examples'
      contexts. Whenever this event occurs, the resulting context or top-level
      dataset will end up containing at least one fewer example than originally
      intended. For a healthy dataset, this number should ideally be 0.
    failed_to_generate_example_of_desired_request_type: Number of times that the
      generation algorithm gave up on generating an example of a certain request
      type (e.g., a negative rule request) due to the maximum number of failed
      attempts being reached. Whenever this event occurs, the resulting
      top-level dataset will end up containing at least one fewer example than
      originally intended. For a healthy dataset, this number should ideally be
      zero.
    failed_to_generate_derived_production: Number of times that the generation
      algorithm gave up on generating a derived production.
    failed_to_generate_context: Number of times that the generation algorithm
      fails to generate a context of sufficient quality.
    failed_to_generate_grammar: Number of times that the generation algorithm
      fails to generate a grammar from a given grammar template. (This should
      normally be at most 1, since dataset generation will be aborted if this
      occurs.)
  """
  failed_to_illustrate_target_rule: int = 0
  failed_to_generate_example_of_desired_request_type: int = 0
  failed_to_generate_derived_production: int = 0
  failed_to_generate_context: int = 0
  failed_to_generate_grammar: int = 0

  def __add__(self, other):
    return GenerationErrorCounters(
        failed_to_illustrate_target_rule=self.failed_to_illustrate_target_rule +
        other.failed_to_illustrate_target_rule,
        failed_to_generate_example_of_desired_request_type=self
        .failed_to_generate_example_of_desired_request_type +
        other.failed_to_generate_example_of_desired_request_type,
        failed_to_generate_derived_production=self
        .failed_to_generate_derived_production +
        other.failed_to_generate_derived_production,
        failed_to_generate_context=self.failed_to_generate_context +
        other.failed_to_generate_context,
        failed_to_generate_grammar=self.failed_to_generate_grammar +
        other.failed_to_generate_grammar)


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class GenerationCounters:
  """Counters tracking various statistics about a benchmark generation run.

  Attributes:
    example_attempts: Counters tracking numbers of example generation attempts.
    examples: Counters tracking numbers of examples generated in the dataset.
    context_examples: Counters tracking number of examples in the contexts in
      the dataset.
    context_attempts: Counters tracking numbers of context generation attempts.
    contexts: Counters tracking numbers of contexts generated in the dataset.
    rules: Counters tracking numbers of rules generated in the dataset.
    errors: Counters tracking unrecoverable errors occurring during sampling.
  """
  example_attempts: ExampleAttemptCounters = dataclasses.field(
      default_factory=ExampleAttemptCounters)
  examples: ExampleCounters = dataclasses.field(default_factory=ExampleCounters)
  context_examples: ExampleCounters = dataclasses.field(
      default_factory=ExampleCounters)
  context_attempts: ContextAttemptCounters = dataclasses.field(
      default_factory=ContextAttemptCounters)
  contexts: ContextCounters = dataclasses.field(default_factory=ContextCounters)
  rules: RuleBreakdownCounters = dataclasses.field(
      default_factory=RuleBreakdownCounters)
  errors: GenerationErrorCounters = dataclasses.field(
      default_factory=GenerationErrorCounters)

  def __add__(self, other):
    return GenerationCounters(
        example_attempts=self.example_attempts + other.example_attempts,
        examples=self.examples + other.examples,
        context_examples=self.context_examples + other.context_examples,
        context_attempts=self.context_attempts + other.context_attempts,
        contexts=self.contexts + other.contexts,
        rules=self.rules + other.rules,
        errors=self.errors + other.errors)


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class GenerationTiming:
  """Statistics on elapsed time during a benchmark generation run.

  Attributes:
    generate_dataset: Seconds elapsed while generating and writing the dataset.
    split_dataset: Seconds elapsed while generating and writing the splits.
    summarize_dataset: Seconds elapsed while generating the dataset summary.
    total: Seconds elapsed during the benchmark generation run as a whole.
  """
  generate_dataset: float = 0.0
  split_dataset: float = 0.0
  summarize_dataset: float = 0.0

  @property
  def total(self):
    return self.generate_dataset + self.split_dataset + self.summarize_dataset

  def __add__(self, other):
    return GenerationTiming(
        generate_dataset=self.generate_dataset + other.generate_dataset,
        split_dataset=self.split_dataset + other.split_dataset,
        summarize_dataset=self.summarize_dataset + other.summarize_dataset)


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class SplittingStatsOfAlgorithm:
  """Bundle of statistics about dataset splitting with a splitting algorithm.

  Since this will be serialized by dataclasses_json's to_json method, we need
  all dict keys to be strings.  For pairs of splits, we use keys such as
  "train:test", and for compounds (which are originally tuples of strings), we
  call str.

  Attributes:
    atom_divergence: Mapping from pairs of splits to atom divergence.
    compound_divergence: Mapping from pairs of splits to compound divergence.
    num_items_by_atom_by_split: Mapping from splits to the mapping from atoms to
      the number of items (contexts or top-level examples) in which the atom
      appears.
    num_items_by_compound_by_split: Mapping from splits to the mapping from
      compounds to the number of items (contexts or top-level examples) in which
      the compound appears.
    atom_coverage_by_split: Mapping from splits to the fraction of all atoms
      appearing in the split.
    compound_coverage_by_split: Mapping from splits to the fraction of all
      compounds appearing in the split.
  """
  atom_divergence: Dict[str, float] = dataclasses.field(default_factory=dict)
  compound_divergence: Dict[str, float] = (
      dataclasses.field(default_factory=dict))
  num_items_by_atom_by_split: Dict[tfds.Split, Dict[str, int]] = (
      dataclasses.field(default_factory=dict))
  num_items_by_compound_by_split: Dict[tfds.Split, Dict[str, int]] = (
      dataclasses.field(default_factory=dict))
  atom_coverage_by_split: Dict[tfds.Split,
                               float] = dataclasses.field(default_factory=dict)
  compound_coverage_by_split: Dict[tfds.Split, float] = dataclasses.field(
      default_factory=dict)


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class SplittingStats:
  """Bundle of statistics about dataset splitting.

  Attributes:
    stats_by_algorithm: Mapping from inputs.SplitBy to statistics on the
      splitting.
  """
  stats_by_algorithm: Dict[inputs.SplitBy, SplittingStatsOfAlgorithm] = (
      dataclasses.field(default_factory=dict))


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class GenerationStats:
  """Bundle of statistics about a benchmark generation run.

  Attributes:
    counters: Counters tracking various statistics about the run.
    timing: Statistics on elapsed time during the run.
    splitting_stats: Statistics on the splitting.
  """
  counters: GenerationCounters = dataclasses.field(
      default_factory=GenerationCounters)
  timing: GenerationTiming = dataclasses.field(default_factory=GenerationTiming)
  splitting_stats: SplittingStats = (
      dataclasses.field(default_factory=SplittingStats))

  def __add__(self, other):
    return GenerationStats(
        counters=self.counters + other.counters,
        timing=self.timing + other.timing)
