# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BiT models as in the paper (ResNet V2) w/ loading of public weights.

See reproduction proof: http://(internal link)
"""

import functools
import re
from typing import Optional, Sequence, Union

# from big_vision.models import bit
# from big_vision.models import common
import flax.linen as nn
import jax
import jax.numpy as jnp

# from big_vision import utils as u
import jaxrl_m.vision.bigvision_common as common
import jaxrl_m.vision.bigvision_utils as u


def standardize(x, axis, eps):
    x = x - jnp.mean(x, axis=axis, keepdims=True)
    x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps)
    return x


# Defined our own, because we compute normalizing variance slightly differently,
# which does affect performance when loading pre-trained weights!
class GroupNorm(nn.Module):
    """Group normalization (arxiv.org/abs/1803.08494)."""
    ngroups: int = 32

    @nn.compact
    def __call__(self, x):
        input_shape = x.shape
        group_shape = x.shape[:-1] + (self.ngroups, x.shape[-1] // self.ngroups)

        x = x.reshape(group_shape)

        # Standardize along spatial and group dimensions
        x = standardize(x, axis=[1, 2, 4], eps=1e-5)
        x = x.reshape(input_shape)

        bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]])
        x = x * self.param('scale', nn.initializers.ones, bias_scale_shape)
        x = x + self.param('bias', nn.initializers.zeros, bias_scale_shape)
        return x


class StdConv(nn.Conv):

    def param(self, name, *a, **kw):
        param = super().param(name, *a, **kw)
        if name == 'kernel':
            param = standardize(param, axis=[0, 1, 2], eps=1e-10)
        return param


class RootBlock(nn.Module):
    """Root block of ResNet."""
    width: int

    @nn.compact
    def __call__(self, x):
        x = StdConv(self.width, (7, 7), (2, 2), padding=[(3, 3), (3, 3)],
                    use_bias=False, name='conv_root')(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)])
        return x


class ResidualUnit(nn.Module):
    """Bottleneck ResNet block."""
    nmid: Optional[int] = None
    strides: Sequence[int] = (1, 1)

    @nn.compact
    def __call__(self, x):
        nmid = self.nmid or x.shape[-1] // 4
        nout = nmid * 4
        conv = functools.partial(StdConv, use_bias=False)

        residual = x
        x = GroupNorm(name='gn1')(x)
        x = nn.relu(x)

        if x.shape[-1] != nout or self.strides != (1, 1):
            residual = conv(nout, (1, 1), self.strides, name='conv_proj')(x)

        x = conv(nmid, (1, 1), name='conv1')(x)
        x = GroupNorm(name='gn2')(x)
        x = nn.relu(x)
        x = conv(nmid, (3, 3), self.strides, padding=[(1, 1), (1, 1)],
                 name='conv2')(x)
        x = GroupNorm(name='gn3')(x)
        x = nn.relu(x)
        x = conv(nout, (1, 1), name='conv3')(x)

        return x + residual


class ResNetStage(nn.Module):
    """A stage (sequence of same-resolution blocks)."""
    block_size: int
    nmid: Optional[int] = None
    first_stride: Sequence[int] = (1, 1)

    @nn.compact
    def __call__(self, x):
        out = {}
        x = out['unit01'] = ResidualUnit(
            self.nmid, strides=self.first_stride, name='unit01')(x)
        for i in range(1, self.block_size):
            x = out[f'unit{i + 1:02d}'] = ResidualUnit(
                self.nmid, name=f'unit{i + 1:02d}')(x)
        return x, out


class Model(nn.Module):
    """ResNetV2."""
    num_classes: int
    width: int = 1
    depth: Union[int, Sequence[int]] = 50  # 5/101/152, or list of block depths.
    image_shape: tuple = None

    @nn.compact
    def __call__(self, x, *, train=False):
        if self.image_shape is not None:
            x = jax.image.resize(x, (*x.shape[:-3], *self.image_shape, x.shape[-1]), 'bilinear')
            print('Resizing to %s' % str(self.image_shape))
        x = x / 255.0
        blocks = get_block_desc(self.depth)
        width = int(64 * self.width)
        out = {}

        x = out['stem'] = RootBlock(width=width, name='root_block')(x)

        # Blocks
        x, out['stage1'] = ResNetStage(blocks[0], nmid=width, name='block1')(x)
        for i, block_size in enumerate(blocks[1:], 1):
            x, out[f'stage{i + 1}'] = ResNetStage(
                block_size, width * 2 ** i,
                first_stride=(2, 2), name=f'block{i + 1}')(x)

        # Pre-head
        x = out['norm_pre_head'] = GroupNorm(name='norm-pre-head')(x)
        x = out['pre_logits_2d'] = nn.relu(x)
        x = out['pre_logits'] = jnp.mean(x, axis=(1, 2))

        # Head
        if self.num_classes:
            head = nn.Dense(self.num_classes, name='head',
                            kernel_init=nn.initializers.zeros)
            out['logits_2d'] = head(out['pre_logits_2d'])
            x = out['logits'] = head(out['pre_logits'])

        return x  # , out


def get_block_desc(depth):
    if isinstance(depth, list):  # Be robust to silly mistakes.
        depth = tuple(depth)
    return {
        26: [2, 2, 2, 2],  # From timm, gets ~75% on ImageNet.
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        152: [3, 8, 36, 3],
        200: [3, 24, 36, 3]
    }.get(depth, depth)


def load(init_params, init_file, model_cfg, dont_load=()):
    """Loads the TF-dumped NumPy or big_vision checkpoint.

    Args:
      init_params: random init params from which the new head is taken.
      init_file: comes from `config.model_init`, can either be an absolute
        path (ie starts with /) to the checkpoint, or a string like
        "L-imagenet2012" describing one of the variants from the paper.
      model_cfg: the model configuration.
      dont_load: list of param names to be reset to init.

    Returns:
      The loaded parameters.
    """

    # Support for vanity model names from the paper.
    vanity = {
        'FunMatch-224px-i1k82.8': 'gs://bit_models/distill/R50x1_224.npz',
        'FunMatch-160px-i1k80.5': 'gs://bit_models/distill/R50x1_160.npz',
    }
    if init_file[0] in ('L', 'M', 'S'):  # The models from the original paper.
        # Supported names are of the following type:
        # - 'M' or 'S': the original "upstream" model without fine-tuning.
        # - 'M-ILSVRC2012': i21k model fine-tuned on i1k.
        # - 'M-run0-caltech101': i21k model fine-tuned on VTAB's caltech101.
        #    each VTAB fine-tuning was run 3x, so there's run0, run1, run2.
        if '-' in init_file:
            up, down = init_file[0], init_file[1:]
        else:
            up, down = init_file, ''
        down = {'-imagenet2012': '-ILSVRC2012'}.get(down, down)  # normalize
        fname = f'BiT-{up}-R{model_cfg.depth}x{model_cfg.width}{down}.npz'
        fname = f'gs://bit_models/{fname}'
    else:
        fname = vanity.get(init_file, init_file)

    params = u.load_params(None, fname)
    params = maybe_convert_big_transfer_format(params)
    return common.merge_params(params, init_params, dont_load)


def maybe_convert_big_transfer_format(params_tf):
    """If the checkpoint comes from legacy codebase, convert it."""

    # Only do anything at all if we recognize the format.
    if 'resnet' not in params_tf:
        return params_tf

    # For ease of processing and backwards compatibility, flatten again:
    params_tf = dict(u.tree_flatten_with_names(params_tf)[0])

    # Works around some files containing weird naming of variables:
    for k in list(params_tf):
        k2 = re.sub('/standardized_conv2d_\\d+/', '/standardized_conv2d/', k)
        if k2 != k:
            params_tf[k2] = params_tf[k]
            del params_tf[k]

    params = {
        'root_block': {'conv_root': {'kernel': params_tf[
            'resnet/root_block/standardized_conv2d/kernel']}},
        'norm-pre-head': {
            'bias': params_tf['resnet/group_norm/beta'][None, None, None],
            'scale': params_tf['resnet/group_norm/gamma'][None, None, None],
        },
        'head': {
            'kernel': params_tf['resnet/head/conv2d/kernel'][0, 0],
            'bias': params_tf['resnet/head/conv2d/bias'],
        }
    }

    for block in ('block1', 'block2', 'block3', 'block4'):
        params[block] = {}
        units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys()
                     if p.find(block) >= 0])
        for unit in units:
            params[block][unit] = {}
            for i, group in enumerate('abc', 1):
                params[block][unit][f'conv{i}'] = {
                    'kernel': params_tf[f'resnet/{block}/{unit}/{group}/standardized_conv2d/kernel']
                    # pylint: disable=line-too-long
                }
                params[block][unit][f'gn{i}'] = {
                    'bias': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/beta'][None, None, None],
                    # pylint: disable=line-too-long
                    'scale': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/gamma'][None, None, None],
                    # pylint: disable=line-too-long
                }

            projs = [p for p in params_tf.keys()
                     if p.find(f'{block}/{unit}/a/proj') >= 0]
            assert len(projs) <= 1
            if projs:
                params[block][unit]['conv_proj'] = {
                    'kernel': params_tf[projs[0]]
                }

    return params


import functools as ft

resnetv2_configs = {
    'resnetv2-26-1': ft.partial(Model, num_classes=None, depth=26),
    'resnetv2-26-1-128': ft.partial(Model, num_classes=None, depth=26, image_shape=(128, 128)),
    'resnetv2-50-1': ft.partial(Model, num_classes=None, depth=50),
    'resnetv2-50-1-128': ft.partial(Model, num_classes=None, depth=50, image_shape=(128, 128)),
}
