# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Consistency metric calculation utils."""

import dataclasses
import itertools
from typing import FrozenSet, Iterable, List, Mapping, Sequence, Set, Tuple

from absl import logging
import dataclasses_json
import nltk

from conceptual_learning.cscan import benchmark_generation
from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import inference
from conceptual_learning.cscan import inputs
from conceptual_learning.cscan import nltk_utils
from conceptual_learning.cscan import production_composition

Prediction = Tuple[cl.Qualifier, str]
Predictions = Mapping[str, Prediction]
ExampleGroup = Sequence[cl.Example]
ExampleGroups = Sequence[ExampleGroup]
MinimalSet = FrozenSet[cl.Example]
MinimalSets = Set[MinimalSet]


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class CandidateImplication:
  """Candidate implication relationship from source examples to an example.

  This structure holds a candidate implication with the implication direction
  preserved. The idea is that the productions of source_examples can be composed
  in a certain way to produce the production of the example of interest. A
  candidate implication relation is considered an implication if and only if
  the qualifier of the source_examples set matches that of the example. If the
  source set is monotonic and the example is defeasible, we consider it to be
  a qualifier contradiction.
  WARNING: As checking for whether source_examples can actually be composed
    to produce example will slow the computations greatly, it is left to the
    caller to ensure this always holds.
  """
  example: cl.Example
  source_examples: FrozenSet[cl.Example]

  def is_qualifier_contradiction(self):
    """Returns True if the example's qualifier contradicts the source set."""
    set_qualifier = self._get_source_examples_set_qualifier()
    if (set_qualifier == cl.Qualifier.M and
        self.example.qualifier == cl.Qualifier.D):
      return True
    else:
      return False

  def is_implication(self):
    """Returns True if the example has the same qualifier as the source set."""
    set_qualifier = self._get_source_examples_set_qualifier()
    return set_qualifier == self.example.qualifier

  def _get_source_examples_set_qualifier(self):
    """Returns the qualifier of the source examples set.

    A set qualifier is monotonic if all the examples are monotonic. Otherwise
    it's considered defeasible.
    """
    set_qualifier = cl.Qualifier.M
    for example in self.source_examples:
      if example.qualifier == cl.Qualifier.D:
        set_qualifier = cl.Qualifier.D
        break
    return set_qualifier


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class ExampleGroupConsistency:
  """Consistency information for examples that share the same context.

  Attributes:
    context_id: Identifier for the group.
    contradictions: The number of contradicted examples in the group.
    implications: The number of implied examples in the group.
  """
  context_id: str
  contradictions: int
  implications: int

  def consistency(self):
    """Returns the consistency of the example group."""
    if (self.implications + self.contradictions) == 0:
      return -1
    return self.implications / (self.implications + self.contradictions)


@dataclasses_json.dataclass_json
@dataclasses.dataclass
class ExampleSetConsistency:
  """Consistency information for groups of examples.

  Attributes:
    groups_consistency: Sequence of consistency informations for the groups of
      examples.
  """
  groups_consistency: List[ExampleGroupConsistency]

  def consistency(self):
    """Returns the consistency of all the example groups.

    The consistency is calculated by aggregating all the implications and
      contradictions across the groups.
    """
    implications = self.implications()
    contradictions = self.contradictions()
    if (implications+contradictions) == 0:
      return -1
    return implications / (implications + contradictions)

  def implications(self):
    """Aggregates the implications across the groups."""
    return sum([
        group_consistency.implications
        for group_consistency in self.groups_consistency
    ])

  def contradictions(self):
    """Aggregates the contradictions across the groups."""
    return sum([
        group_consistency.contradictions
        for group_consistency in self.groups_consistency
    ])


def get_consistency_file_name(prediction_step = -1,
                              is_validation = False):
  """Returns the file name of the consistency metric."""
  if is_validation:
    file_prefix = 'eval'
  else:
    file_prefix = 'test'

  if prediction_step < 0:
    # In case the prediction step is unknown.
    prediction_step_str = ''
  else:
    prediction_step_str = str(prediction_step)
  # On the format [eval|test]_consistency[step_number].json.
  consistency_file_name = (
      f'{file_prefix}_consistency{prediction_step_str}.json')
  return consistency_file_name


