# Lint as: python3
# Copyright 2018 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.
"""ResNet50 keras implementation."""
from networks import base as networks_base

import tensorflow as tf

layers = tf.keras.layers

DEFAULT_BATCH_NORM_DECAY = 0.9
DEFAULT_BATCH_NORM_EPSILON = 1e-5


class KerasModelWrapper(tf.Module):
  """Wrapper for Keras model to change arguments of __call__ method."""

  def __init__(self,
               keras_model,
               regularized_variables,
               num_classes=1000,
               skip_head=True):
    self.keras_model = keras_model
    self._regularized_variables = regularized_variables

    if skip_head:
      self.num_classes = num_classes
      self._pool = layers.GlobalAveragePooling2D(name='avg_pool')
      self._logits = layers.Dense(num_classes)

  def recreate_classifier(self):
    hidden_size = self.keras_model.output.shape[-1]
    self._logits = layers.Dense(self.num_classes)
    self._logits(tf.zeros((1, hidden_size)))

  def summary(self):
    self.keras_model.summary()

  @property
  def regularized_variables(self):
    return self._regularized_variables

  def __call__(self, x, is_training):
    x = self.keras_model(x, training=is_training)
    if hasattr(self, '_logits'):
      x = self._pool(x)
      return self._logits(x)
    return x


@networks_base.register_network_factory('keras_resnet50_imagenet')
def keras_resnet50_factory(num_classes,
                           input_shape,
                           network_hparams,
                           mixed_precision_dtype=None):
  """Factory for keras resnet50."""
  if 'imagenet_weight' in network_hparams and network_hparams['imagenet_weight']:
    weights = 'imagenet'
  else:
    weights = None
  if mixed_precision_dtype == tf.bfloat16:
    policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
    tf.keras.mixed_precision.experimental.set_policy(policy)
  elif ((mixed_precision_dtype is not None) and
        (mixed_precision_dtype != tf.float32)):
    raise ValueError('Unsupported mixed precision dtype %s' %
                     mixed_precision_dtype)
  base_model = tf.keras.applications.ResNet50V2(
      include_top=False,
      weights=weights,
      input_shape=input_shape[1:],
      classes=num_classes)

  model = base_model
  skip_dense = 'skip_dense' in network_hparams and network_hparams['skip_dense']
  if not skip_dense:
    x = layers.GlobalAveragePooling2D(name='avg_pool')(base_model.output)
    predictions = layers.Dense(num_classes)(x)
    model = tf.keras.Model(inputs=base_model.input, outputs=predictions)
  else:
    model = base_model

  regularized_variables = []
  for v in model.trainable_variables:
    if 'kernel' in v.name or 'bias' in v.name:
      regularized_variables.append(v)
  return KerasModelWrapper(model, regularized_variables, num_classes,
                           skip_dense)
