"""This implementation is based on the thrid-party library, timm.
This file should be put under the directory: timm/models/.
"""

import math
from collections import OrderedDict
from dataclasses import dataclass, replace
from functools import partial
from typing import Callable, Optional, Union, Tuple, List

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, checkpoint_seq, named_apply
from .fx_features import register_notrace_function
from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm
from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d
from .layers import to_2tuple, extend_tuple, make_divisible, _assert
from .registry import register_model
from .vision_transformer_relpos import RelPosMlp, RelPosBias  

__all__ = ['moat_model_config', 'moat_conv_config', 'moat_block_config', 'MOAT']


def _cfg(url='', **kwargs):
    return {
        'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 
        'crop_pct': 0.95, 'interpolation': 'bicubic',
        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
        'first_conv': 'stem.conv1', 'classifier': 'head.fc',
        'fixed_input_size': True,
        **kwargs
    }


@dataclass
class moat_conv_config:
    block_type: str = 'mbconv'
    expand_ratio: float = 4.0
    expand_output: bool = True  # calculate expansion channels from output (vs input chs)
    kernel_size: int = 3
    group_size: int = 1  # 1 == depthwise
    pre_norm_act: bool = False  # activation after pre-norm
    output_bias: bool = True  # bias for shortcut + final 1x1 projection conv
    stride_mode: str = 'dw'  # stride done via one of 'pool', '1x1', 'dw'
    pool_type: str = 'avg2'
    downsample_pool_type: str = 'avg2'
    attn_early: bool = False  # apply attn between conv2 and norm2, instead of after norm2
    attn_layer: str = 'se'
    attn_act_layer: str = 'silu'
    attn_ratio: float = 0.25
    init_values: Optional[float] = 1e-6
    act_layer: str = 'gelu'
    norm_layer: str = ''
    norm_layer_cl: str = ''
    norm_eps: Optional[float] = None
    remove_se: Optional[bool] = False

    def __post_init__(self):
        assert self.block_type in ('mbconv')
        use_mbconv = self.block_type == 'mbconv'
        if not self.norm_layer:
            self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
        if not self.norm_layer_cl and not use_mbconv:
            self.norm_layer_cl = 'layernorm'
        if self.norm_eps is None:
            self.norm_eps = 1e-5 if use_mbconv else 1e-6
        self.downsample_pool_type = self.downsample_pool_type or self.pool_type


@dataclass
class moat_block_cfg:
    dim_head: int = 32
    expand_ratio: float = 4.0
    expand_first: bool = True
    shortcut_bias: bool = True
    attn_bias: bool = True
    attn_drop: float = 0.
    proj_drop: float = 0.
    pool_type: str = 'avg2'
    rel_pos_type: str = 'bias'
    rel_pos_dim: int = 512  # for relative position types w/ MLP
    partition_ratio: int = 32
    window_size: Optional[Tuple[int, int]] = None
    init_values: Optional[float] = None
    act_layer: str = 'gelu'
    norm_layer: str = 'layernorm2d'
    norm_layer_cl: str = 'layernorm'
    norm_eps: float = 1e-6

    mbconv_cfg: moat_conv_config = moat_conv_config(
        remove_se=True,
    )

    def __post_init__(self):
        if self.window_size is not None:
            self.window_size = to_2tuple(self.window_size)


@dataclass
class moat_model_config:
    embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
    depths: Tuple[int, ...] = (2, 3, 5, 2)
    block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'MOAT', 'MOAT')
    stem_width: Union[int, Tuple[int, int]] = 64
    stem_bias: bool = True
    conv_cfg: moat_conv_config = moat_conv_config()
    moat_cfg: moat_block_cfg =  moat_block_cfg()
    weight_init: str = 'vit_eff'


