# coding=utf-8
# Copyright 2018 .
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains base implementation for relation classification tasks."""


import flax.linen as nn
import jax.numpy as jnp
from mentionmemory.encoders import encoder_registry
from mentionmemory.modules import mlp
from mentionmemory.tasks import downstream_encoder_task
from mentionmemory.tasks import task_registry
from mentionmemory.utils import default_values
from mentionmemory.utils import jax_utils as jut
from mentionmemory.utils import metric_utils
from mentionmemory.utils.custom_types import Array, Dtype, MetricGroups  # pylint: disable=g-multiple-import
import mentionmemory.utils.mention_preprocess_utils as mention_preprocess_utils
import ml_collections
import numpy as np
import tensorflow.compat.v2 as tf


class RelationClassifierModel(nn.Module):
  """Encoder wrapper with classification layer for relation classification.

  This model takes mention-annotated text with special mentions per sample
  denoted as "subject" and "object". The model generates mention encodings for
  these two mentions, concatenates them and applies (potentially, multi-layer)
  n-ary classification layers.

  Attributes:
    num_classes: number of classification labels.
    num_layers: number of classification MLP layers on top of mention encodings.
    input_dim: input dimensionality of classification MLP layers. This must be
      equal to 2 * mention encodings size.
    hidden_dim: hidden dimensionality of classification MLP layers.
    dropout_rate: dropout rate of classification MLP layers.
    encoder_name: name of encoder model to use to encode passage.
    encoder_config: encoder hyperparameters.
    dtype: precision of computation.
    mention_encodings_feature: feature name for encodings of target mentions.
  """

  num_classes: int
  num_layers: int
  input_dim: int
  hidden_dim: int
  dropout_rate: float
  encoder_name: str
  encoder_config: ml_collections.FrozenConfigDict
  dtype: Dtype

  mention_encodings_feature: str = 'target_mention_encodings'
  layer_norm_epsilon: float = default_values.layer_norm_epsilon

  def setup(self):
    self.encoder = encoder_registry.get_registered_encoder(
        self.encoder_name)(**self.encoder_config)

    self.classification_mlp_layers = [
        mlp.MLPBlock(  # pylint: disable=g-complex-comprehension
            input_dim=self.input_dim,
            hidden_dim=self.hidden_dim,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            layer_norm_epsilon=self.layer_norm_epsilon,
        ) for _ in range(self.num_layers)
    ]

    self.linear_classifier = nn.Dense(self.num_classes, dtype=self.dtype)

  def __call__(self, batch, deterministic):
    _, loss_helpers, logging_helpers = self.encoder.forward(
        batch, deterministic)
    mention_encodings = loss_helpers[self.mention_encodings_feature]

    subject_mention_encodings = jut.matmul_slice(
        mention_encodings, batch['mention_subject_indices'])

    object_mention_encodings = jut.matmul_slice(mention_encodings,
                                                batch['mention_object_indices'])

    relation_encodings = jnp.concatenate(
        [subject_mention_encodings, object_mention_encodings], -1)

    for mlp_layer in self.classification_mlp_layers:
      relation_encodings = mlp_layer(relation_encodings, deterministic)

    classifier_logits = self.linear_classifier(relation_encodings)
    loss_helpers['classifier_logits'] = classifier_logits

    return loss_helpers, logging_helpers


@task_registry.register_task('relation_classifier')
class RelationClassifierTask(downstream_encoder_task.DownstreamEncoderTask):
  """Class for relation classification task."""
  model_class = RelationClassifierModel

  @classmethod
  def make_loss_fn(
      cls, config
  ):
    """Creates loss function for Relation Classifier training.

    Args:
      config: task configuration.

    Returns:
      Loss function.
    """

    ignore_label = config.ignore_label

    def loss_fn(
        model_config,
        model_params,
        model_vars,
        batch,
        deterministic,
        dropout_rng = None,
    ):
      """Loss function used by Relation Classifier task. See BaseTask."""

      variable_dict = {'params': model_params}
      variable_dict.update(model_vars)
      loss_helpers, _ = cls.build_model(model_config).apply(
          variable_dict, batch, deterministic=deterministic, rngs=dropout_rng)

      weights = jnp.ones_like(batch['classifier_target'])

      loss, denom = metric_utils.compute_weighted_cross_entropy(
          loss_helpers['classifier_logits'], batch['classifier_target'],
          weights)

      acc, _ = metric_utils.compute_weighted_accuracy(
          loss_helpers['classifier_logits'], batch['classifier_target'],
          weights)

      predictions = jnp.argmax(loss_helpers['classifier_logits'], axis=-1)

      tp, fp, fn = metric_utils.compute_tp_fp_fn_weighted(
          predictions, batch['classifier_target'], weights, ignore_label)

      metrics = {
          'agg': {
              'loss': loss,
              'denominator': denom,
              'acc': acc,
          },
          'micro_precision': {
              'value': tp,
              'denominator': tp + fp,
          },
          'micro_recall': {
              'value': tp,
              'denominator': tp + fn,
          }
      }

      auxiliary_output = {'predictions': predictions}
      auxiliary_output.update(cls.get_auxiliary_output(loss_helpers))

      return loss, metrics, auxiliary_output

    return loss_fn

  @staticmethod
  def make_collater_fn(
      config
  ):
    """Produces function to preprocess batches for relation classification task.

    This function samples and flattens mentions from input data.

    Args:
      config: task configuration.

    Returns:
      collater function.
    """
    encoder_config = config.model_config.encoder_config
    max_length = encoder_config.max_length
    bsz = config.per_device_batch_size
    max_batch_mentions = config.max_mentions * bsz
    n_candidate_mentions = config.max_mentions_per_sample * bsz

    if config.max_mentions < 2:
      raise ValueError('Need at least two mentions per sample in order to '
                       'include object and subject mentions.')

    def collater_fn(batch):
      """Collater function for relation classification task. See BaseTask."""

      def flatten_bsz(tensor):
        return tf.reshape(tensor, [bsz])

      new_batch = {
          'text_ids': batch['text_ids'],
          'text_mask': batch['text_mask'],
          'classifier_target': flatten_bsz(batch['target']),
      }

      # Sample mentions across batch

      # We want to make sure that the subject / object mentions always have
      # priority when we sample `max_batch_mentions` out of all available
      # mentions. Additionally, we want these subject / object  mentions to be
      # in the same order as their samples. In other words, we want the first
      # sampled mention to be object mention from the first sample, the second
      # sampled mention to be subject mention from the first sample, the third
      # sampled mention to be object mention from the second sample, etc.

      subj_index = flatten_bsz(batch['subject_mention_indices'])
      obj_index = flatten_bsz(batch['object_mention_indices'])

      # Adjust subject / object mention positions in individual samples to their
      # positions in flattened mentions.
      shift = tf.range(
          bsz, dtype=obj_index.dtype) * config.max_mentions_per_sample
      mention_target_indices = tf.reshape(
          tf.stack([subj_index + shift, obj_index + shift], axis=1), [-1])

      # Sample the rest of the mentions uniformly across batch
      scores = tf.random.uniform(shape=tf.shape(batch['mention_mask']))
      scores = scores * tf.cast(batch['mention_mask'], tf.float32)

      # We want to adjust scores for target mentions so they don't get sampled
      # for the second time. We achive this by making their scores negative.
      def set_negative_scores(scores, indices):
        indices_2d = tf.stack([tf.range(bsz, dtype=indices.dtype), indices],
                              axis=1)
        return tf.tensor_scatter_nd_update(scores, indices_2d,
                                           tf.fill(tf.shape(indices), -1.0))

      # Note that since we're using 2D scores (not yet flattened for simplicity)
      # we use unadjusted `subj_index` and `obj_index`.
      scores = set_negative_scores(scores, subj_index)
      scores = set_negative_scores(scores, obj_index)

      # There are `2 * bsz` target mentions which were already chosen
      num_to_sample = tf.maximum(max_batch_mentions - 2 * bsz, 0)
      sampled_scores, sampled_indices = tf.math.top_k(
          tf.reshape(scores, [-1]), num_to_sample, sorted=True)

      # Note that negative scores indicate that we have double-sampled some of
      # the target mentions (we set their scores to negative right above).
      # In this case, we remove them.
      num_not_double_sampled = tf.reduce_sum(
          tf.cast(tf.not_equal(sampled_scores, -1), tf.int32))
      sampled_indices = sampled_indices[:num_not_double_sampled]

      # Combine target mentions (subject / object) with sampled mentions
      mention_target_indices = tf.cast(mention_target_indices,
                                       sampled_indices.dtype)
      sampled_indices = tf.concat([mention_target_indices, sampled_indices],
                                  axis=0)

      sampled_indices = mention_preprocess_utils.dynamic_padding_1d(
          sampled_indices, max_batch_mentions)

      mention_mask = tf.reshape(batch['mention_mask'], [n_candidate_mentions])
      new_batch['mention_mask'] = tf.gather(mention_mask, sampled_indices)
      new_batch['mention_start_positions'] = tf.gather(
          tf.reshape(batch['mention_start_positions'], [n_candidate_mentions]),
          sampled_indices)
      new_batch['mention_end_positions'] = tf.gather(
          tf.reshape(batch['mention_end_positions'], [n_candidate_mentions]),
          sampled_indices)
      new_batch['mention_batch_positions'] = tf.gather(
          tf.repeat(tf.range(bsz), config.max_mentions_per_sample),
          sampled_indices)

      new_batch['mention_target_indices'] = tf.range(2 * bsz)
      new_batch['mention_subject_indices'] = tf.range(bsz) * 2
      new_batch['mention_object_indices'] = tf.range(bsz) * 2 + 1

      new_batch['mention_target_batch_positions'] = tf.gather(
          new_batch['mention_batch_positions'],
          new_batch['mention_target_indices'])
      new_batch['mention_target_start_positions'] = tf.gather(
          new_batch['mention_start_positions'],
          new_batch['mention_target_indices'])
      new_batch['mention_target_end_positions'] = tf.gather(
          new_batch['mention_end_positions'],
          new_batch['mention_target_indices'])
      new_batch['mention_target_weights'] = tf.ones(2 * bsz)

      # Fake IDs -- some encoders (ReadTwice) need them
      new_batch['mention_target_ids'] = tf.zeros(2 * bsz)

      new_batch['segment_ids'] = tf.zeros_like(batch['text_ids'])

      position_ids = tf.expand_dims(tf.range(max_length), axis=0)
      new_batch['position_ids'] = tf.tile(position_ids, (bsz, 1))

      return new_batch

    return collater_fn

  @staticmethod
  def get_name_to_features(
      config):
    """Return feature dict for decoding purposes. See BaseTask for details."""
    encoder_config = config.model_config.encoder_config
    max_length = encoder_config.max_length

    name_to_features = {
        'text_ids':
            tf.io.FixedLenFeature(max_length, tf.int64),
        'text_mask':
            tf.io.FixedLenFeature(max_length, tf.int64),
        'target':
            tf.io.FixedLenFeature(1, tf.int64),
        'mention_start_positions':
            tf.io.FixedLenFeature(config.max_mentions_per_sample, tf.int64),
        'mention_end_positions':
            tf.io.FixedLenFeature(config.max_mentions_per_sample, tf.int64),
        'mention_mask':
            tf.io.FixedLenFeature(config.max_mentions_per_sample, tf.int64),
        'subject_mention_indices':
            tf.io.FixedLenFeature(1, tf.int64),
        'object_mention_indices':
            tf.io.FixedLenFeature(1, tf.int64),
    }

    return name_to_features

  @staticmethod
  def dummy_input(config):
    """Produces model-specific dummy input batch. See BaseTask for details."""

    encoder_config = config.model_config.encoder_config
    bsz = config.per_device_batch_size
    text_shape = (bsz, encoder_config.max_length)
    mention_shape = (config.max_mentions)
    mention_target_shape = (2 * bsz)
    int_type = jnp.int32

    position_ids = np.arange(encoder_config.max_length)
    position_ids = np.tile(position_ids, (bsz, 1))

    dummy_input = {
        'text_ids':
            jnp.ones(text_shape, int_type),
        'text_mask':
            jnp.ones(text_shape, int_type),
        'position_ids':
            jnp.asarray(position_ids, int_type),
        'segment_ids':
            jnp.zeros(text_shape, int_type),
        'classifier_target':
            jnp.ones((bsz,), int_type),
        'mention_start_positions':
            jnp.zeros(mention_shape, int_type),
        'mention_end_positions':
            jnp.zeros(mention_shape, int_type),
        'mention_mask':
            jnp.ones(mention_shape, int_type),
        'mention_batch_positions':
            jnp.ones(mention_shape, int_type),
        'mention_target_indices':
            jnp.arange(mention_target_shape, dtype=int_type),
        'mention_target_weights':
            jnp.ones(mention_target_shape, dtype=int_type),
        'mention_object_indices':
            jnp.arange(bsz, dtype=int_type),
        'mention_subject_indices':
            jnp.arange(bsz, dtype=int_type),
        'mention_target_batch_positions':
            jnp.arange(mention_target_shape, dtype=int_type),
        'mention_target_start_positions':
            jnp.zeros(mention_target_shape, int_type),
        'mention_target_end_positions':
            jnp.zeros(mention_target_shape, int_type),
        'mention_target_ids':
            jnp.zeros(mention_target_shape, int_type),
    }

    return dummy_input
