# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BiT models as in the paper (ResNet V2) w/ loading of public weights.

See reproduction proof: http://(internal link)
"""

import functools
import re
from typing import Optional, Sequence, Union

# from big_vision.models import bit
# from big_vision.models import common
import flax.linen as nn
import jax
import jax.numpy as jnp

# from big_vision import utils as u
import jaxrl_m.vision.bigvision_common as common
import jaxrl_m.vision.bigvision_utils as u


def standardize(x, axis, eps):
    x = x - jnp.mean(x, axis=axis, keepdims=True)
    x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps)
    return x


# Defined our own, because we compute normalizing variance slightly differently,
# which does affect performance when loading pre-trained weights!
class GroupNorm(nn.Module):
    """Group normalization (arxiv.org/abs/1803.08494)."""

    ngroups: int = 32

    @nn.compact
    def __call__(self, x):

        input_shape = x.shape
        group_shape = x.shape[:-1] + (self.ngroups, x.shape[-1] // self.ngroups)

        x = x.reshape(group_shape)

        # Standardize along spatial and group dimensions
        x = standardize(x, axis=[1, 2, 4], eps=1e-5)
        x = x.reshape(input_shape)

        bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]])
        x = x * self.param("scale", nn.initializers.ones, bias_scale_shape)
        x = x + self.param("bias", nn.initializers.zeros, bias_scale_shape)
        return x


class StdConv(nn.Conv):
    def param(self, name, *a, **kw):
        param = super().param(name, *a, **kw)
        if name == "kernel":
            param = standardize(param, axis=[0, 1, 2], eps=1e-10)
        return param


class RootBlock(nn.Module):
    """Root block of ResNet."""

    width: int

    @nn.compact
    def __call__(self, x):
        x = StdConv(
            self.width,
            (7, 7),
            (2, 2),
            padding=[(3, 3), (3, 3)],
            use_bias=False,
            name="conv_root",
        )(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)])
        return x


class ResidualUnit(nn.Module):
    """Bottleneck ResNet block."""

    nmid: Optional[int] = None
    strides: Sequence[int] = (1, 1)

    @nn.compact
    def __call__(self, x):
        nmid = self.nmid or x.shape[-1] // 4
        nout = nmid * 4
        conv = functools.partial(StdConv, use_bias=False)

        residual = x
        x = GroupNorm(name="gn1")(x)
        x = nn.relu(x)

        if x.shape[-1] != nout or self.strides != (1, 1):
            residual = conv(nout, (1, 1), self.strides, name="conv_proj")(x)

        x = conv(nmid, (1, 1), name="conv1")(x)
        x = GroupNorm(name="gn2")(x)
        x = nn.relu(x)
        x = conv(nmid, (3, 3), self.strides, padding=[(1, 1), (1, 1)], name="conv2")(x)
        x = GroupNorm(name="gn3")(x)
        x = nn.relu(x)
        x = conv(nout, (1, 1), name="conv3")(x)

        return x + residual


class ResNetStage(nn.Module):
    """A stage (sequence of same-resolution blocks)."""

    block_size: int
    nmid: Optional[int] = None
    first_stride: Sequence[int] = (1, 1)

    @nn.compact
    def __call__(self, x):
        out = {}
        x = out["unit01"] = ResidualUnit(
            self.nmid, strides=self.first_stride, name="unit01"
        )(x)
        for i in range(1, self.block_size):
            x = out[f"unit{i+1:02d}"] = ResidualUnit(self.nmid, name=f"unit{i+1:02d}")(
                x
            )
        return x, out


class Model(nn.Module):
    """ResNetV2."""

    num_classes: int
    width: int = 1
    depth: Union[int, Sequence[int]] = 50  # 50/101/152, or list of block depths.
    image_shape: tuple = None

    @nn.compact
    def __call__(self, x, *, train=False):
        if self.image_shape is not None:
            x = jax.image.resize(
                x, (*x.shape[:-3], *self.image_shape, x.shape[-1]), "bilinear"
            )
            print("Resizing to %s" % str(self.image_shape))
        # put inputs in [-1, 1]
        x = x.astype(jnp.float32) / 127.5 - 1.0
        blocks = get_block_desc(self.depth)
        width = int(64 * self.width)
        out = {}

        x = out["stem"] = RootBlock(width=width, name="root_block")(x)

        # Blocks
        x, out["stage1"] = ResNetStage(blocks[0], nmid=width, name="block1")(x)
        for i, block_size in enumerate(blocks[1:], 1):
            x, out[f"stage{i + 1}"] = ResNetStage(
                block_size, width * 2**i, first_stride=(2, 2), name=f"block{i + 1}"
            )(x)

        # Pre-head
        x = out["norm_pre_head"] = GroupNorm(name="norm-pre-head")(x)
        x = out["pre_logits_2d"] = nn.relu(x)
        x = out["pre_logits"] = jnp.mean(x, axis=(1, 2))

        # Head
        if self.num_classes:
            head = nn.Dense(
                self.num_classes, name="head", kernel_init=nn.initializers.zeros
            )
            out["logits_2d"] = head(out["pre_logits_2d"])
            x = out["logits"] = head(out["pre_logits"])

        return x  # , out


def get_block_desc(depth):
    if isinstance(depth, list):  # Be robust to silly mistakes.
        depth = tuple(depth)
    return {
        26: [2, 2, 2, 2],  # From timm, gets ~75% on ImageNet.
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        152: [3, 8, 36, 3],
        200: [3, 24, 36, 3],
    }.get(depth, depth)


def load(init_params, checkpoint_path, dont_load=()):
    """Loads the TF-dumped NumPy or big_vision checkpoint.

    Args:
      init_params: random init params from which the new head is taken.
      checkpoint_path: path to checkpoint.
      dont_load: list of param names to be reset to init.

    Returns:
      The loaded parameters.
    """

    params = u.load_params(None, checkpoint_path)
    params = maybe_convert_big_transfer_format(params)
    return common.merge_params(params, init_params, dont_load)


def maybe_convert_big_transfer_format(params_tf):
    """If the checkpoint comes from legacy codebase, convert it."""

    # Only do anything at all if we recognize the format.
    if "resnet" not in params_tf:
        return params_tf

    # For ease of processing and backwards compatibility, flatten again:
    params_tf = dict(u.tree_flatten_with_names(params_tf)[0])

    # Works around some files containing weird naming of variables:
    for k in list(params_tf):
        k2 = re.sub("/standardized_conv2d_\\d+/", "/standardized_conv2d/", k)
        if k2 != k:
            params_tf[k2] = params_tf[k]
            del params_tf[k]

    params = {
        "root_block": {
            "conv_root": {
                "kernel": params_tf["resnet/root_block/standardized_conv2d/kernel"]
            }
        },
        "norm-pre-head": {
            "bias": params_tf["resnet/group_norm/beta"][None, None, None],
            "scale": params_tf["resnet/group_norm/gamma"][None, None, None],
        },
        "head": {
            "kernel": params_tf["resnet/head/conv2d/kernel"][0, 0],
            "bias": params_tf["resnet/head/conv2d/bias"],
        },
    }

    for block in ("block1", "block2", "block3", "block4"):
        params[block] = {}
        units = set(
            [
                re.findall(r"unit\d+", p)[0]
                for p in params_tf.keys()
                if p.find(block) >= 0
            ]
        )
        for unit in units:
            params[block][unit] = {}
            for i, group in enumerate("abc", 1):
                params[block][unit][f"conv{i}"] = {
                    "kernel": params_tf[
                        f"resnet/{block}/{unit}/{group}/standardized_conv2d/kernel"
                    ]  # pylint: disable=line-too-long
                }
                params[block][unit][f"gn{i}"] = {
                    "bias": params_tf[f"resnet/{block}/{unit}/{group}/group_norm/beta"][
                        None, None, None
                    ],  # pylint: disable=line-too-long
                    "scale": params_tf[
                        f"resnet/{block}/{unit}/{group}/group_norm/gamma"
                    ][
                        None, None, None
                    ],  # pylint: disable=line-too-long
                }

            projs = [
                p for p in params_tf.keys() if p.find(f"{block}/{unit}/a/proj") >= 0
            ]
            assert len(projs) <= 1
            if projs:
                params[block][unit]["conv_proj"] = {"kernel": params_tf[projs[0]]}

    return params


import functools as ft

resnetv2_configs = {
    "resnetv2-26-1": ft.partial(Model, num_classes=None, depth=26),
    "resnetv2-26-1-128": ft.partial(
        Model, num_classes=None, depth=26, image_shape=(128, 128)
    ),
    "resnetv2-50-1": ft.partial(Model, num_classes=None, depth=50),
    "resnetv2-50-1-128": ft.partial(
        Model, num_classes=None, depth=50, image_shape=(128, 128)
    ),
}