def _get_base_inference_engine(
    dataset_spec):
  """Returns an inference engine with the base productions added to it.

  The inference engine will contain the pass-through rules and optionally the
  productions of `base_examples` if provided.

  Args:
    dataset_spec: Spec used to generate the dataset of interest. We need this
      information as different dataset specs have different implicit knowledge
      which we assume is known monotonically and is fixed throughout the
      contexts.
  """
  phrase_structure_grammar = (
      benchmark_generation
      .load_fixed_phrase_structure_grammar_template_for_dataset_spec(
          dataset_spec))
  base_inference_engine = inference.InferenceEngine(
      track_multiple_provenances=True)
  if phrase_structure_grammar is None:
    raise ValueError(
        'No phrase structure grammar was detected for dataset spec '
        f'{dataset_spec.id}')
  for rule in phrase_structure_grammar.get_all_rules():
    rule_string = rule.to_rule_string()
    rule_production = nltk_utils.production_from_production_string(rule_string)
    if nltk_utils.is_pass_through_rule(rule_production):
      base_inference_engine.add_production(rule_production, is_monotonic=True)
  return base_inference_engine


def _get_production_from_example_and_prediction(
    example, prediction_reply):
  """Returns the production for the model prediction.

  The production is constructed with reference to the ground truth production.
  If the example is a rule request example it must be asserted by the prediction
  and in that case we don't need to change anything. If a rule request is not
  asserted by the prediction a ValueError will be raised.
  For non-rule requests we replace the lhs with the model predictions.

  Args:
    example: Original example. Only the metadata is used to produce the
      prediction production.
    prediction_reply: Model prediction's reply.
  """
  if example.get_request_type() == cl.RequestType.NON_RULE:
    production = example.metadata.production
    new_sem_features_list = prediction_reply.split()
    # Search if there are variables in the predicted tokens, if there are, cast
    # them as nltk.Variable. This is important because equating production
    # objects is sensitive to the type of tokens in the sem features.
    values = []
    for feature in new_sem_features_list:
      if nltk_utils.is_output_token_a_variable(feature):
        values.append(nltk.Variable(feature))
      else:
        values.append(feature)
    if len(values) > 1:
      new_sem = nltk.featstruct.FeatureValueConcat(values)
    elif len(values) == 1:
      new_sem = values[0]
    else:
      new_sem = ''
    new_lhs = nltk.grammar.FeatStructNonterminal({
        'sem': new_sem,
        nltk.grammar.TYPE: production.lhs()[nltk.grammar.TYPE]
    })
    prediction_production = nltk.Production(lhs=new_lhs, rhs=production.rhs())
    return prediction_production
  else:
    if prediction_reply != cl.RuleReply.TRUE:
      raise ValueError('Only asserted rules are supported. Got '
                       f'{prediction_reply}')
    prediction_production = example.metadata.production
    return prediction_production


def _merge_examples_and_predictions(
    examples,
    predictions):
  """Replaces the ground truth of each example by the model predictions."""
  prediction_set = []
  for example in examples:
    example_md5_hash = example.get_md5_hash()
    if example_md5_hash not in predictions:
      logging.warning('Prediction was not found for example with md5 hash: %r.',
                      example_md5_hash)
      continue
    qualifier, reply = predictions[example_md5_hash]
    if (example.get_request_type() == cl.RequestType.RULE and
        reply != cl.RuleReply.TRUE):
      continue
    new_example = cl.Example(
        context=example.context,
        request=example.request,
        reply=reply,
        qualifier=qualifier,
        metadata=cl.ExampleMetadata(
            production=_get_production_from_example_and_prediction(
                example, reply)))
    prediction_set.append(new_example)
  return prediction_set


def _get_unique_source_productions_from_provenances(
    provenances
):
  """Extracts the unique source production sets from a list of provenances."""
  unique_sets = set()
  for provenance in provenances:
    source_productions_list = [
        p for (p, _) in provenance.compositions
        if not nltk_utils.is_pass_through_rule(p)
    ]
    if not nltk_utils.is_pass_through_rule(provenance.source):
      source_productions_list.append(provenance.source)
    unique_sets.add(frozenset(source_productions_list))
  return unique_sets


