# Based on:
# https://github.com/google/flax/blob/43b358c/examples/imagenet/models.py
#
# With CIFAR variant from: https://github.com/kuangliu/pytorch-cifar

# Copyright 2022 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Flax implementation of ResNet V1."""

# See issue #620.
# pytype: disable=wrong-arg-count

from functools import partial
from typing import Any, Callable, Sequence, Tuple

from flax import linen as nn
import jax.numpy as jnp

ModuleDef = Callable[..., nn.Module]

# torch.nn.Linear initialization: Unif(-1/sqrt(k), 1/sqrt(k)).
# Variance of uniform [-a, a] is (2 a)^2 / 12 = a^2 / 3.
pytorch_kernel_init = nn.initializers.variance_scaling(scale=1/3, mode='fan_in', distribution='uniform')


class ResNetBlock(nn.Module):
    """ResNet block."""
    filters: int
    conv: ModuleDef
    norm: ModuleDef
    act: Callable
    strides: Tuple[int, int] = (1, 1)
    no_bn: bool = False
    no_init_zero: bool = False  # Initialize residual block outputs to zero?

    @nn.compact
    def __call__(self, x, norm_kwargs=None):
        norm_kwargs = norm_kwargs or {}
        r = x
        r = self.conv(self.filters, (3, 3), self.strides, use_bias=self.no_bn)(r)
        r = r if self.no_bn else self.norm()(r, **norm_kwargs)
        r = self.act(r)

        last_conv = partial(self.conv, self.filters, (3, 3))
        if self.no_bn:
            if self.no_init_zero:
                r = last_conv(use_bias=True)(r)
            else:
                # Initalize residual block output to zero.
                r = last_conv(use_bias=True, kernel_init=nn.initializers.zeros)(r)
        else:
            r = last_conv(use_bias=False)(r)
            if self.no_init_zero:  # init zero
                r = self.norm()(r, **norm_kwargs)
            else:
                # Initalize residual block output to zero.
                r = self.norm(scale_init=nn.initializers.zeros)(r, **norm_kwargs)

        if x.shape != r.shape:
            x = self.conv(self.filters, (1, 1), self.strides,
                          use_bias=self.no_bn, name='conv_proj')(x)
            x = x if self.no_bn else self.norm(name='norm_proj')(x, **norm_kwargs)

        return self.act(x + r)


class BottleneckResNetBlock(nn.Module):
    """Bottleneck ResNet block."""
    filters: int
    conv: ModuleDef
    norm: ModuleDef
    act: Callable
    strides: Tuple[int, int] = (1, 1)
    no_bn: bool = False
    no_init_zero: bool = False

    @nn.compact
    def __call__(self, x, norm_kwargs=None):
        norm_kwargs = norm_kwargs or {}
        r = x
        r = self.conv(self.filters, (1, 1), use_bias=self.no_bn)(r)
        r = r if self.no_bn else self.norm()(r, **norm_kwargs)
        r = self.act(r)
        r = self.conv(self.filters, (3, 3), self.strides, use_bias=self.no_bn)(r)
        r = r if self.no_bn else self.norm()(r, **norm_kwargs)
        r = self.act(r)

        last_conv = partial(self.conv, self.filters * 4, (1, 1))
        if self.no_bn:
            if self.no_init_zero:
                r = last_conv(use_bias=True)(r)
            else:
                # Initalize residual block output to zero.
                r = last_conv(use_bias=True, kernel_init=nn.initializers.zeros)(r)
        else:
            r = last_conv(use_bias=False)(r)
            if self.no_init_zero:  # init zero
                r = self.norm()(r, **norm_kwargs)
            else:
                # Initalize residual block output to zero.
                r = self.norm(scale_init=nn.initializers.zeros)(r, **norm_kwargs)

        if x.shape != r.shape:
            x = self.conv(self.filters * 4, (1, 1), self.strides,
                          use_bias=self.no_bn, name='conv_proj')(x)
            x = x if self.no_bn else self.norm(name='norm_proj')(x, **norm_kwargs)

        return self.act(x + r)


class ResNet(nn.Module):
    """ResNetV1."""
    stage_sizes: Sequence[int]
    block_cls: ModuleDef
    num_classes: int
    num_filters: int = 64
    act: Callable = nn.relu
    conv: ModuleDef = nn.Conv
    norm: ModuleDef = nn.BatchNorm
    no_bn: bool = False
    no_init_zero: bool = False

    @nn.compact
    def __call__(self, x, norm_kwargs=None):
        norm_kwargs = norm_kwargs or {}
        # Override existing kernel_init (but let caller override this).
        def conv(*args, kernel_init=pytorch_kernel_init, **kwargs):
            return self.conv(*args, kernel_init=kernel_init)

        # 3x3 conv with stride 1
        x = conv(self.num_filters, (3, 3), (1, 1), padding=[(1, 1), (1, 1)], use_bias=self.no_bn, name='conv_init')(x)
        x = x if self.no_bn else self.norm(name='bn_init')(x, **norm_kwargs)
        x = nn.relu(x)

        for i, block_size in enumerate(self.stage_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = self.block_cls(self.num_filters * 2 ** i,
                                   strides=strides,
                                   conv=conv,
                                   norm=self.norm,
                                   act=self.act,
                                   no_bn=self.no_bn,
                                   no_init_zero=self.no_init_zero)(x, norm_kwargs=norm_kwargs)
        x = jnp.mean(x, axis=(-3, -2))
        x = nn.Dense(self.num_classes, kernel_init=pytorch_kernel_init)(x)
        return x


ResNet18 = partial(ResNet, [2, 2, 2, 2], ResNetBlock)
ResNet34 = partial(ResNet, [3, 4, 6, 3], ResNetBlock)
ResNet50 = partial(ResNet, [3, 4, 6, 3], BottleneckResNetBlock)
ResNet101 = partial(ResNet, [3, 4, 23, 3], BottleneckResNetBlock)
ResNet152 = partial(ResNet, [3, 8, 36, 3], BottleneckResNetBlock)
ResNet200 = partial(ResNet, [3, 24, 36, 3], BottleneckResNetBlock)

ResNet18Local = partial(ResNet, [2, 2, 2, 2], ResNetBlock, conv=nn.ConvLocal)

# Used for testing only.
_ResNet1 = partial(ResNet, [1], ResNetBlock)
_ResNet1Local = partial(ResNet, [1], ResNetBlock, conv=nn.ConvLocal)
