import torch
import torch.nn as nn
from functools import partial

from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
    trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
    get_act_layer, get_norm_layer, LayerType

__all__ = [
    'deit_tiny_patch16_224_L3', 'deit_tiny_patch16_224_L4', 'deit_tiny_patch16_224_L5', 'deit_tiny_patch16_224_L6',
    'deit_tiny_patch16_224_L7', 'deit_tiny_patch16_224_L8', 'deit_tiny_patch16_224_L9', 'deit_tiny_patch16_224_L10',
    'deit_tiny_patch16_224_L11', 'deit_tiny_patch16_224_L12',
    'deit_small_patch16_224_L3', 'deit_small_patch16_224_L4', 'deit_small_patch16_224_L5', 'deit_small_patch16_224_L6',
    'deit_small_patch16_224_L7', 'deit_small_patch16_224_L8', 'deit_small_patch16_224_L9', 'deit_small_patch16_224_L10',
    'deit_small_patch16_224_L11', 'deit_small_patch16_224_L12',
    'deit_base_patch16_256_L3', 'deit_base_patch16_256_L4', 'deit_base_patch16_256_L5', 'deit_base_patch16_256_L6',
    'deit_base_patch16_256_L7', 'deit_base_patch16_256_L8', 'deit_base_patch16_256_L9', 'deit_base_patch16_256_L10',
    'deit_base_patch16_256_L11', 'deit_base_patch16_256_L12',
]

class Mlp(nn.Module):
    def __init__(self, in_dim, outdim, hidden_dim=None):
        super(Mlp, self).__init__()
        if hidden_dim is None:
            hidden_dim = int((in_dim+outdim)/2)
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, outdim)
        self.act_fn = nn.GELU()  # torch.nn.functional.gelu
        self.dropout = nn.Dropout()  # Dropout(p=0.1)

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        #x = self.dropout(x)
        x = self.fc2(x)
        return x

class PartVisionTransformer(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.total_points = self.patch_embed.img_size[0]*self.patch_embed.img_size[1]
        self.one_path_points = self.patch_embed.patch_size[0]*self.patch_embed.patch_size[1]
        self.point_embed_dim = int(self.embed_dim/self.one_path_points)
        self.head = nn.Linear(self.point_embed_dim, self.num_classes)
        norm_layer = get_norm_layer(kwargs['norm_layer'])
        if not isinstance(self.norm, nn.Identity):
            self.norm = norm_layer(self.point_embed_dim)
        if not isinstance(self.fc_norm,nn.Identity):
            self.fc_norm = norm_layer(self.point_embed_dim)
        self.init_weights()


        #self.output_layer = nn.Linear(self.embed_dim,self.num_classes)

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)
        x = self.blocks(x)
        x = x.reshape(1,self.total_points,self.point_embed_dim)
        x = self.norm(x)
        return x

    def forward_heads(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
        if self.attn_pool is not None:
            x = self.attn_pool(x)
        x = self.fc_norm(x)
        x = self.head_drop(x)
        return x if pre_logits else self.head(x)

    def forward(self, x: torch.Tensor):
        x = self.forward_features(x)
        output = self.forward_head(x)
        return x, output


'''
@register_model
def deit_tiny_distilled_patch16_224(pretrained=False, pretrained_cfg=None, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_small_distilled_patch16_224(pretrained=False, pretrained_cfg=None, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_distilled_patch16_224(pretrained=False, pretrained_cfg=None, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_patch16_384(pretrained=False, pretrained_cfg=None, **kwargs):
    model = VisionTransformer(
        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_distilled_patch16_384(pretrained=False, pretrained_cfg=None, **kwargs):
    model = DistilledVisionTransformer(
        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model
'''


@register_model
def deit_tiny_patch16_224_L3(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 3
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=192, depth=3, num_heads=3, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_tiny_patch16_224_L4(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 4
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=192, depth=4, num_heads=3, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_tiny_patch16_224_L5(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 5
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=192, depth=5, num_heads=3, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_tiny_patch16_224_L6(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 6
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=192, depth=6, num_heads=3, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_tiny_patch16_224_L7(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 7
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=192, depth=7, num_heads=3, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_tiny_patch16_224_L8(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 8
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=192, depth=8, num_heads=3, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_tiny_patch16_224_L9(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 9
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=192, depth=9, num_heads=3, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_tiny_patch16_224_L10(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 10
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=192, depth=10, num_heads=3, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_tiny_patch16_224_L11(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 11
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=192, depth=11, num_heads=3, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_tiny_patch16_224_L12(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_small_patch16_224_L3(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 3
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=384, depth=3, num_heads=6, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_small_patch16_224_L4(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 4
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=384, depth=4, num_heads=6, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_small_patch16_224_L5(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 5
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=384, depth=5, num_heads=6, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_small_patch16_224_L6(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 6
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=384, depth=6, num_heads=6, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_small_patch16_224_L7(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 7
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=384, depth=7, num_heads=6, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_small_patch16_224_L8(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 8
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=384, depth=8, num_heads=6, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_small_patch16_224_L9(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 9
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=384, depth=9, num_heads=6, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_small_patch16_224_L10(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 10
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=384, depth=10, num_heads=6, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_small_patch16_224_L11(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 11
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=384, depth=11, num_heads=6, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained

    return model


@register_model
def deit_small_patch16_224_L12(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    # depth 12 -> 12
    model = PartVisionTransformer(
        img_size=256, in_chans=in_channel,weight_init='skip',
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    assert not pretrained
    return model


@register_model
def deit_base_patch16_256_L3(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256,in_chans=in_channel,weight_init='skip',
        patch_size=8, embed_dim=96*64, depth=3, num_heads=12, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    return model


@register_model
def deit_base_patch16_256_L4(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256,in_chans=in_channel,weight_init='skip',
        patch_size=8, embed_dim=96*64, depth=4, num_heads=12, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def deit_base_patch16_256_L5(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256,in_chans=in_channel,weight_init='skip',
        patch_size=8, embed_dim=96*64, depth=5, num_heads=12, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def deit_base_patch16_256_L6(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256,in_chans=in_channel,weight_init='skip',
        patch_size=8, embed_dim=96*64, depth=6, num_heads=12, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def deit_base_patch16_256_L7(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256,in_chans=in_channel,weight_init='skip',
        patch_size=8, embed_dim=96*64, depth=7, num_heads=12, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def deit_base_patch16_256_L8(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256,in_chans=in_channel,weight_init='skip',
        patch_size=8, embed_dim=96*64, depth=8, num_heads=12, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def deit_base_patch16_256_L9(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256,in_chans=in_channel,weight_init='skip',
        patch_size=8, embed_dim=96*64, depth=9, num_heads=12, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def deit_base_patch16_256_L10(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256,in_chans=in_channel,weight_init='skip',
        patch_size=8, embed_dim=96*64, depth=10, num_heads=12, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def deit_base_patch16_256_L11(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256,in_chans=in_channel,weight_init='skip',
        patch_size=8, embed_dim=96*64, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def deit_base_patch16_256_L12(in_channel=36,pretrained=False, pretrained_cfg=None,pretrained_cfg_overlay=None, **kwargs):
    model = PartVisionTransformer(
        img_size=256,in_chans=in_channel,weight_init='skip',
        patch_size=8, embed_dim=96*64, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, global_pool='',
        class_token=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()

    return model



