# Lint as: python3
# Copyright 2018 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.
"""Base classes for network defition."""

import tensorflow.compat.v2 as tf

from utils import registry


def register_network_factory(name):
  """Decorator which registers network factory.

  Args:
    name: network factory will be registered under this name.

  Returns:
    decorator which registers network factory method.

  Example usage:

    @register_network('resnet')
    def make_resnet():
      # ...

  """
  return registry.register('network_factory', name)


def get_network_factory_from_registry(network_name):
  """Returns network factory from network registry."""
  return registry.get_registry('network_factory')[network_name]


class NetworkInterface(tf.Module):
  """This is an interface of neural network definition.

  Implementation of networks should follow this interface.
  Note that you don't have to inherit from this class, just need to implement
  all methods which are declared here.
  """

  @property
  def regularized_variables(self):
    """Returns list of variables to which regularization should be applied."""
    raise NotImplementedError(
        'regularized_variables must be implemented in a subclass')

  def __call__(self, x, is_training):
    """Evaluates network on given input.

    Args:
      x: input tensor.
      is_training: whether to run training mode or eval mode.

    Returns:
      Result of evaluation of network on tensor x.
    """
    raise NotImplementedError(
        '__call__ must be implemented in a subclass')


def network_factory_interface(num_classes,
                              input_shape,
                              network_hparams,
                              mixed_precision_dtype=None):
  """Example of how network factory should be defined.

  Args:
    num_classes: number of output classes.
    input_shape: shape of the input tensor.
    network_hparams: hyperparameters of this specific network architecture.
    mixed_precision_dtype: data type for mixed precision calculations.
      If None then default tf.float32 will be used for all calculations.

  Returns:
    Network class which should implement NetworkInterface.
  """
  del num_classes
  del input_shape
  del network_hparams
  del mixed_precision_dtype
  raise NotImplementedError(
      'This is just an example of interface of network factory.')