def _implication_sets_from_inference_engine(
    inference_engine,
    examples_by_production,
):
  """Returns the minimal sets with implications and qualifier contradictions.

  Args:
    inference_engine: The inference engine state. This function will access
      `provenances_by_production` attribute to determine implications and
      qualifier contradictions.
    examples_by_production: Source examples (excluding the pass-through rules).
      that were added to the inference engine.
  """
  candidate_implications = set()
  # Iterate through the source examples to check for direct implications.
  for production, examples in examples_by_production.items():
    if len(examples_by_production[production]) > 1:
      # If there is more than one example per production then we already have
      # either an implication relation or a qualifier contradiction relation
      # involving those two examples alone.
      for source_example, target_example in itertools.permutations(
          examples_by_production[production], 2):
        candidate_implication = CandidateImplication(
            example=target_example, source_examples=frozenset([source_example]))
        candidate_implications.add(candidate_implication)
    # Additionally, we will have one candidate implication for each set of
    # source productions from the production provenance (as the premise) and
    # example (as the conclusion).
    production_provenances = inference_engine.provenances_by_production.get(
        production, [])
    unique_source_productions = _get_unique_source_productions_from_provenances(
        production_provenances)
    for source_productions in unique_source_productions:
      if source_productions == frozenset((production,)):
        # Ignore self provenance.
        continue
      source_examples_options = [
          examples_by_production[p] for p in source_productions
      ]
      # For each production there is one or more examples. To get all the
      # possible combinations we perform a cartesian product of all these
      # options and the result would be a unique set of sequences.
      # For example if source_productions were [p1, p2, p3] and
      # examples_by_production was
      # {
      #     p1: [e1],
      #     p2:[e2.1, e2.2],
      #     p3: [e3]
      # }
      # source_examples_options would be [[e1], [e2.1, e2.2], [e3]]
      # then the product will be [[e1, e2.1, e3], [e1, e2.2, e3]].
      flat_source_examples_options = itertools.product(
          *source_examples_options)
      for source_examples_option in (
          flat_source_examples_options):
        for example in examples:
          candidate_implication = CandidateImplication(
              example=example,
              source_examples=frozenset(source_examples_option))
          candidate_implications.add(candidate_implication)

  return candidate_implications


