# Lint as: python3
# Copyright 2018 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.
"""Code which creates networks."""

from absl import logging
from networks import base as networks_base
from networks import default_resnet50
import tensorflow as tf


def make_network(hparams, num_classes, input_shape):
  """Creates network.

  Args:
    hparams: dictionary with all hyperparameters.
    num_classes: number of output classes.
    input_shape: shape of the input batch, including batch dimension.

  Returns:
    Class with the network which implements networks_base.NetworkInterface.
  """
  network_type = hparams.network.lower()
  network_factory = networks_base.get_network_factory_from_registry(
      network_type)
  network_hparams = hparams.arch
  mixed_precision_dtype = tf.bfloat16 if hparams.bfloat16 else None
  logging.info('Creating %s network with hyperparams: %s',
               network_type, network_hparams)
  return network_factory(num_classes=num_classes,
                         input_shape=input_shape,
                         network_hparams=network_hparams,
                         mixed_precision_dtype=mixed_precision_dtype)
