# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for calculating similarity between contexts and examples."""

import collections
from typing import AbstractSet, Any, Mapping, Sequence

import numpy as np

from conceptual_learning.cscan import conceptual_learning as cl


def _calculate_value_overlap(value_set,
                             other_value_set):
  """Returns the fraction of values that also appear in the other set.

  This is similar to the cosine similarity between the two sets, except that
  rather than normalizing the similarity measure based on the number values in
  each of the sets, this normalizes purely based on the number of values in the
  first value set. This is a more appropriate measure than the cosine similarity
  in the case where the first value set represents a test example, while the
  other represents a train example, and we care primarily about the degree to
  which the information observed in train subsumes that required in test.

  Args:
    value_set: Set of values of interest.
    other_value_set: Other set in which we will check whether the values occur.
  """
  if value_set:
    result = (1.0 * len(set.intersection(value_set, other_value_set)) /
              len(value_set))
  else:
    result = 1.0
  return result


def calculate_context_similarity_by_rules(
    context1,
    context2):
  """Returns similarity between the two contexts' rule sets."""
  return _calculate_value_overlap(
      set(context1.metadata.rules), set(context2.metadata.rules))


def calculate_context_similarity_by_examples(
    context1,
    context2):
  """Returns similarity between the two contexts' example sets."""
  return _calculate_value_overlap(set(context1), set(context2))


def calculate_example_to_context_similarity_by_rules(
    example1,
    context2):
  """Returns similarity between the rules used by an example and a context."""
  return _calculate_value_overlap(
      set(example1.metadata.rules), set(context2.metadata.rules))


def _index_contexts_by_rule(
    dataset):
  """Returns a mapping of rules to context indices."""
  context_indices_by_rule = collections.defaultdict(list)
  for context_index, example_group in enumerate(dataset.example_groups):
    for rule in example_group.context.metadata.rules:
      context_indices_by_rule[rule].append(context_index)
  return context_indices_by_rule


def _index_contexts_by_example(
    dataset):
  """Returns a mapping of md5 hashes of context examples to context indices."""
  context_indices_by_example_hash = collections.defaultdict(list)
  for context_index, example_group in enumerate(dataset.example_groups):
    for example in example_group.context:
      context_indices_by_example_hash[example.get_md5_hash()].append(
          context_index)
  return context_indices_by_example_hash


def calculate_context_similarity_matrix_by_rules(
    validation_or_test_set,
    train_set):
  """Returns a matrix representing similarities between test and train contexts.

  similarity_matrix[i][j] holds a double value between 0 and 1 representing the
  similarity between the ith test context and the jth train context, where 0
  means that none of the rules from the test context appear in the train
  context, and 1 means that all of them appear in the train context.

  Args:
    validation_or_test_set: Validation or test set containing the contexts for
      which train similarity is to be evaluated.
    train_set: Train set containing the contexts to be compared against.
  """
  # For efficiency, we first index the contexts by rule, and then process only
  # the cross product of the train/test contexts that share any given rule.
  # We thereby avoid processing a full cross product of all train/test contexts,
  # which could potentially be much larger. (This speeds up the similarity
  # matrix calculation by around 30x on a 10K context dataset.)
  similarity_matrix = np.zeros((len(validation_or_test_set.example_groups),
                                len(train_set.example_groups)))

  test_context_indices_by_rule = _index_contexts_by_rule(validation_or_test_set)
  train_context_indices_by_rule = _index_contexts_by_rule(train_set)

  # Tally up the number of shared rules.
  for rule, test_context_indices in test_context_indices_by_rule.items():
    train_context_indices = train_context_indices_by_rule.get(rule, ())
    for test_index in test_context_indices:
      for train_index in train_context_indices:
        similarity_matrix[test_index][train_index] += 1.0

  # Divide by the total number of rules in the test context.
  for test_index, test_example_group in enumerate(
      validation_or_test_set.example_groups):
    num_test_context_rules = len(test_example_group.context.metadata.rules)
    if num_test_context_rules:
      similarity_matrix[test_index] /= num_test_context_rules

  return similarity_matrix


def calculate_context_similarity_matrix_by_examples(
    validation_or_test_set,
    train_set):
  """Returns a matrix representing similarities between test and train contexts.

  similarity_matrix[i][j] holds a double value between 0 and 1 representing the
  similarity between the ith test context and the jth train context, where 0
  means that none of the examples from the test context appear in the train
  context, and 1 means that all of them appear in the train context.

  Args:
    validation_or_test_set: Validation or test set containing the contexts for
      which train similarity is to be evaluated.
    train_set: Train set containing the contexts to be compared against.
  """
  # For efficiency, we first index the contexts by example, and then process
  # only the cross product of the train/test contexts that share any given
  # examples. We thereby avoid processing a full cross product of all train/test
  # contexts, which could potentially be much larger. (This speeds up the
  # similarity matrix calculation by around 30x on a 10K context dataset.)
  similarity_matrix = np.zeros((len(validation_or_test_set.example_groups),
                                len(train_set.example_groups)))

  test_context_indices_by_example = _index_contexts_by_example(
      validation_or_test_set)
  train_context_indices_by_example = _index_contexts_by_example(train_set)

  # Tally up the number of shared examples.
  for example, test_context_indices in test_context_indices_by_example.items():
    train_context_indices = train_context_indices_by_example.get(example, ())
    for test_index in test_context_indices:
      for train_index in train_context_indices:
        similarity_matrix[test_index][train_index] += 1.0

  # Divide by the total number of examples in the test context.
  for test_index, test_example_group in enumerate(
      validation_or_test_set.example_groups):
    num_test_context_examples = len(test_example_group.context)
    if num_test_context_examples:
      similarity_matrix[test_index] /= num_test_context_examples
  return similarity_matrix
