# coding=utf-8
# Copyright 2020 The Gsa Net Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Specialized operators for GSA-Net."""
import functools
from typing import Callable, Tuple

import numpy as np
import tensorflow.compat.v1 as tf


def _combine_last_two_dimensions(tensor):
  """Reshapes tensor so that the last two dimensions become one dimension."""
  shape = tensor.get_shape().as_list()
  dim1, dim0 = shape[-2:]
  return tf.reshape(tensor, shape[:-2] + [dim1 * dim0])


def _compute_attention_component(
    inputs, depth, name, head_count):
  """Computes queries, keys, or values.

  Args:
    inputs: Input features with shape (batch_size, height, width, input_depth).
    depth: Total number of channels for all attention heads.
    name: The name of the component to compute, in
        {'queries', 'keys', 'values'}.
    head_count: Number of attention heads.

  Returns:
    Output features with shape
        (batch_size, height, width, head_count, depth // head_count).

  Raises:
    ValueError: If `name` not one of the expected values, or `depth` is less
        than `head_count`.
  """
  if name not in {'queries', 'keys', 'values'}:
    raise ValueError(
        'name ("{}") not in {{"queries", "keys", "values"}}.'.format(name))
  if depth < head_count:
    raise ValueError('depth ("{}") must be at least head_count ({}).'.format(
        depth, head_count))

  input_depth = inputs.get_shape().as_list()[-1]
  depth_per_head = depth // head_count
  if name == 'queries':
    # He initialization with scaling as in scaled-dot product attention.
    # For scaled dot-product attention, refer to Vaswani, Ashish, et al.
    # "Attention is all you need." NeurIPS. 2017.
    # https://arxiv.org/abs/1706.03762
    initializer_stddev = input_depth ** -0.5 * depth_per_head ** -0.5
  else:
    # He initialization
    initializer_stddev = input_depth ** -0.5
  weight = tf.get_variable(
      name,
      (head_count, input_depth, depth_per_head),
      initializer=tf.random_normal_initializer(stddev=initializer_stddev),
  )
  # Einsum indices:
  # b: batch, x: height, y: width, d: input depth, h: head, e: output depth
  outputs = tf.einsum(
      'bxyd,hde->bxyhe', inputs, weight, name='compute_{}'.format(name))
  return outputs


def _compute_attention_components(
    inputs,
    depth,
    head_count,
):
  """Compute queries, keys, and values.

  Args:
    inputs: Input features with shape (batch_size, height, width, depth).
    depth: Total number of channels for all attention heads for the keys,
        queries, or values.
    head_count: Number of attention heads.

  Returns:
    queries: Query features with shape
        (batch, height, width, head_count, depth // head_count).
    keys: Key features with shape
        (batch, height, width, head_count, depth // head_count).
    values: Value features with shape
        (batch, height, width, head_count, depth // head_count).
  """
  with tf.variable_scope('compute_attention_components'):
    queries = _compute_attention_component(
        inputs, depth, name='queries', head_count=head_count)
    keys = _compute_attention_component(
        inputs, depth, name='keys', head_count=head_count)
    values = _compute_attention_component(
        inputs, depth, name='values', head_count=head_count)
  return queries, keys, values


def multi_head_batch_normalization(
    inputs,
    is_training = True,
    batch_norm_momentum = 0.99,
    batch_norm_epsilon = 1e-5,
):
  """Batch normalization for a multi-head feature map.

  This function applies batch normalization to a multi-head feature tensor of
  shape (batch_size, height, width, head_count, depth). It appropriately
  recognizes the head dimension and ensures that the normalization is separate
  for each head. Naively applying `tf.layers.batch_normalization` would result
  in treating the head dimension as a spatiotemporal dimension and incorrectly
  aggregating batch statistics along that dimension.

  This function accepts `None` as `inputs` and returns `None` in that case.

  Args:
    inputs: Input features to the batch normalization layer with shape
        (batch_size, height, width, head_count, depth).
    is_training: Whether in training or evaluation mode.
    batch_norm_momentum: Momentum for the batch normalization layer.
    batch_norm_epsilon: Epsilon for the batch normalization layer.

  Returns:
    Output features with shape (batch_size, height, width, head_count, depth).
  """
  if inputs is None:
    return None
  shape = inputs.get_shape().as_list()
  inputs = _combine_last_two_dimensions(inputs)
  outputs = tf.layers.batch_normalization(
      inputs=inputs,
      axis=-1,
      momentum=batch_norm_momentum,
      epsilon=batch_norm_epsilon,
      center=True,
      scale=True,
      training=is_training,
      fused=True,
      gamma_initializer=tf.ones_initializer(),
  )
  return tf.reshape(outputs, shape)


def softmax_spatial(tensor):
  """Softmax along spatial axes.

  Args:
    tensor: Input features with shape (batch_size, height, width, depth) or
        (batch_size, height, width, head_count, depth).

  Returns:
    Normalized features with the same shape as the input features, with softmax
        function applied to the height and width dimensions, such that the
        features sum up to 1 along those dimensions.
  """
  spatial_axes = (1, 2)
  spatial_max = tf.reduce_max(tensor, axis=spatial_axes, keepdims=True)
  exp_shifted = tf.math.exp(tensor - spatial_max)
  sum_exp_shifted = tf.reduce_sum(exp_shifted, axis=spatial_axes, keepdims=True)
  normalized = exp_shifted / sum_exp_shifted
  return normalized