def _contradictory_sets_from_inference_engine(
    inference_engine,
    examples_by_production
):
  """Returns the minimal contradictory sets in the inference engine.

  This function performs the following steps:
    1- Contradictions are extracted by checking all the directly contradicted
      productions in the inference engine.
    2- The provenances of these contradictory pairs are retrieved and the source
      productions are extracted for each one.
    3- Then to get the candidate minimal contradictory sets we take all possible
      (source1, source2) pairs and compute the union.
    4- Ignore each set that is a superset to another set in this list.

  Args:
    inference_engine: The inference engine state. This function will access
      `provenances_by_production` attribute to determine implications and
      qualifier contradictions.
    examples_by_production: Source examples (excluding the pass-through rules)
      that were added to the inference engine.
  """
  contradictory_sets = []
  for current_production in inference_engine.all_productions:
    # Retrieve the direct contradictions for current_production.
    direct_contradictions = inference_engine.get_contradicting_productions(
        current_production).difference({current_production})
    if not direct_contradictions:
      # If there are no direct contradictions, move on to the next production.
      continue
    # Extract unique source productions sets from the provenances of
    # current_production.
    provenances = inference_engine.provenances_by_production.get(
        current_production, set())
    unique_source_productions = (
        _get_unique_source_productions_from_provenances(provenances))
    for other_contradicting_production in direct_contradictions:
      # Extract unique source productions sets from the provenances of
      # the other contradicting production.
      other_provenances = inference_engine.provenances_by_production.get(
          other_contradicting_production, [])
      unique_other_source_productions = (
          _get_unique_source_productions_from_provenances(other_provenances))
      # Take the cartesian product of the unique sources for both current and
      # other contradicting productions.
      # For example if current_production has {source1_1, source1_2} and
      # other_contradicting_production has {source2_1, source2_2, source2_3}
      # then the contradictory sets would be
      # {
      #     source1_1.union(source2_1),
      #     source1_1.union(source2_2),
      #     source1_1.union(source2_3),
      #     source1_2.union(source2_1),
      #     source1_2.union(source2_2),
      #     source1_2.union(source2_3),
      # }
      for source, other_source in itertools.product(
          unique_source_productions, unique_other_source_productions):
        # If the directly contradicted productions are source productions
        # then add them to the sources sets.
        if current_production in examples_by_production:
          source = source.union({current_production})
        if other_contradicting_production in examples_by_production:
          other_source = source.union({other_contradicting_production})
        contradictory_set = source.union(other_source)
        contradictory_sets.append(contradictory_set)
  # Sort the contradictory sets by their length in ascending order.
  sorted_contradictory_sets = sorted(contradictory_sets, key=len)
  # A set will be considered minimal only if we haven't encountered a subset
  # of it when iterating through the sourted contradictory sets.
  minimal_contradictory_productions_sets = set()
  for contradictory_set in sorted_contradictory_sets:
    # Iterate through the previously verified minimal sets. If any of them
    # is a subset to contradictory_set then we ignore contradictory_set,
    # otherwise we add it to the minimal_contradictory_productions_sets.
    for minimal_contradictory_productions_set in (
        minimal_contradictory_productions_sets):
      if minimal_contradictory_productions_set.issubset(contradictory_set):
        break
    else:
      minimal_contradictory_productions_sets.add(contradictory_set)
  # Then we go a step further by constructing the
  # minimal_contradictory_examples_sets, which is similar to
  # minimal_contradictory_productions_sets, except that it uses examples instead
  # of productions, and each minimal_contradictory_productions_set can
  # correspond to one or more minimal_contradictory_examples_set. This
  # one-to-many relation is due to the fact that one production can be
  # associated with one or more examples objects.
  minimal_contradictory_examples_sets = set()
  for minimal_contradictory_productions_set in (
      minimal_contradictory_productions_sets):
    minimal_contradictory_examples_options = [
        examples_by_production[p] for p in minimal_contradictory_productions_set
    ]
    minimal_contradictory_examples_options = itertools.product(
        *minimal_contradictory_examples_options)
    for minimal_contradictory_examples_option in (
        minimal_contradictory_examples_options):
      minimal_contradictory_examples_sets.add(
          frozenset(minimal_contradictory_examples_option))
  return minimal_contradictory_examples_sets


