"""Implementation of attention in MOAT block."""

import string
from typing import Optional
import numpy as np
import tensorflow as tf
from hparam_configs import create_config_from_dict

# Global dict storing the computed lookup tensor to avoid repeated computation.
LOOKUP_TENSOR_CACHE = {}


def generate_lookup_tensor(length: int,
                           max_relative_position: Optional[int] = None,
                           dtype: tf.dtypes.DType = tf.float32) -> np.ndarray:
  """Generates one-hot lookup tensor to reindex embeddings along one dimension.

  Args:
    length: The length to reindex to.
    max_relative_position: The maximum relative position to consider.
      Relative position embeddings for distances above this threshold
      are zeroed out.
    dtype: The data type of the returned lookup tensor.

  Returns:
    lookup_tensor: A tensor with shape
    [length, length, relative_position_range]. The element
      satisfies lookup_tensor[n, m, v] = 1{m - n + max_relative_position = v},
      where n, m mean two positions while v the relative position selection.
  """
  if max_relative_position is None:
    max_relative_position = length - 1
  lookup_key = ('lookup_matrix', length, max_relative_position)
  relative_position_range = 2 * max_relative_position + 1
  lookup_tensor_shape = [length, length, relative_position_range]

  if (lookup_key not in LOOKUP_TENSOR_CACHE
      or LOOKUP_TENSOR_CACHE[lookup_key].shape.as_list() != lookup_tensor_shape
      or LOOKUP_TENSOR_CACHE[lookup_key].dtype != dtype):
    lookup_tensor = np.zeros(lookup_tensor_shape)
    for i in range(length):
      for x in range(length):
        v = x - i + max_relative_position
        if abs(x - i) > max_relative_position:
          continue
        lookup_tensor[i, x, v] = 1
    LOOKUP_TENSOR_CACHE[lookup_key] = tf.constant(lookup_tensor, dtype)
  return LOOKUP_TENSOR_CACHE[lookup_key]


def reindex_2d_einsum_lookup(relative_position_tensor: tf.Tensor,
                             height: int,
                             width: int,
                             max_relative_height: Optional[int] = None,
                             max_relative_width: Optional[int] = None,
                             h_axis: int = 0) -> tf.Tensor:
  """Reindexes 2d relative position bias with 2 independent einsum lookups.

  Args:
    relative_position_tensor: A tensor of shape
      [..., relative_position_embedding_height,
      relative_position_embedding_width, ...].
    height: The height to reindex to.
    width: The width to reindex to.
    max_relative_height: Maximum relative height.
      Position embeddings corresponding to vertical distances larger
      than max_relative_height are zeroed out. None to disable.
    max_relative_width: Maximum relative width.
      Position embeddings corresponding to horizontal distances larger
      than max_relative_width are zeroed out. None to disable.
    h_axis: Axis corresponding to relative_position_embedding_height.
      Default to 0.

  Returns:
    reindexed_position_embedding: A Tensor of shape
      [..., height * width, height * width, ...]
  """
  height_lookup = generate_lookup_tensor(
      height, max_relative_position=max_relative_height,
      dtype=relative_position_tensor.dtype)
  width_lookup = generate_lookup_tensor(
      width, max_relative_position=max_relative_width,
      dtype=relative_position_tensor.dtype)

  non_spatial_rank = relative_position_tensor.shape.rank - 2
  non_spatial_expr = ''.join(chr(ord('n') + i) for i in range(non_spatial_rank))
  prefix = non_spatial_expr[:h_axis]
  suffix = non_spatial_expr[h_axis:]

  reindexed_tensor = tf.einsum(
      '{0}hw{1},ixh->{0}ixw{1}'.format(prefix, suffix),
      relative_position_tensor, height_lookup, name='height_lookup')
  reindexed_tensor = tf.einsum(
      '{0}ixw{1},jyw->{0}ijxy{1}'.format(prefix, suffix),
      reindexed_tensor, width_lookup, name='width_lookup')

  ret_shape = relative_position_tensor.shape.as_list()
  ret_shape[h_axis] = height * width
  ret_shape[h_axis + 1] = height * width
  reindexed_tensor = tf.reshape(reindexed_tensor, ret_shape)
  return reindexed_tensor


