# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for calculating and populating train similarity metadata."""

import collections
from typing import Dict, Mapping, Optional, Sequence, Tuple

import tensorflow_datasets as tfds

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import similarity_metrics


def _index_dataset_by_request(
    dataset):
  """Returns an index mapping requests to top-level examples that contain them.

  Note that while the dataset itself is in grouped form, meaning that the
  contexts are factored out from the top-level example, the examples in the
  index are self-contained (containing a context as well as request, reply and
  qualifier). These self-contained examples are constructed on the fly, with
  each of their fields pointing to the relevant existing object in the dataset.

  Args:
    dataset: Dataset containing the examples to be indexed.
  """
  examples_by_request = collections.defaultdict(list)
  for example in dataset.to_flat_examples():
    examples_by_request[example.request].append(example)
  return examples_by_request


def _sort_examples_by_context_example_overlap_similarity(
    example, other_examples
):
  """Returns the other_examples with their similarity starting with the nearest.

  Args:
    example: Example with respect to which to calculate the similarity.
    other_examples: Examples to sort.
  """

  def _similarity(other_example):
    return similarity_metrics.calculate_context_similarity_by_examples(
        example.context, other_example.context)

  return sorted([(other_example, _similarity(other_example))
                 for other_example in other_examples],
                key=lambda x: x[1],
                reverse=True)


def _get_nearest_example_with_similarity_score(
    examples_by_descending_similarity
):
  """Returns the most similar example together with its similarity."""
  if examples_by_descending_similarity:
    nearest_neighbor = examples_by_descending_similarity[0][0]
    nearest_similarity = examples_by_descending_similarity[0][1]
  else:
    nearest_neighbor = None
    nearest_similarity = 0.0
  return nearest_neighbor, nearest_similarity


def _get_consensus_reply(
    examples_by_descending_similarity
):
  """Returns the consensus reply based on similarity-weighted voting."""
  similarity_score_by_reply: Dict[str, float] = (collections.defaultdict(float))
  for example, similarity in examples_by_descending_similarity:
    similarity_score_by_reply[example.reply] += similarity
  replies_by_descending_score = sorted(
      similarity_score_by_reply.items(), key=lambda x: x[1], reverse=True)
  if replies_by_descending_score:
    return replies_by_descending_score[0][0]
  else:
    return None


def _get_consensus_qualifier(
    examples_by_descending_similarity
):
  """Returns the consensus qualifier based on similarity-weighted voting."""
  similarity_score_by_qualifier = collections.defaultdict(float)
  for example, similarity in examples_by_descending_similarity:
    similarity_score_by_qualifier[example.qualifier] += similarity
  qualifiers_by_descending_score = sorted(
      similarity_score_by_qualifier.items(), key=lambda x: x[1], reverse=True)
  if qualifiers_by_descending_score:
    return qualifiers_by_descending_score[0][0]
  else:
    return None


def _populate_example_train_similarity_metadata(
    example,
    train_examples_by_request):
  """Populates the train similarity metadata of the given example.

  Args:
    example: Example to modify.
    train_examples_by_request: Index mapping requests to top-level train
      examples that contain them.
  """
  train_examples_with_same_request = (
      train_examples_by_request[example.request])
  train_examples_with_same_request_by_descending_similarity = (
      _sort_examples_by_context_example_overlap_similarity(
          example, train_examples_with_same_request))
  nearest_neighbor, nearest_similarity = (
      _get_nearest_example_with_similarity_score(
          train_examples_with_same_request_by_descending_similarity))

  consensus_reply = _get_consensus_reply(
      train_examples_with_same_request_by_descending_similarity)
  consensus_qualifier = _get_consensus_qualifier(
      train_examples_with_same_request_by_descending_similarity)

  train_examples_by_reply = {}
  for train_example in train_examples_with_same_request:
    train_examples_by_reply.setdefault(train_example.reply,
                                       []).append(train_example)
  train_examples_with_same_request_and_reply = (
      train_examples_by_reply.get(example.reply, ()))
  train_examples_with_same_request_and_output = list(
      filter(lambda x: x.qualifier == example.qualifier,
             train_examples_with_same_request_and_reply))

  example.metadata.train_similarity = cl.ExampleTrainSimilarityMetadata(
      num_train_examples_with_same_request=(
          len(train_examples_with_same_request)),
      num_train_examples_with_same_request_and_reply=(
          len(train_examples_with_same_request_and_reply)),
      num_train_examples_with_same_request_and_output=(
          len(train_examples_with_same_request_and_output)),
      num_unique_train_replies=len(train_examples_by_reply),
      nearest_reply_matches=(nearest_neighbor.reply == example.reply
                             if nearest_neighbor else False),
      nearest_qualifier_matches=(nearest_neighbor.qualifier == example.qualifier
                                 if nearest_neighbor else False),
      nearest_similarity=nearest_similarity,
      consensus_reply_matches=(consensus_reply == example.reply),
      consensus_qualifier_matches=(consensus_qualifier == example.qualifier),
  )


def _populate_context_train_similarity_metadata(
    validation_or_test_set,
    train_set):
  """Populates the train similarity metadata of the contexts in the example set.

  Args:
    validation_or_test_set: The example set to modify.
    train_set: Train set to compare against.
  """
  context_similarity_by_rules = (
      similarity_metrics.calculate_context_similarity_matrix_by_rules(
          validation_or_test_set, train_set))
  context_similarity_by_examples = (
      similarity_metrics.calculate_context_similarity_matrix_by_examples(
          validation_or_test_set, train_set))

  for i, example_group in enumerate(validation_or_test_set.example_groups):
    example_group.context.metadata.train_similarity = (
        cl.ExampleSetTrainSimilarityMetadata(
            nearest_similarity_by_rule_overlap=(
                context_similarity_by_rules[i].max(initial=0.0)),
            nearest_similarity_by_example_overlap=(
                context_similarity_by_examples[i].max(initial=0.0))))


def _populate_train_similarity_metadata_for_one_split(
    validation_or_test_set,
    train_set,
    train_examples_by_request):
  """Populates the train similarity metadata in the validation or test set.

  Args:
    validation_or_test_set: The example set to modify.
    train_set: The train set to compare against.
    train_examples_by_request: An index of the contents of train_set, mapping
      requests to top-level flattened examples that contain the request.
  """
  _populate_context_train_similarity_metadata(validation_or_test_set, train_set)
  for example_group in validation_or_test_set.example_groups:
    for example in example_group.to_flat_examples():
      _populate_example_train_similarity_metadata(example,
                                                  train_examples_by_request)


def populate_train_similarity_metadata(
    splits):
  """Populates split-specific metadata in the test and validation splits."""
  train_set = splits.get(tfds.Split.TRAIN, cl.GroupedExampleSet())
  validation_set = splits.get(tfds.Split.VALIDATION, cl.GroupedExampleSet())
  test_set = splits.get(tfds.Split.TEST, cl.GroupedExampleSet())

  train_examples_by_request = _index_dataset_by_request(train_set)
  _populate_train_similarity_metadata_for_one_split(validation_set, train_set,
                                                    train_examples_by_request)
  _populate_train_similarity_metadata_for_one_split(test_set, train_set,
                                                    train_examples_by_request)