def consistency_for_example_group(
    example_group,
    context_id,
    inference_engine,
    example_level_consistency = False):
  """Returns the consistency measures of example_set_pool.

  Args:
    example_group: Examples that share similar context so they are supposed to
      be consistent.
    context_id: MD5 hash of the context, used as an identifier for each examples
      group.
    inference_engine: Base inference engine to use. It's assumed that the fixed
      phrase structure is pre-loaded here.
    example_level_consistency: If true, return how many examples were implied
      and how many were contradicted, otherwise return how many minimal
      contradictory and implication sets are there.
  """
  # Create a set of the productions of all the examples in example_set_pool,
  # This is necessary when we want to exclude the pass-through rules when
  # checking for source productions.
  examples_by_production = dict()
  for example in example_group:
    examples_by_production.setdefault(example.metadata.production, [])
    examples_by_production[example.metadata.production].append(example)
  # Construct the inference engine by force adding all the examples in the
  # example group.
  for example in example_group:
    inference_engine.force_add_production(
        example.metadata.production,
        is_monotonic=example.qualifier == cl.Qualifier.M)
  # Extract the minimal implication and contradictory sets from the inference
  # engine.
  minimal_contradictory_sets = _contradictory_sets_from_inference_engine(
      inference_engine, examples_by_production)
  candidate_implications = _implication_sets_from_inference_engine(
      inference_engine, examples_by_production)

  # If example_level_consistency is True we use C† for the consistency
  # calculation, otherwise we use the standard definition
  # of the consistency metric.
  if example_level_consistency:
    # For example level consistency, implications are treated as a one direction
    # relation from a set of examples to one example, so we only count the
    # examples being implied. While for contradictions every example belonging
    # to a minimal contradictory set is considered to be contradicted by the
    # rest of the examples in that set.
    contradicted_examples = set()
    implied_examples = set()
    for contradictory_set in minimal_contradictory_sets:
      contradicted_examples.update(contradictory_set)
    for candidate_implication in candidate_implications:
      if candidate_implication.is_qualifier_contradiction():
        # If the implied example's qualifier contradicts the source set then
        # consider it to be a contradiction. Note that qualifier contradiction
        # is also considered to be one directional so we only consider the
        # implied example to be contradicted.
        contradicted_examples.add(candidate_implication.example)
      elif candidate_implication.is_implication():
        # An example is considered to be implied if and only if it can be
        # inferred using the source examples and its qualifier matches that of
        # the set.
        implied_examples.add(candidate_implication.example)
    implications_count = len(implied_examples)
    contradictions_count = len(contradicted_examples)
  else:
    # Use the standard definition of the consistency.
    # Contradictions are the number of minimal contradictory sets extracted from
    # the inference engine + the number of qualifier contradicted examples.
    # Implications are the number of unique minimal implication sets. minimal
    # implication sets are constructed by taking the union of the implication
    # source set and the example being implied. This is necessary as we might
    # have a case of e1 -> e2 and e2 -> e1 relationships but according to the
    # definition these two are considered one implication set.
    contradictions_count = len(minimal_contradictory_sets)
    implied_sets = set()
    for candidate_implication in candidate_implications:
      if candidate_implication.is_qualifier_contradiction():
        contradictions_count += 1
      elif candidate_implication.is_implication():
        implied_sets.add(
            candidate_implication.source_examples.union(
                {candidate_implication.example}))
    implications_count = len(implied_sets)

  examples_consistency = ExampleGroupConsistency(
      context_id=context_id,
      implications=implications_count,
      contradictions=contradictions_count)
  return examples_consistency


def consistency_for_example_groups(
    example_groups,
    dataset_spec,
    example_level_consistency = False):
  """Returns consistency of example groups."""
  # First we create an inference engine with only the pass-through rules (as
  # those are considered implicit knowledge and are fixed throughout the
  # dataset).

  groups_consistency = []
  for example_group in example_groups:
    # Ignore empty example groups. An example group can be empty in some cases
    # like when creating MCD splits.
    if not bool(example_group):
      continue
    context = example_group[0].context
    inference_engine = _get_base_inference_engine(dataset_spec)
    group_consistency = consistency_for_example_group(
        example_group,
        context_id=context.get_md5_hash(),
        inference_engine=inference_engine,
        example_level_consistency=example_level_consistency,
    )
    groups_consistency.append(
        group_consistency,

    )
  return ExampleSetConsistency(groups_consistency=groups_consistency)


def compute_consistency_for_model_predictions(
    example_set, predictions,
    dataset_spec,
    example_level_consistency = False):
  """Returns the consistency of model predictions.

  Args:
    example_set: Ground truth examples, used to extract production information.
    predictions: Model predictions for each example as a reply and a qualifier.
    dataset_spec: Dataset spec of `example_set`. Used to load the fixed phrase
      structure grammar for the inference engine.
    example_level_consistency: If true, return how many examples were implied
      and how many were contradicted, otherwise return how many minimal
      contradictory and implication sets are there.
  """
  predictions_groups = []
  for example_group in example_set.example_groups:
    # We call `ExampleGroup.to_flat_examples` to include the context in each
    # example, so that calling `get_md5_hash()` would return the correct md5
    # hash used during training.
    prediction_set = _merge_examples_and_predictions(
        example_group.to_flat_examples(), predictions)
    predictions_groups.append(prediction_set)
  example_set_consistency = consistency_for_example_groups(
      predictions_groups,
      dataset_spec=dataset_spec,
      example_level_consistency=example_level_consistency)
  return example_set_consistency