model_cfgs = dict(
    tiny_moat_0=moat_model_config(
        embed_dim=(32, 64, 128, 256),
        depths=(2, 3, 7, 2),
        stem_width=32,
        block_type=('C', 'C', 'MOAT', 'MOAT'),
    ),

    tiny_moat_1=moat_model_config(
        embed_dim=(40, 80, 160, 320),
        depths=(2, 3, 7, 2),
        stem_width=40,
        block_type=('C', 'C', 'MOAT', 'MOAT'),
    ),

    tiny_moat_2=moat_model_config(
        embed_dim=(56, 112, 224, 448),
        depths=(2, 3, 7, 2),
        stem_width=56,
        block_type=('C', 'C', 'MOAT', 'MOAT'),
    ),

    tiny_moat_3=moat_model_config(
        embed_dim=(80, 160, 320, 640),
        depths=(2, 3, 7, 2),
        stem_width=80,
        block_type=('C', 'C', 'MOAT', 'MOAT'),
    ),

    moat_0=moat_model_config(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 3, 7, 2),
        stem_width=64,
        block_type=('C', 'C', 'MOAT', 'MOAT'),
    ),

    moat_1=moat_model_config(
        embed_dim=(96, 192, 384, 768),
        depths=(2, 6, 14, 2),
        stem_width=64,
        block_type=('C', 'C', 'MOAT', 'MOAT'),
    ),

    moat_2=moat_model_config(
        embed_dim=(128, 256, 512, 1024),
        depths=(2, 6, 14, 2),
        stem_width=128,
        block_type=('C', 'C', 'MOAT', 'MOAT'),
    ),

    moat_3=moat_model_config(
        embed_dim=(160, 320, 640, 1280),
        depths=(2, 12, 28, 2),
        stem_width=160,
        block_type=('C', 'C', 'MOAT', 'MOAT'),
    ),

    moat_4=moat_model_config(
        embed_dim=(256, 512, 1024, 2048),
        depths=(2, 12, 28, 2),
        stem_width=256,
        block_type=('C', 'C', 'MOAT', 'MOAT'),
    ),
)


class Attention2d(nn.Module):
    """ multi-head attention for 2D NCHW tensors"""
    def __init__(
            self,
            dim: int,
            dim_out: Optional[int] = None,
            dim_head: int = 32,
            bias: bool = True,
            expand_first: bool = True,
            rel_pos_cls: Callable = None,
            attn_drop: float = 0.,
            proj_drop: float = 0.
    ):
        super().__init__()
        dim_out = dim_out or dim
        dim_attn = dim_out if expand_first else dim
        self.num_heads = dim_attn // dim_head
        self.dim_head = dim_head
        self.scale = dim_head ** -0.5

        self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
        self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
        B, C, H, W = x.shape

        q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)

        attn = (q.transpose(-2, -1) @ k) * self.scale
        if self.rel_pos is not None:
            attn = self.rel_pos(attn)
        elif shared_rel_pos is not None:
            attn = attn + shared_rel_pos
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        gamma = self.gamma
        return x.mul_(gamma) if self.inplace else x * gamma


class LayerScale2d(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        gamma = self.gamma.view(1, -1, 1, 1)
        return x.mul_(gamma) if self.inplace else x * gamma


class Downsample2d(nn.Module):
    """ A downsample pooling module supporting several maxpool and avgpool modes
    * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
    * 'max2' - MaxPool2d w/ kernel_size = stride = 2
    * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
    * 'avg2' - AvgPool2d w/ kernel_size = stride = 2
    """

    def __init__(
            self,
            dim: int,
            dim_out: int,
            pool_type: str = 'avg2',
            bias: bool = True,
    ):
        super().__init__()
        assert pool_type in ('max', 'max2', 'avg', 'avg2')
        if pool_type == 'max':
            self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        elif pool_type == 'max2':
            self.pool = nn.MaxPool2d(2)  # kernel_size == stride == 2
        elif pool_type == 'avg':
            self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
        else:
            self.pool = nn.AvgPool2d(2)  # kernel_size == stride == 2

        if dim != dim_out:
            self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias)
        else:
            self.expand = nn.Identity()

    def forward(self, x):
        x = self.pool(x)  # spatial downsample
        x = self.expand(x)  # expand chs
        return x