def _generate_lookup_tensor(length):
  """Returns lookup tensor to reindex relative embeddings to absolute indices.

  Args:
    length: Total length of the input dimension to reindex to.

  Returns:
    Lookup tensor with shape (length, length, 2 * length - 1).
  """
  possible_shifts = 2 * length - 1
  lookup_tensor_numpy = np.zeros((length, length, possible_shifts))
  for i in range(length):
    for j in range(length):
      shift = (j - i) + length - 1
      lookup_tensor_numpy[i, j, shift] = 1
  lookup_tensor = tf.constant(lookup_tensor_numpy, dtype=tf.float32)
  return lookup_tensor


def _compute_relative_logits(queries, dimension):
  """Computes relative logits along a dimension.

  Args:
    queries: Query features with shape
        (batch, height, width, head_count, depth).
    dimension: Dimension of interest for relative logit computation, in
        {'height', 'width'}.

  Returns:
    Relative positional attention logits with shape
        (batch, head_count, height, width, length), for length one of height or
        width, depending on `dimension`.
  """
  shape = queries.get_shape().as_list()
  _, height, width, _, depth = shape
  if dimension == 'height':
    length = height
  elif dimension == 'width':
    length = width
  else:
    raise ValueError('Unknown dimension: {}'.format(dimension))

  lookup_tensor = _generate_lookup_tensor(length)
  with tf.variable_scope('{}_relative_embeddings'.format(dimension)):
    possible_shifts = 2 * length - 1
    initializer_stddev = depth ** -0.5
    relative_embeddings = tf.get_variable(
        'embeddings',
        (possible_shifts, depth),
        initializer=tf.random_normal_initializer(stddev=initializer_stddev),
    )
  # Einsum indices:
  # x: query index, i: key index, s: relative shift
  reindexed_embeddings = tf.einsum(
      'xis,sd->xid', lookup_tensor, relative_embeddings,
      name='reindex_{}_embeddings'.format(dimension))

  # Einsum indices:
  # b: batch, h: head, d: key depth
  # x: query height, i: key height
  # y: query width, j: key width
  if dimension == 'height':
    relative_logits = tf.einsum(
        'bxyhd,xid->bhxyi',
        queries,
        reindexed_embeddings,
        name='height_relative_logits',
    )
  elif dimension == 'width':
    relative_logits = tf.einsum(
        'bxyhd,yjd->bhxyj',
        queries,
        reindexed_embeddings,
        name='weight_relative_logits',
    )
  return relative_logits


def efficient_relative_attention_2d(
    queries,
    keys,
    values,
    batch_norm_fn = tf.identity,
):
  """2D efficient attention with relative positional embeddings.

  This function takes in three 2D feature maps `queries`, `keys`, and `values`,
  applies the efficient attention operation with relative positional embeddings,
  and returns the result.

  Args:
    queries: Query features with shape
        (batch_size, height, width, head_count, depth).
    keys: Key features with shape
        (batch_size, height, width, head_count, depth).
    values: Value features with shape
        (batch_size, height, width, head_count, depth).
    batch_norm_fn: Batch normalization function to apply on the intermediate
        output, with pre-set parameters for `is_training`,
        `batch_norm_momentum`, and `batch_norm_epsilon`.

  Returns:
    Output features with shape (batch, heads, height, width, depth).
  """
  with tf.variable_scope('efficient_relative_attention_2d'):
    # Einsum indices:
    # b: batch, h: head, x: height, y: width, k: key depth, v: value depth
    context = tf.einsum('bxyhk,bxyhv->bhkv', keys, values)
    content_output = tf.einsum('bxyhk,bhkv->bxyhv', queries, context)

    height_relative_logits = _compute_relative_logits(queries, 'height')
    width_relative_logits = _compute_relative_logits(queries, 'width')

    # Einsum indices:
    # b: batch, h: head, v: value depth
    # x: query height, i: key height
    # y: query width, j: key width
    height_position_output = tf.einsum(
        'bhxyi,biyhv->bxyhv', height_relative_logits, values,
        name='height_einsum')
    height_position_output = batch_norm_fn(height_position_output)
    width_position_output = tf.einsum(
        'bhxyj,bxjhv->bxyhv', width_relative_logits, height_position_output,
        name='width_einsum')
    position_output = width_position_output

  output = content_output + position_output
  return output


def global_self_attention(
    inputs,
    depth,
    head_count,
    is_training = True,
    batch_norm_momentum = 0.99,
    batch_norm_epsilon = 1e-5,
):
  """Global self-attention (GSA) layer.

  Args:
    inputs: Input features with shape (batch_size, height, width, input_depth).
    depth: Number of channels for the outputs.
    head_count: Number of attention heads.
    is_training: Whether in training or evaluation mode.
    batch_norm_momentum: Momentum for the batch normalization layer.
    batch_norm_epsilon: Epsilon for the batch normalization layer.

  Returns:
    Output features with shape (batch_size, height, width, depth).

  Raises:
    ValueError: If depth is not divisible by the number of attention heads.
  """
  if depth % head_count != 0:
    raise ValueError(
        'depth ({}) must be divisible by head_count ({}).'.format(
            depth, head_count))

  with tf.variable_scope('global_self_attention'):
    queries, keys, values = _compute_attention_components(
        inputs, depth, head_count)

    batch_norm_fn = functools.partial(
        multi_head_batch_normalization,
        is_training=is_training,
        batch_norm_momentum=batch_norm_momentum,
        batch_norm_epsilon=batch_norm_epsilon,
    )
    queries = batch_norm_fn(queries)
    keys = softmax_spatial(batch_norm_fn(keys))
    values = batch_norm_fn(values)

    x = efficient_relative_attention_2d(queries, keys, values, batch_norm_fn)
    x = _combine_last_two_dimensions(x)
    return x