class TrailDense(tf.keras.layers.Layer):
  """A dense layer that projects features in multiple trailing axes.

  This layer projects features from multiple dimensions to multiple dimensions.
  The trailing axes with size n mean the last n dimensions. This layer avoids
  the extra uses of reshape operations.

  A einsum expression string is generated in this layer, examples:
    - For 4D tensors in conv, a common expression would be 'ABCD,DE->ABCE'.
    - For `q/k/v` head projection in multi-head attention with two output
      trailing dimensions, the expression is 'ABC,CDE->ABDE'
    - For `o` output projection in multi-head attention with
      input_begin_axis = -2, the expression is 'ABCD,CDE->ABE'
  """

  def __init__(self,
               output_trailing_dimensions,
               input_begin_axis=-1,
               use_bias=True,
               kernel_initializer=tf.random_normal_initializer(stddev=0.02),
               bias_initializer=tf.zeros_initializer,
               name='dense'):
    """Initializes TrailDense layer.

    Args:
      output_trailing_dimensions: A list of integers, multiple output
        dimensions in trailing axes. This avoids extra reshape
        operation that splits one single output dimension.
      input_begin_axis: A negative integer, the beginning axes of the input.
        This saves extra reshape operation to merge multiple input dimension.
      use_bias: A boolen, whether to use learnable bias in the layer.
      kernel_initializer: Initializer for the kernel weights matrix.
      bias_initializer: Initializer for the bias vector.
      name: A string, layer name.
    """

    super().__init__(name=name)
    self._output_trailing_dimensions = output_trailing_dimensions
    self._input_begin_axis = input_begin_axis
    self._use_bias = use_bias
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer

  def build(self, input_shape):
    """Creates variables and einsum expression based on input shape."""

    weight_shape = (input_shape[self._input_begin_axis:] +
                    self._output_trailing_dimensions)
    self.weight = self.add_weight(
        name='weight',
        shape=weight_shape,
        initializer=self._kernel_initializer,
        trainable=True)
    if self._use_bias:
      self.bias = self.add_weight(
          name='bias',
          shape=self._output_trailing_dimensions,
          initializer=self._bias_initializer,
          trainable=True)

    # Create einsum expression.
    input_rank = input_shape.rank
    shared_size = self._input_begin_axis % input_rank
    i_only_size = input_rank - shared_size
    o_only_size = len(self._output_trailing_dimensions)

    if input_rank + o_only_size >= len(string.ascii_uppercase):
      raise ValueError('Cannot use einsum as input rank + output rank > 26.')
    einsum_str = string.ascii_uppercase[:input_rank + o_only_size]

    offset = 0
    shared_str = einsum_str[offset:offset+shared_size]
    offset += shared_size
    i_only_str = einsum_str[offset:offset+i_only_size]
    offset += i_only_size
    o_only_str = einsum_str[offset:offset+o_only_size]

    input_str = '{}{}'.format(shared_str, i_only_str)
    output_str = '{}{}'.format(shared_str, o_only_str)
    weight_str = '{}{}'.format(i_only_str, o_only_str)
    self.einsum_expr = '{},{}->{}'.format(input_str, weight_str, output_str)

  def call(self, inputs):
    output = tf.einsum(self.einsum_expr, inputs, self.weight)
    if self._use_bias:
      output += self.bias
    return output