def _init_transformer(module, name, scheme=''):
    if isinstance(module, (nn.Conv2d, nn.Linear)):
        if scheme == 'normal':
            nn.init.normal_(module.weight, std=.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == 'trunc_normal':
            trunc_normal_tf_(module.weight, std=.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == 'xavier_normal':
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        else:
            # vit like
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                if 'mlp' in name:
                    nn.init.normal_(module.bias, std=1e-6)
                else:
                    nn.init.zeros_(module.bias)


class MOATBlock2d(nn.Module):
    """ Transformer block with 2D downsampling
    '2D' NCHW tensor layout
    """

    def __init__(
            self,
            dim: int,
            dim_out: int,
            stride: int = 1,
            rel_pos_cls: Callable = None,
            cfg: moat_block_cfg = moat_block_cfg(),
            drop_path: float = 0.,
    ):
        super().__init__()
        self.mbconv = MbConvBlock(
            in_chs=dim,
            out_chs=dim_out,
            stride=stride,
            cfg=cfg.mbconv_cfg,
            drop_path=drop_path,
        )

        norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
        act_layer = get_act_layer(cfg.act_layer)

        self.norm1 = norm_layer(dim_out)
        self.attn = Attention2d(
            dim_out,
            dim_out,
            dim_head=cfg.dim_head,
            expand_first=cfg.expand_first,
            bias=cfg.attn_bias,
            rel_pos_cls=rel_pos_cls,
            attn_drop=cfg.attn_drop,
            proj_drop=cfg.proj_drop
        )
        self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def init_weights(self, scheme=''):
        named_apply(partial(_init_transformer, scheme=scheme), self)

    def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
        x = self.mbconv(x)
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
        return x


def _init_conv(module, name, scheme=''):
    if isinstance(module, nn.Conv2d):
        if scheme == 'normal':
            nn.init.normal_(module.weight, std=.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == 'trunc_normal':
            trunc_normal_tf_(module.weight, std=.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == 'xavier_normal':
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        else:
            # efficientnet like
            fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
            fan_out //= module.groups
            nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
            if module.bias is not None:
                nn.init.zeros_(module.bias)


def num_groups(group_size, channels):
    if not group_size:  # 0 or None
        return 1  # normal conv with 1 group
    else:
        # NOTE group_size == 1 -> depthwise conv
        assert channels % group_size == 0
        return channels // group_size


class MbConvBlock(nn.Module):
    """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
    """
    def __init__(
            self,
            in_chs: int,
            out_chs: int,
            stride: int = 1,
            dilation: Tuple[int, int] = (1, 1),
            cfg: moat_conv_config = moat_conv_config(),
            drop_path: float = 0.,
    ):
        super(MbConvBlock, self).__init__()
        norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps)
        mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio)
        groups = num_groups(cfg.group_size, mid_chs)

        if stride == 2:
            self.shortcut = Downsample2d(in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias)
        else:
            self.shortcut = nn.Identity()

        assert cfg.stride_mode in ('pool', '1x1', 'dw')
        stride_pool, stride_1, stride_2 = 1, 1, 1
        if cfg.stride_mode == 'pool':
            stride_pool, dilation_2 = stride, dilation[1]
        elif cfg.stride_mode == '1x1':
            stride_1, dilation_2 = stride, dilation[1]
        else:
            stride_2, dilation_2 = stride, dilation[0]

        self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act)
        if stride_pool > 1:
            self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type)
        else:
            self.down = nn.Identity()
        self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1)
        self.norm1 = norm_act_layer(mid_chs)

        self.conv2_kxk = create_conv2d(
            mid_chs, mid_chs, cfg.kernel_size, stride=stride_2, dilation=dilation_2, groups=groups)

        attn_kwargs = {}
        if isinstance(cfg.attn_layer, str):
            if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca':
                attn_kwargs['act_layer'] = cfg.attn_act_layer
                attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs))

        if cfg.remove_se:
            self.se_early = None
            self.norm2 = norm_act_layer(mid_chs)
            self.se = None
        else:
            # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2)
            if cfg.attn_early:
                self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs)
                self.norm2 = norm_act_layer(mid_chs)
                self.se = None
            else:
                self.se_early = None
                self.norm2 = norm_act_layer(mid_chs)
                self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs)

        self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def init_weights(self, scheme=''):
        named_apply(partial(_init_conv, scheme=scheme), self)

    def forward(self, x):
        shortcut = self.shortcut(x)
        x = self.pre_norm(x)
        x = self.down(x)

        # 1x1 expansion conv & norm-act
        x = self.conv1_1x1(x)
        x = self.norm1(x)

        # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act
        x = self.conv2_kxk(x)
        if self.se_early is not None:
            x = self.se_early(x)
        x = self.norm2(x)
        if self.se is not None:
            x = self.se(x)

        # 1x1 linear projection to output width
        x = self.conv3_1x1(x)
        x = self.drop_path(x) + shortcut
        return x


