from functools import partial

import timm
from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn, _cfg, default_cfgs

import torch
from bypass.core.activation import TrivialActivationForBypass,TrivialActivationForDx2,ActivationForDx2

class BypassViT_timm(VisionTransformer):
    bypass_attn_cls = TrivialActivationForBypass
    bypass_activation_cls = ActivationForDx2
    def __init__(self, *args,**kwargs):
        super().__init__(*args,**kwargs)


        for name,module in self.named_modules():
            parent_module_name = parent_module_name = '.'.join(name.split('.')[:-1])
            node_name = name.split('.')[-1]
            if node_name == 'qkv':
                parent_module:torch.nn.Module = self.get_submodule(parent_module_name)
                new_module = self.bypass_attn_cls(module.weight.shape[0],module,channel_last=True) #Note: you can use TrivialActivationForDx2 for alternative
                parent_module.register_module(node_name,new_module)
            if name.endswith('mlp.act'):
                parent_module:torch.nn.Module = self.get_submodule(parent_module_name)
                new_module = self.bypass_activation_cls(parent_module.fc1.out_features, module,channel_last=True)
                parent_module.register_module(node_name,new_module)

class imagenet_DeiTBase(BypassViT_timm):
    model_id = 'deit_base_patch16_224'
    bypass_attn_cls = TrivialActivationForBypass
    bypass_activation_cls = ActivationForDx2
    def __init__(self,pretrained=True,**kwargs):

        super().__init__(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), **kwargs)

        # self.default_cfg = _cfg()
        self.default_cfg = default_cfgs[self.model_id]
        if pretrained:
            checkpoint = torch.hub.load_state_dict_from_url(
                url=self.default_cfg['url'],
                map_location="cpu", check_hash=True
            )
            checkpoint_model = checkpoint["model"]
            checkpoint_model = {(k.replace('qkv','qkv.activation')):v for k,v in checkpoint_model.items()}
            self.load_state_dict(checkpoint_model,strict=False)
        return None


if __name__ == '__main__':
    model = imagenet_DeiTBase(pretrained=True)

    print(1)