# coding=utf-8
# Copyright 2021 The Spectral Bias Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Builds the Shake-Shake Model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import tensorflow as tf

from spectral_bias.label_smoothing import custom_ops as ops


def round_int(x):
  """Rounds `x` and then converts to an int."""
  return int(math.floor(x + 0.5))


def shortcut(x, output_filters, stride):
  """Applies strided avg pool or zero padding to make output_filters match x."""
  num_filters = int(x.shape[3])
  if stride == 2:
    x = ops.avg_pool(x, 2, stride=stride, padding='SAME')
  if num_filters != output_filters:
    diff = output_filters - num_filters
    assert diff > 0
    # Zero padd diff zeros
    padding = [[0, 0], [0, 0], [0, 0], [0, diff]]
    x = tf.pad(x, padding)
  return x


def calc_prob(curr_layer, total_layers, p_l):
  """Calculates drop prob depending on the current layer."""
  return 1 - (float(curr_layer) / total_layers) * p_l


def bottleneck_layer(x, n, stride, prob, is_training, alpha, beta):
  """Bottleneck layer for shake drop model."""
  assert alpha[1] > alpha[0]
  assert beta[1] > beta[0]
  with tf.variable_scope('bottleneck_{}'.format(prob)):
    input_layer = x
    x = ops.maybe_normalize(x, scope='bn_1_pre')
    x = ops.conv2d(x, n, 1, scope='1x1_conv_contract')
    x = ops.maybe_normalize(x, scope='bn_1_post')
    x = tf.nn.relu(x)
    x = ops.conv2d(x, n, 3, stride=stride, scope='3x3')
    x = ops.maybe_normalize(x, scope='bn_2')
    x = tf.nn.relu(x)
    x = ops.conv2d(x, n * 4, 1, scope='1x1_conv_expand')
    x = ops.maybe_normalize(x, scope='bn_3')

    # Apply regularization here
    # Sample bernoulli with prob
    if is_training:
      batch_size = tf.shape(x)[0]
      bern_shape = [batch_size, 1, 1, 1]
      random_tensor = prob
      random_tensor += tf.random_uniform(bern_shape, dtype=tf.float32)
      binary_tensor = tf.floor(random_tensor)

      alpha_values = tf.random_uniform(
          [batch_size, 1, 1, 1], minval=alpha[0], maxval=alpha[1],
          dtype=tf.float32)
      beta_values = tf.random_uniform(
          [batch_size, 1, 1, 1], minval=beta[0], maxval=beta[1],
          dtype=tf.float32)
      rand_forward = (
          binary_tensor + alpha_values - binary_tensor * alpha_values)
      rand_backward = (
          binary_tensor + beta_values - binary_tensor * beta_values)
      x = x * rand_backward + tf.stop_gradient(x * rand_forward -
                                               x * rand_backward)
    else:
      expected_alpha = (alpha[1] + alpha[0])/2
      # prob is the expectation of the bernoulli variable
      x = (prob + expected_alpha - prob * expected_alpha) * x

    res = shortcut(input_layer, n * 4, stride)
    return x + res


def build_shake_drop_model(images, num_classes, is_training):
  """Builds the PyramidNet Shake-Drop model.

  Build the PyramidNet Shake-Drop model from https://arxiv.org/abs/1802.02375.

  Args:
    images: Tensor of images that will be fed into the Wide ResNet Model.
    num_classes: Number of classed that the model needs to predict.
    is_training: Is the model training or not.

  Returns:
    The logits of the PyramidNet Shake-Drop model.
  """
  # ShakeDrop Hparams
  p_l = 0.5
  alpha_shake = [-1, 1]
  beta_shake = [0, 1]

  # PyramidNet Hparams
  alpha = 200
  depth = 272
  # This is for the bottleneck architecture specifically
  n = int((depth - 2) / 9)
  start_channel = 16
  add_channel = alpha / (3 * n)

  # Building the models
  x = images
  x = ops.conv2d(x, 16, 3, scope='init_conv')
  x = ops.maybe_normalize(x, scope='init_bn')

  layer_num = 1
  total_layers = n * 3
  start_channel += add_channel
  prob = calc_prob(layer_num, total_layers, p_l)
  x = bottleneck_layer(
      x, round_int(start_channel), 1, prob, is_training, alpha_shake,
      beta_shake)
  layer_num += 1
  for _ in range(1, n):
    start_channel += add_channel
    prob = calc_prob(layer_num, total_layers, p_l)
    x = bottleneck_layer(
        x, round_int(start_channel), 1, prob, is_training, alpha_shake,
        beta_shake)
    layer_num += 1

  start_channel += add_channel
  prob = calc_prob(layer_num, total_layers, p_l)
  x = bottleneck_layer(
      x, round_int(start_channel), 2, prob, is_training, alpha_shake,
      beta_shake)
  layer_num += 1
  for _ in range(1, n):
    start_channel += add_channel
    prob = calc_prob(layer_num, total_layers, p_l)
    x = bottleneck_layer(
        x, round_int(start_channel), 1, prob, is_training, alpha_shake,
        beta_shake)
    layer_num += 1

  start_channel += add_channel
  prob = calc_prob(layer_num, total_layers, p_l)
  x = bottleneck_layer(
      x, round_int(start_channel), 2, prob, is_training, alpha_shake,
      beta_shake)
  layer_num += 1
  for _ in range(1, n):
    start_channel += add_channel
    prob = calc_prob(layer_num, total_layers, p_l)
    x = bottleneck_layer(
        x, round_int(start_channel), 1, prob, is_training, alpha_shake,
        beta_shake)
    layer_num += 1

  assert layer_num - 1 == total_layers
  x = ops.maybe_normalize(x, scope='final_bn')
  x = tf.nn.relu(x)
  x = ops.global_avg_pool(x)
  # Fully connected
  logits = ops.fc(x, num_classes)
  return logits, x