def get_rel_pos_cls(cfg: moat_block_config, window_size):
    rel_pos_cls = None
    if cfg.rel_pos_type == 'mlp':
        rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim)
    elif cfg.rel_pos_type == 'bias':
        rel_pos_cls = partial(RelPosBias, window_size=window_size)
    return rel_pos_cls


class MOATStage(nn.Module):
    def __init__(
            self,
            in_chs: int,
            out_chs: int,
            stride: int = 2,
            depth: int = 4,
            feat_size: Tuple[int, int] = (14, 14),
            block_types: Union[str, Tuple[str]] = 'C',
            conv_cfg: moat_conv_config = moat_conv_config(),
            moat_cfg: moat_block_cfg = moat_block_cfg(),
            drop_path: Union[float, List[float]] = 0.,
    ):
        super().__init__()
        self.grad_checkpointing = False

        block_types = extend_tuple(block_types, depth)
        blocks = []
        for i, t in enumerate(block_types):
            block_stride = stride if i == 0 else 1
            assert t in ('C', 'MOAT')
            if t == 'C':
                conv_cls = MbConvBlock
                blocks += [conv_cls(
                    in_chs,
                    out_chs,
                    stride=block_stride,
                    cfg=conv_cfg,
                    drop_path=drop_path[i],
                )]
            elif t == 'MOAT':
                rel_pos_cls = get_rel_pos_cls(moat_cfg, feat_size)
                blocks += [MOATBlock2d(
                    in_chs,
                    out_chs,
                    stride=block_stride,
                    rel_pos_cls=rel_pos_cls,
                    cfg=moat_cfg,
                    drop_path=drop_path[i],
                )]
            in_chs = out_chs
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)
        return x


class Stem(nn.Module):

    def __init__(
            self,
            in_chs: int,
            out_chs: int,
            kernel_size: int = 3,
            act_layer: str = 'gelu',
            norm_layer: str = 'batchnorm2d',
            norm_eps: float = 1e-5,
    ):
        super().__init__()
        if not isinstance(out_chs, (list, tuple)):
            out_chs = to_2tuple(out_chs)

        norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
        self.out_chs = out_chs[-1]
        self.stride = 2

        self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2)
        self.norm1 = norm_act_layer(out_chs[0])
        self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1)

    def init_weights(self, scheme=''):
        named_apply(partial(_init_conv, scheme=scheme), self)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.conv2(x)
        return x


