# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Flax modules and functions for using external memory."""

from typing import Any, Optional, Tuple

from absl import logging
from flax import linen
import gin
import jax
from transformer import memory_layer



PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
Array = Any
MemoryResource = Any


class MemoryManager:
  """Manages any external resources that may be required by external memory.

  MemoryManager also functions as a factory, to create Flax modules that will
  read and write to whatever external memory has been configured.
  """

  def __init__(self,
               batch_size: int,
               mode: str,
               num_heads: int,
               key_size: int,
               value_size: int,
               database_size: Optional[int] = None,
               dtype: Dtype = "float32",
               off_device_memory: Optional[MemoryResource] = None):
    """Create a MemoryManager object.

    A MemoryManager configures external memory, and is used as a factory to
    construct flax modules that read or write to the memory.

    Args:
      batch_size: The number of separate documents in a batch.
      mode:       e.g. ("train", or "test")
      num_heads:  The number of transformer heads.
      key_size:   The length of the key vectors.
      value_size: The length of the value vectors.
      database_size:  The total number of tokens in the database.
      dtype:      The datatype used for keys and values.
      off_device_memory: An object which manages underlying SCAM memory.
          If None, then the model will use on-device memory.
    """
    self.batch_size = batch_size
    self.mode = mode
    self.num_heads = num_heads
    self.key_size = key_size
    self.value_size = value_size
    self.database_size = database_size
    self.dtype = dtype
    self.off_device_memory = off_device_memory

  def create_memory_layer(self) -> linen.Module:
    """Create a flax Module that implements external memory."""

    num_datasets = (
        self.batch_size * self.num_heads  #
        if self.off_device_memory is None  #
        else self.num_heads)
    if self.off_device_memory is not None:
      mem_layer = None
      if mem_layer is None:
        raise ValueError("Off-device memory is not supported at this time.")
      return memory_layer.BatchedMemory(
          mem_layer,
          split_dimensions=(-2,),
      )
    else:
      assert self.database_size is not None
      mem_layer = memory_layer.MemoryOnTpu(num_datasets=num_datasets,
                                           key_features=self.key_size,
                                           value_features=self.value_size,
                                           database_size=self.database_size,
                                           dtype=self.dtype)
    # Handle queries of shape [batch_size, seq_len, num_heads, kv_features]
    return memory_layer.BatchedMemory(mem_layer,
                                      split_dimensions=(0, -2))


@gin.configurable
def memory_on_tpu_factory(batch_size: int,
                          mode: str,
                          num_heads: int = gin.REQUIRED,
                          key_size: int = gin.REQUIRED,
                          value_size: int = gin.REQUIRED,
                          database_size: int = gin.REQUIRED,
                          dtype: Dtype = gin.REQUIRED) -> MemoryManager:
  """Implement SCAM memory on device."""
  return MemoryManager(batch_size=batch_size,
                       mode=mode,
                       num_heads=num_heads,
                       key_size=key_size,
                       value_size=value_size,
                       database_size=database_size,
                       dtype=dtype,
                       off_device_memory=None)