class Attention(tf.keras.layers.Layer):
  """Implementation of Attention.

  This layer performs global self-attention [1] on the input. The input shape
  is [batch_size, height, width, channels] and the output shape is
  [batch_size, height * width, channels].

  If one would like to extend the global self-attention to the local window
  attention [2], they could reshape the input to
  [batch_size * num_window, pixel_num_per_window, channels] followed by
  applying this class.

  [1] Attention Is All You Need.
    NeurIPS 2017.
      Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
      Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin.

  [2] Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
    ICCV 2021.
      Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang,
      Stephen Lin, Baining Guo
  """

  def _retrieve_config(self, config):
    """Retrieves the config of Attention.

    Args:
      config: A dictionary containing the following keys.
        -hidden_size: An integer, output channels.
        -head_size: An integer, the head size of attention.
        -relative_position_embedding_type: A string, type of relative position
          embedding in Attention. Only '2d_multi_head' is supported now.
          If None, no relative position embedding will be used.
        -scale_ratio: A float or a list of floats with length 2, scaling factors
          for the position embedding in height and width dimensions. For
          example, [14/14, 16/14] means the position embedding is created for
          window 14 x 14, but will be interpolated to 14 x 16.
        -kernel_initializer: Initializer for the kernel weights matrix.
        -bias_initializer: Initializer for the bias vector.
        -name: A string, layer name.

    Returns:
      A Config class: hparams_config.Config.
    """

    required_keys = ['hidden_size', 'head_size']
    optional_keys = {
        'relative_position_embedding_type': None,
        'scale_ratio': None,
        'kernel_initializer': tf.random_normal_initializer(stddev=0.02),
        'bias_initializer': tf.zeros_initializer,
        'name': 'attention',
    }
    config = create_config_from_dict(config, required_keys, optional_keys)
    return config

  def __init__(self, **config):
    self._config = self._retrieve_config(config)
    super().__init__(name=self._config.name)
    self._config.num_heads = self._config.hidden_size // self._config.head_size

    self._q_proj = TrailDense(
        output_trailing_dimensions=[self._config.num_heads,
                                    self._config.head_size],
        kernel_initializer=self._config.kernel_initializer,
        bias_initializer=self._config.bias_initializer,
        name='q')
    self._k_proj = TrailDense(
        output_trailing_dimensions=[self._config.num_heads,
                                    self._config.head_size],
        kernel_initializer=self._config.kernel_initializer,
        bias_initializer=self._config.bias_initializer,
        name='k')
    self._v_proj = TrailDense(
        output_trailing_dimensions=[self._config.num_heads,
                                    self._config.head_size],
        kernel_initializer=self._config.kernel_initializer,
        bias_initializer=self._config.bias_initializer,
        name='v')
    self._o_proj = TrailDense(
        output_trailing_dimensions=[self._config.hidden_size],
        input_begin_axis=-2,
        kernel_initializer=self._config.kernel_initializer,
        bias_initializer=self._config.bias_initializer,
        name='o')

    self._q_scale = self._config.head_size ** -0.5

  def build(self, input_shape):
    if self._config.relative_position_embedding_type == '2d_multi_head':
      if input_shape.rank != 4:
        raise ValueError(
            'The input shape should be [batch_size, height, width, channels]')
      input_shape_list = input_shape.as_list()
      height, width = input_shape_list[-3], input_shape_list[-2]
      if self._config.scale_ratio is not None:
        if (isinstance(self._config.scale_ratio, list) and
            len(self._config.scale_ratio) == 2):
          height_scale, width_scale = self._config.scale_ratio
        elif isinstance(self._config.scale_ratio, float):
          height_scale = self._config.scale_ratio
          width_scale = self._config.scale_ratio
        else:
          raise ValueError(
              'scale ratio should be float or list of floats with length 2')

        relative_position_embedding_height = (
            2 * round(height / height_scale) - 1)
        relative_position_embedding_width = (
            2 * round(width / width_scale) - 1)
      else:
        relative_position_embedding_height = 2 * height - 1
        relative_position_embedding_width = 2 * width - 1
      relative_position_embedding_height_axis = 1
      relative_position_embedding_shape = [
          self._config.num_heads,
          relative_position_embedding_height,
          relative_position_embedding_width]
      self.relative_position_embedding = self.add_weight(
          'relative_position_embedding',
          relative_position_embedding_shape,
          initializer=self._config.kernel_initializer,
          trainable=True)
      if self._config.scale_ratio is not None:
        relative_position_embedding = tf.expand_dims(
            self.relative_position_embedding, axis=-1)
        relative_position_embedding = tf.cast(
            tf.image.resize(relative_position_embedding,
                            [2 * height - 1, 2 * width - 1]),
            self.compute_dtype)
        relative_position_embedding = tf.squeeze(relative_position_embedding,
                                                 axis=-1)
      else:
        relative_position_embedding = tf.cast(self.relative_position_embedding,
                                              self.compute_dtype)

      self.reindexed_position_embedding = reindex_2d_einsum_lookup(
          relative_position_embedding, height, width, height - 1, width - 1,
          h_axis=relative_position_embedding_height_axis)
    elif self._config.relative_position_embedding_type is None:
      self.reindexed_position_embedding = None

  def call(self, query, training):
    _, h, w, channels = query.shape.as_list()
    query = tf.reshape(query, [-1, h * w, channels])

    q_heads = self._q_proj(query)
    k_heads = self._k_proj(query)
    v_heads = self._v_proj(query)
    q_heads *= self._q_scale

    attention_logits = tf.einsum('BSNK, BTNK -> BNST', q_heads, k_heads)

    if self.reindexed_position_embedding is not None:
      attention_logits += self.reindexed_position_embedding

    attention_probs = tf.cast(
        tf.nn.softmax(tf.cast(attention_logits, tf.float32), axis=-1),
        attention_logits.dtype)

    attention_out = tf.einsum('BNST, BTNK -> BSNK', attention_probs, v_heads)
    output = self._o_proj(attention_out)
    return output

