""" Relative Position Vision Transformer (ViT) in PyTorch

NOTE: these models are experimental / WIP, expect changes

Hacked together by / Copyright 2022, Ross Wightman
"""
import math
import logging
from functools import partial
from collections import OrderedDict
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
from .registry import register_model

_logger = logging.getLogger(__name__)


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
        'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
        'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
        'first_conv': 'patch_embed.proj', 'classifier': 'head',
        **kwargs
    }


default_cfgs = {
    'vit_relpos_base_patch32_plus_rpn_256': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
        input_size=(3, 256, 256)),
    'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)),

    'vit_relpos_small_patch16_224': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth'),
    'vit_relpos_medium_patch16_224': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth'),
    'vit_relpos_base_patch16_224': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),

    'vit_relpos_base_patch16_cls_224': _cfg(
        url=''),
    'vit_relpos_base_patch16_gapcls_224': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'),

    'vit_relpos_small_patch16_rpn_224': _cfg(url=''),
    'vit_relpos_medium_patch16_rpn_224': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth'),
    'vit_relpos_base_patch16_rpn_224': _cfg(url=''),
}


def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor:
    # cut and paste w/ modifications from swin / beit codebase
    # cls to token & token 2 cls & cls to cls
    # get pair-wise relative position index for each token inside the window
    window_area = win_size[0] * win_size[1]
    coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1)  # 2, Wh, Ww
    relative_coords = coords[:, :, None] - coords[:, None, :]  # 2, Wh*Ww, Wh*Ww
    relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
    relative_coords[:, :, 0] += win_size[0] - 1  # shift to start from 0
    relative_coords[:, :, 1] += win_size[1] - 1
    relative_coords[:, :, 0] *= 2 * win_size[1] - 1
    if class_token:
        num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3
        relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        relative_position_index[0, 0:] = num_relative_distance - 3
        relative_position_index[0:, 0] = num_relative_distance - 2
        relative_position_index[0, 0] = num_relative_distance - 1
    else:
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
    return relative_position_index


def gen_relative_log_coords(
        win_size: Tuple[int, int],
        pretrained_win_size: Tuple[int, int] = (0, 0),
        mode='swin'
):
    # as per official swin-v2 impl, supporting timm swin-v2-cr coords as well
    relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
    relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
    relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
    relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous()  # 2*Wh-1, 2*Ww-1, 2
    if mode == 'swin':
        if pretrained_win_size[0] > 0:
            relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
            relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
        else:
            relative_coords_table[:, :, 0] /= (win_size[0] - 1)
            relative_coords_table[:, :, 1] /= (win_size[1] - 1)
        relative_coords_table *= 8  # normalize to -8, 8
        scale = math.log2(8)
    else:
        # FIXME we should support a form of normalization (to -1/1) for this mode?
        scale = math.log2(math.e)
    relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
        1.0 + relative_coords_table.abs()) / scale
    return relative_coords_table


class RelPosMlp(nn.Module):
    def __init__(
            self,
            window_size,
            num_heads=8,
            hidden_dim=128,
            class_token=False,
            mode='cr',
            pretrained_window_size=(0, 0)
    ):
        super().__init__()
        self.window_size = window_size
        self.window_area = self.window_size[0] * self.window_size[1]
        self.class_token = 1 if class_token else 0
        self.num_heads = num_heads
        self.bias_shape = (self.window_area,) * 2 + (num_heads,)
        self.apply_sigmoid = mode == 'swin'

        mlp_bias = (True, False) if mode == 'swin' else True
        self.mlp = Mlp(
            2,  # x, y
            hidden_features=hidden_dim,
            out_features=num_heads,
            act_layer=nn.ReLU,
            bias=mlp_bias,
            drop=(0.125, 0.)
        )

        self.register_buffer(
            "relative_position_index",
            gen_relative_position_index(window_size),
            persistent=False)

        # get relative_coords_table
        self.register_buffer(
            "rel_coords_log",
            gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
            persistent=False)

    def get_bias(self) -> torch.Tensor:
        relative_position_bias = self.mlp(self.rel_coords_log)
        if self.relative_position_index is not None:
            relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
                self.relative_position_index.view(-1)]  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.view(self.bias_shape)
        relative_position_bias = relative_position_bias.permute(2, 0, 1)
        if self.apply_sigmoid:
            relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
        if self.class_token:
            relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0])
        return relative_position_bias.unsqueeze(0).contiguous()

    def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
        return attn + self.get_bias()


class RelPosBias(nn.Module):

    def __init__(self, window_size, num_heads, class_token=False):
        super().__init__()
        self.window_size = window_size
        self.window_area = window_size[0] * window_size[1]
        self.class_token = 1 if class_token else 0
        self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,)

        num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token
        self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
        self.register_buffer(
            "relative_position_index",
            gen_relative_position_index(self.window_size, class_token=self.class_token),
            persistent=False,
        )

        self.init_weights()

    def init_weights(self):
        trunc_normal_(self.relative_position_bias_table, std=.02)

    def get_bias(self) -> torch.Tensor:
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
        # win_h * win_w, win_h * win_w, num_heads
        relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
        return relative_position_bias.unsqueeze(0).contiguous()

    def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
        return attn + self.get_bias()


class RelPosAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        if self.rel_pos is not None:
            attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
        elif shared_rel_pos is not None:
            attn = attn + shared_rel_pos
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class RelPosBlock(nn.Module):

    def __init__(
            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None,
            drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = RelPosAttention(
            dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop)
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x


class ResPostRelPosBlock(nn.Module):

    def __init__(
            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None,
            drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.init_values = init_values

        self.attn = RelPosAttention(
            dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop)
        self.norm1 = norm_layer(dim)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.norm2 = norm_layer(dim)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.init_weights()

    def init_weights(self):
        # NOTE this init overrides that base model init with specific changes for the block type
        if self.init_values is not None:
            nn.init.constant_(self.norm1.weight, self.init_values)
            nn.init.constant_(self.norm2.weight, self.init_values)

    def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
        x = x + self.drop_path1(self.norm1(self.attn(x, shared_rel_pos=shared_rel_pos)))
        x = x + self.drop_path2(self.norm2(self.mlp(x)))
        return x


class VisionTransformerRelPos(nn.Module):
    """ Vision Transformer w/ Relative Position Bias

    Differing from classic vit, this impl
      * uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed
      * defaults to no class token (can be enabled)
      * defaults to global avg pool for head (can be changed)
      * layer-scale (residual branch gain) enabled
    """

    def __init__(
            self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
            embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6,
            class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None,
            drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip',
            embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            global_pool (str): type of global pooling for final sequence (default: 'avg')
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            init_values: (float): layer-scale init values
            class_token (bool): use class token (default: False)
            fc_norm (bool): use pre classifier norm instead of pre-pool
            rel_pos_ty pe (str): type of relative position
            shared_rel_pos (bool): share relative pos across all blocks
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            weight_init (str): weight init scheme
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
            act_layer: (nn.Module): MLP activation layer
        """
        super().__init__()
        assert global_pool in ('', 'avg', 'token')
        assert class_token or global_pool != 'token'
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.num_classes = num_classes
        self.global_pool = global_pool
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 1 if class_token else 0
        self.grad_checkpointing = False

        self.patch_embed = embed_layer(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        feat_size = self.patch_embed.grid_size

        rel_pos_args = dict(window_size=feat_size, class_token=class_token)
        if rel_pos_type.startswith('mlp'):
            if rel_pos_dim:
                rel_pos_args['hidden_dim'] = rel_pos_dim
            if 'swin' in rel_pos_type:
                rel_pos_args['mode'] = 'swin'
            rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
        else:
            rel_pos_cls = partial(RelPosBias, **rel_pos_args)
        self.shared_rel_pos = None
        if shared_rel_pos:
            self.shared_rel_pos = rel_pos_cls(num_heads=num_heads)
            # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both...
            rel_pos_cls = None

        self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            block_fn(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls,
                init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
                norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()

        # Classifier Head
        self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity()
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        if weight_init != 'skip':
            self.init_weights(weight_init)

    def init_weights(self, mode=''):
        assert mode in ('jax', 'moco', '')
        if self.cls_token is not None:
            nn.init.normal_(self.cls_token, std=1e-6)
        # FIXME weight init scheme using PyTorch defaults curently
        #named_apply(get_init_weights_vit(mode, head_bias), self)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token'}

    @torch.jit.ignore
    def group_matcher(self, coarse=False):
        return dict(
            stem=r'^cls_token|patch_embed',  # stem and embed
            blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
        )

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable

    @torch.jit.ignore
    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes: int, global_pool=None):
        self.num_classes = num_classes
        if global_pool is not None:
            assert global_pool in ('', 'avg', 'token')
            self.global_pool = global_pool
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.cls_token is not None:
            x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)

        shared_rel_pos = self.shared_rel_pos.get_bias() if self.shared_rel_pos is not None else None
        for blk in self.blocks:
            if self.grad_checkpointing and not torch.jit.is_scripting():
                x = checkpoint(blk, x, shared_rel_pos=shared_rel_pos)
            else:
                x = blk(x, shared_rel_pos=shared_rel_pos)
        x = self.norm(x)
        return x

    def forward_head(self, x, pre_logits: bool = False):
        if self.global_pool:
            x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
        x = self.fc_norm(x)
        return x if pre_logits else self.head(x)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
    if kwargs.get('features_only', None):
        raise RuntimeError('features_only not implemented for Vision Transformer models.')

    model = build_model_with_cfg(VisionTransformerRelPos, variant, pretrained, **kwargs)
    return model


@register_model
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
    """
    model_kwargs = dict(
        patch_size=32, embed_dim=896, depth=12, num_heads=14, block_fn=ResPostRelPosBlock, **kwargs)
    model = _create_vision_transformer_relpos(
        'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token
    """
    model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, **kwargs)
    model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_relpos_small_patch16_224(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
    """
    model_kwargs = dict(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=True, **kwargs)
    model = _create_vision_transformer_relpos('vit_relpos_small_patch16_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_relpos_medium_patch16_224(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
    """
    model_kwargs = dict(
        patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=True, **kwargs)
    model = _create_vision_transformer_relpos('vit_relpos_medium_patch16_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
    """
    model_kwargs = dict(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, **kwargs)
    model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
    """
    model_kwargs = dict(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False,
        class_token=True, global_pool='token', **kwargs)
    model = _create_vision_transformer_relpos('vit_relpos_base_patch16_cls_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
    NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled
    Leaving here for comparisons w/ a future re-train as it performs quite well.
    """
    model_kwargs = dict(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs)
    model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_relpos_small_patch16_rpn_224(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
    """
    model_kwargs = dict(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
    model = _create_vision_transformer_relpos(
        'vit_relpos_small_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_relpos_medium_patch16_rpn_224(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
    """
    model_kwargs = dict(
        patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
    model = _create_vision_transformer_relpos(
        'vit_relpos_medium_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
    return model


@register_model
def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs):
    """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
    """
    model_kwargs = dict(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
    model = _create_vision_transformer_relpos(
        'vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
    return model