class MOAT(nn.Module):

    def __init__(
            self,
            cfg: moat_model_config,
            img_size: Union[int, Tuple[int, int]] = 224,
            in_chans: int = 3,
            num_classes: int = 1000,
            global_pool: str = 'avg',
            drop_rate: float = 0.,
            drop_path_rate: float = 0.
    ):
        super().__init__()
        img_size = to_2tuple(img_size)
        self.num_classes = num_classes
        self.global_pool = global_pool
        self.num_features = cfg.embed_dim[-1]
        self.embed_dim = cfg.embed_dim
        self.drop_rate = drop_rate
        self.grad_checkpointing = False

        self.stem = Stem(
            in_chs=in_chans,
            out_chs=cfg.stem_width,
            act_layer=cfg.conv_cfg.act_layer,
            norm_layer=cfg.conv_cfg.norm_layer,
            norm_eps=cfg.conv_cfg.norm_eps,
        )

        stride = self.stem.stride
        feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])

        num_stages = len(cfg.embed_dim)
        assert len(cfg.depths) == num_stages
        dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
        in_chs = self.stem.out_chs
        stages = []
        for i in range(num_stages):
            stage_stride = 2
            out_chs = cfg.embed_dim[i]
            feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size])
            stages += [MOATStage(
                in_chs,
                out_chs,
                depth=cfg.depths[i],
                block_types=cfg.block_type[i],
                conv_cfg=cfg.conv_cfg,
                moat_cfg=cfg.moat_cfg,
                feat_size=feat_size,
                drop_path=dpr[i],
            )]
            stride *= stage_stride
            in_chs = out_chs
        self.stages = nn.Sequential(*stages)

        final_norm_layer = get_norm_layer(cfg.moat_cfg.norm_layer)
        self.norm = final_norm_layer(self.num_features, eps=cfg.moat_cfg.norm_eps)

        # Classifier head
        self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)

        # Weight init (default PyTorch init works well for AdamW if scheme not set)
        assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff')
        if cfg.weight_init:
            named_apply(partial(self._init_weights, scheme=cfg.weight_init), self)

    def _init_weights(self, module, name, scheme=''):
        if hasattr(module, 'init_weights'):
            try:
                module.init_weights(scheme=scheme)
            except TypeError:
                module.init_weights()

    @torch.jit.ignore
    def no_weight_decay(self):
        return {
            k for k, _ in self.named_parameters()
            if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}

    @torch.jit.ignore
    def group_matcher(self, coarse=False):
        matcher = dict(
            stem=r'^stem',  # stem and embed
            blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
        )
        return matcher

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        for s in self.stages:
            s.grad_checkpointing = enable

    @torch.jit.ignore
    def get_classifier(self):
        return self.head.fc

    def reset_classifier(self, num_classes, global_pool=None):
        self.num_classes = num_classes
        if global_pool is None:
            global_pool = self.head.global_pool.pool_type
        self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)

    def forward_features(self, x):
        x = self.stem(x)
        x = self.stages(x)
        x = self.norm(x)
        return x

    def forward_head(self, x, pre_logits: bool = False):
        return self.head(x, pre_logits=pre_logits)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


def _create_moat(variant, cfg_variant=None, pretrained=False, **kwargs):
    return build_model_with_cfg(
        MOAT, variant, pretrained,
        model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
        feature_cfg=dict(flatten_sequential=True),
        **kwargs)


@register_model
def moat_0(pretrained=False, **kwargs):
    return _create_moat('moat_0', pretrained=pretrained, **kwargs)

@register_model
def moat_1(pretrained=False, **kwargs):
    return _create_moat('moat_1', pretrained=pretrained, **kwargs)

@register_model
def moat_2(pretrained=False, **kwargs):
    return _create_moat('moat_2', pretrained=pretrained, **kwargs)

@register_model
def moat_3(pretrained=False, **kwargs):
    return _create_moat('moat_3', pretrained=pretrained, **kwargs)

@register_model
def moat_4(pretrained=False, **kwargs):
    return _create_moat('moat_4', pretrained=pretrained, **kwargs)

@register_model
def tiny_moat_0(pretrained=False, **kwargs):
    return _create_moat('tiny_moat_0', pretrained=pretrained, **kwargs)

@register_model
def tiny_moat_1(pretrained=False, **kwargs):
    return _create_moat('tiny_moat_1', pretrained=pretrained, **kwargs)

@register_model
def tiny_moat_2(pretrained=False, **kwargs):
    return _create_moat('tiny_moat_2', pretrained=pretrained, **kwargs)

@register_model
def tiny_moat_3(pretrained=False, **kwargs):
    return _create_moat('tiny_moat_3', pretrained=pretrained, **kwargs)
