import torch
from timm.models.registry import register_model
from models import *


def get_deit_rpe_config():
    from irpe import get_rpe_config as _get_rpe_config
    rpe_config = _get_rpe_config(
        ratio=1.9,
        method="product",
        mode='ctx',
        shared_head=True,
        skip=0,
        rpe_on='k',
    )
    return rpe_config

def get_repeated_shcedule(depth):
    return {
        'norm1': [[depth], [True]],
        'norm2': [[depth], [True]],
        'attn_rpe': [[depth], [True]],
        'attn_qkv': [[depth], [True]], 
        'attn_transform1': [[depth], [True]],
        'attn_transform2': [[depth], [True]],
        'attn_proj': [[depth], [True]],
        'mlp_fc1': [[depth], [True]],
        'mlp_fc2': [[depth], [True]],
    }

@register_model
def mini_deit_tiny_patch16_224(pretrained=False, **kwargs):
    return deit_tiny_patch16_224(pretrained=pretrained,
                                 rpe_config=get_deit_rpe_config(),
                                 use_cls_token=False,
                                 repeated_times_schedule=get_repeated_shcedule(12),
                                 **kwargs)


@register_model
def mini_deit_small_patch16_224(pretrained=False, repeated_time=1, **kwargs):
    return deit_small_patch16_224(pretrained=pretrained,
                                  rpe_config=get_deit_rpe_config(),
                                  use_cls_token=False,
                                  repeated_times_schedule=get_repeated_shcedule(depth=repeated_time),
                                  **kwargs)

@register_model
def mini_deit_base_patch16_224(pretrained=False, repeated_time=12, **kwargs):
    return deit_base_patch16_224(pretrained=pretrained,
                                 repeated_time=repeated_time,
                                 rpe_config=get_deit_rpe_config(),
                                 use_cls_token=False,
                                 repeated_times_schedule=get_repeated_shcedule(repeated_time),
                                 **kwargs)

@register_model
def mini_deit_base_patch16_384(pretrained=False, **kwargs):
    return deit_base_patch16_384(pretrained=pretrained,
                                 rpe_config=get_deit_rpe_config(),
                                 use_cls_token=False,
                                 repeated_times_schedule=get_repeated_shcedule(12),
                                 **kwargs)

@register_model
def mini_deit_micro_patch16_224(pretrained=False, **kwargs):
    return deit_micro_patch16_224(pretrained=pretrained,
                                 rpe_config=get_deit_rpe_config(),
                                 use_cls_token=False,
                                 repeated_times_schedule=get_repeated_shcedule(12),
                                 **kwargs)

@register_model
def mini_deit_large_patch16_224(pretrained=False, **kwargs):
    return deit_large_patch16_224(pretrained=pretrained,
                                 rpe_config=get_deit_rpe_config(),
                                 use_cls_token=False,
                                 repeated_times_schedule=get_repeated_shcedule(12),
                                 **kwargs)