# coding=utf-8
# Copyright 2024 The Language Tale Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of ResNet V1 in Flax.

"Deep Residual Learning for Image Recognition"
He et al., 2015, [https://arxiv.org/abs/1512.03385]
"""

# See issue #620.
# pytype: disable=wrong-arg-count
# pytype: disable=missing-parameter
# pytype: disable=wrong-keyword-args

import functools
from typing import Any, Tuple, Type

import flax.linen as nn
import jax.numpy as jnp

Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1), use_bias=False)
Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3), use_bias=False)


class ResNetBlock(nn.Module):
  """ResNet block without bottleneck used in ResNet-18 and ResNet-34."""

  filters: int
  norm: Any
  strides: Tuple[int, int] = (1, 1)

  @nn.compact
  def __call__(self, x):
    residual = x

    x = Conv3x3(self.filters, strides=self.strides, name="conv1")(x)
    x = self.norm(name="bn1")(x)
    x = nn.relu(x)
    x = Conv3x3(self.filters, name="conv2")(x)
    # Initializing the scale to 0 has been common practice since "Fixup
    # Initialization: Residual Learning Without Normalization" Tengyu et al,
    # 2019, [https://openreview.net/forum?id=H1gsz30cKX].
    x = self.norm(scale_init=nn.initializers.zeros, name="bn2")(x)

    if residual.shape != x.shape:
      residual = Conv1x1(
          self.filters, strides=self.strides, name="proj_conv")(
              residual)
      residual = self.norm(name="proj_bn")(residual)

    x = nn.relu(residual + x)
    return x


class BottleneckResNetBlock(ResNetBlock):
  """Bottleneck ResNet block used in ResNet-50 and larger."""

  @nn.compact
  def __call__(self, x):
    residual = x

    x = Conv1x1(self.filters, name="conv1")(x)
    x = self.norm(name="bn1")(x)
    x = nn.relu(x)
    x = Conv3x3(self.filters, strides=self.strides, name="conv2")(x)
    x = self.norm(name="bn2")(x)
    x = nn.relu(x)
    x = Conv1x1(4 * self.filters, name="conv3")(x)
    # Initializing the scale to 0 has been common practice since "Fixup
    # Initialization: Residual Learning Without Normalization" Tengyu et al,
    # 2019, [https://openreview.net/forum?id=H1gsz30cKX].
    x = self.norm(name="bn3")(x)

    if residual.shape != x.shape:
      residual = Conv1x1(
          4 * self.filters, strides=self.strides, name="proj_conv")(
              residual)
      residual = self.norm(name="proj_bn")(residual)

    x = nn.relu(residual + x)
    return x


class ResNetStage(nn.Module):
  """ResNet stage consistent of multiple ResNet blocks."""

  stage_size: int
  filters: int
  block_cls: Type[ResNetBlock]
  norm: Any
  first_block_strides: Tuple[int, int]

  @nn.compact
  def __call__(self, x):
    for i in range(self.stage_size):
      x = self.block_cls(
          filters=self.filters,
          norm=self.norm,
          strides=self.first_block_strides if i == 0 else (1, 1),
          name=f"block{i + 1}")(
              x)
    return x


class ResNet(nn.Module):
  """Construct ResNet V1 with `num_classes` outputs.

  Attributes:
    num_classes: Number of nodes in the final layer.
    block_cls: Class for the blocks. ResNet-50 and larger use
      `BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and
        ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions).
    stage_sizes: Tuple with the number of ResNet blocks in each stage. Number of
      stages can be varied.
    width_factor: Factor applied to the number of filters. The 64 * width_factor
      is the number of filters in the first stage, every consecutive stage
      doubles the number of filters.
  """
  num_classes: int
  block_cls: Type[ResNetBlock]
  stage_sizes: Tuple[int]
  width_factor: int = 1

  @nn.compact
  def __call__(self, x, *, train):
    """Apply the ResNet to the inputs `x`.

    Args:
      x: Inputs.
      train: Whether to use BatchNorm in training or inference mode.

    Returns:
      The output head with `num_classes` entries.
    """
    width = 64 * self.width_factor
    norm = functools.partial(
        nn.BatchNorm, use_running_average=not train, momentum=0.9)

    # Root block
    x = nn.Conv(
        features=width,
        kernel_size=(7, 7),
        strides=(2, 2),
        use_bias=False,
        name="init_conv")(
            x)
    x = norm(name="init_bn")(x)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")

    # Stages
    for i, stage_size in enumerate(self.stage_sizes):
      x = ResNetStage(
          stage_size,
          filters=width * 2**i,
          block_cls=self.block_cls,
          norm=norm,
          first_block_strides=(1, 1) if i == 0 else (2, 2),
          name=f"stage{i + 1}")(
              x)

    # Head
    x = jnp.mean(x, axis=(1, 2))
    x = nn.Dense(
        self.num_classes, kernel_init=nn.initializers.zeros, name="head")(
            x)
    return x


ResNet18 = functools.partial(
    ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock)
ResNet34 = functools.partial(
    ResNet, stage_sizes=(3, 4, 6, 3), block_cls=ResNetBlock)
ResNet50 = functools.partial(
    ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock)
ResNet101 = functools.partial(
    ResNet, stage_sizes=(3, 4, 23, 3), block_cls=BottleneckResNetBlock)
ResNet152 = functools.partial(
    ResNet, stage_sizes=(3, 8, 36, 3), block_cls=BottleneckResNetBlock)
ResNet200 = functools.partial(
    ResNet, stage_sizes=(3, 24, 36, 3), block_cls=BottleneckResNetBlock)


class MultiscaleResNet(nn.Module):
  """Construct a multiscale ResNet.

  Attributes:
    block_cls: Class for the blocks. ResNet-50 and larger use
      `BottleneckResNetBlock` (convolutions: 1x1, 3x3, 1x1), ResNet-18 and
      ResNet-34 use `ResNetBlock` without bottleneck (two 3x3 convolutions).
    stage_sizes: Tuple with the number of ResNet blocks in each stage. Number of
      stages can be varied.
    width_factor: Factor applied to the number of filters. The 64 * width_factor
      is the number of filters in the first stage, every consecutive stage
      doubles the number of filters.
  """
  block_cls: Type[ResNetBlock]
  stage_sizes: Tuple[int, Ellipsis]
  width_factor: int = 1

  @nn.compact
  def __call__(self, x, *, train):
    """Apply the ResNet to the inputs `x`.

    Args:
      x: Inputs.
      train: Whether to use BatchNorm in training or inference mode.

    Returns:
      The output head with `num_classes` entries.
    """
    width = 64 * self.width_factor
    norm = functools.partial(
        nn.BatchNorm, use_running_average=not train, momentum=0.9)

    all_outputs = []

    # Root block
    x = nn.Conv(
        features=width,
        kernel_size=(7, 7),
        strides=(2, 2),
        use_bias=False,
        name="init_conv")(
            x)

    all_outputs.append(x)

    x = norm(name="init_bn")(x)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")

    all_outputs.append(x)

    # Stages
    for i, stage_size in enumerate(self.stage_sizes):
      x = ResNetStage(
          stage_size,
          filters=width * 2**i,
          block_cls=self.block_cls,
          norm=norm,
          first_block_strides=(1, 1) if i == 0 else (2, 2),
          name=f"stage{i + 1}")(
              x)
      all_outputs.append(x)

    return all_outputs
