import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.resnet import resnet26d, resnet50d
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg, default_cfgs,\
     PatchEmbed
try:
    from timm.models.vision_transformer import HybridEmbed
except ImportError:
    # for higher version of timm
    from timm.models.vision_transformer_hybrid import HybridEmbed

from irpe import iRPE, iRPE_Cross, METHOD

from modules_with_LE import LELayerNorm, LELinear
import pickle

class RepeatedModuleList(nn.Module):
    def __init__(self, iblock, Layer_num: int, sublist_repeated_times: list, sublist_with_k: list, Sublist_instance, Default_instance=None, *args, **kwargs):
        super().__init__()
        assert Layer_num == sum(sublist_repeated_times)
        assert len(sublist_repeated_times) == len(sublist_with_k)
        self.Sublist_num = len(sublist_repeated_times)
        self.sublist_repeated_times = sublist_repeated_times
        self.sublist_with_k = sublist_with_k
        
        modules = []
        assert (Default_instance is not None) or (False not in self.sublist_with_k)
        for i, (r, t) in enumerate(zip(self.sublist_repeated_times, self.sublist_with_k)):
            if (r==0 and Default_instance is not None) or not t:
                modules.append(Default_instance(*args, **kwargs))
            else:
                if issubclass(Sublist_instance, iRPE) or issubclass(Sublist_instance, iRPE_Cross):  
                    modules.append(Sublist_instance(*args, **kwargs))
                    def set_layer_id(m):
                        m._ilayer = 1.0
                    modules[-1].apply(set_layer_id)
                else:
                    modules.append(Sublist_instance(iblock=iblock, *args, **kwargs))
                
        self.instances = nn.ModuleList(modules)

    def forward(self, *args, **kwargs):
        if len(self.instances) == 1:
            r =0
        else:
            r = self._layer_id
        return self.instances[r](*args, **kwargs)
    def layer_id_2_repeated_times(self): 
        def get_block_id():
            s, i  = 0, 0
            while s<=self._layer_id:
                s += self.sublist_repeated_times[i]
                i+=1
            return i-1
        self._block_id = get_block_id()
        def set_repeated_id_fn(m): 
            m._repeated_id = self._layer_id - sum(self.sublist_repeated_times[:self._block_id])
            m._ilayer = m._repeated_id/float(self.sublist_repeated_times[self._block_id])
        self.apply(set_repeated_id_fn)

    def __repr__(self):
        msg = super().__repr__()
        return msg


class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, device, iblock, 
                 in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.,
                 repeated_times_schedule=None, num_layers=None, power = 1, power_method = 'power', 
                 use_power_list=[0,1], learngene_d=False, only_rel_last=False, constraint_d=False,
                 gene_dict = None, LELinear_trunc_normal_std = -1.0):
        super().__init__()
        assert isinstance(repeated_times_schedule, dict)
        self.repeated_times_schedule = repeated_times_schedule
        self.num_layers = num_layers
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)

        assert sum(self.repeated_times_schedule['mlp_fc1'][0]) == num_layers
        self.fc1 = RepeatedModuleList(iblock=iblock, Layer_num=num_layers, sublist_repeated_times=self.repeated_times_schedule['mlp_fc1'][0],
                                      sublist_with_k=self.repeated_times_schedule['mlp_fc1'][1],
                                      Sublist_instance=LELinear, Default_instance=nn.Linear,
                                      serve_module_name = 'LG_fc1', in_features=in_features, out_features=hidden_features, bias=bias[0],
                                      power = power, power_method = power_method, gene_dict = gene_dict, use_power_list=use_power_list, 
                                      learngene_d=learngene_d, only_rel_last=only_rel_last,constraint_d=constraint_d, LELinear_trunc_normal_std = LELinear_trunc_normal_std, 
                                      device=device)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        assert sum(self.repeated_times_schedule['mlp_fc2'][0]) == num_layers
        self.fc2 = RepeatedModuleList(iblock=iblock, Layer_num=num_layers, sublist_repeated_times=self.repeated_times_schedule['mlp_fc2'][0],
                                      sublist_with_k=self.repeated_times_schedule['mlp_fc2'][1],
                                      Sublist_instance=LELinear, Default_instance=nn.Linear,
                                      serve_module_name = 'LG_fc2', in_features=hidden_features, out_features=out_features, bias=bias[1],
                                      power = power, power_method = power_method, gene_dict = gene_dict, 
                                      use_power_list=use_power_list, learngene_d=learngene_d, only_rel_last=only_rel_last, constraint_d=constraint_d, 
                                      LELinear_trunc_normal_std = LELinear_trunc_normal_std,
                                      device=device)
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

def build_rpe(config, iblock, head_dim, num_heads, num_layers, sublist_repeated_times, device):
    if config is None:
        return None, None, None
    rpes = [config.rpe_q, config.rpe_k, config.rpe_v]
    transposeds = [True, True, False]
    iblcok = iblock

    def _build_single_rpe(rpe, transposed):
        if rpe is None:
            return None

        rpe_cls = iRPE if rpe.method != METHOD.CROSS else iRPE_Cross
        return RepeatedModuleList(iblock=iblock, Layer_num=num_layers, sublist_repeated_times=sublist_repeated_times,
                                        sublist_with_k=[True for _ in range(len(sublist_repeated_times))],
                                        Sublist_instance=rpe_cls, Default_instance=None,
                                        head_dim=head_dim,
                                        num_heads=1 if rpe.shared_head else num_heads,
                                        mode=rpe.mode,
                                        method=rpe.method,
                                        transposed=transposed,
                                        num_buckets=rpe.num_buckets,
                                        rpe_config=rpe,
                                    )
    return [_build_single_rpe(rpe, transposed)
            for rpe, transposed in zip(rpes, transposeds)]


class MiniAttention(nn.Module):
    '''
    Attention with image relative position encoding
    '''

    def __init__(self, device, iblock,
                 dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., rpe_config=None,
                 repeated_times_schedule=None, num_layers=None, 
                 power=1, power_method='power', use_power_list=[0,1], learngene_d=False, only_rel_last=False, constraint_d=False, 
                 gene_dict = None, LELinear_trunc_normal_std = -1.0):
        super().__init__()
        assert isinstance(repeated_times_schedule, dict)
        self.repeated_times_schedule = repeated_times_schedule
        self.num_layers = num_layers
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        assert sum(self.repeated_times_schedule['attn_qkv'][0]) == num_layers
        self.qkv = RepeatedModuleList(iblock=iblock, Layer_num=num_layers, 
                                      sublist_repeated_times=self.repeated_times_schedule['attn_qkv'][0],
                                      sublist_with_k=self.repeated_times_schedule['attn_qkv'][1],
                                      Sublist_instance=LELinear, Default_instance=nn.Linear,
                                      serve_module_name = 'LG_qkv', in_features=dim, out_features=dim*3, bias=qkv_bias,
                                      power = power, power_method = power_method, 
                                      gene_dict = gene_dict, use_power_list=use_power_list, 
                                      learngene_d=learngene_d, only_rel_last=only_rel_last, constraint_d=constraint_d, LELinear_trunc_normal_std = LELinear_trunc_normal_std,
                                      device=device)
        self.attn_drop = nn.Dropout(attn_drop)
        assert sum(self.repeated_times_schedule['attn_proj'][0]) == num_layers
        self.proj = RepeatedModuleList(iblock=iblock, Layer_num=num_layers, sublist_repeated_times=self.repeated_times_schedule['attn_proj'][0],
                                       sublist_with_k=self.repeated_times_schedule['attn_proj'][1],
                                       Sublist_instance=LELinear, Default_instance=nn.Linear,
                                       serve_module_name = 'LG_attn_proj', in_features=dim, out_features=dim,
                                       power = power, power_method = power_method, 
                                       gene_dict = gene_dict, use_power_list=use_power_list, 
                                       learngene_d=learngene_d, only_rel_last=only_rel_last, constraint_d=constraint_d, LELinear_trunc_normal_std = LELinear_trunc_normal_std,
                                       device=device)
        self.proj_drop = nn.Dropout(proj_drop)

        # image relative position encoding 
        assert sum(self.repeated_times_schedule['attn_rpe'][0]) == num_layers
        rpe_q, rpe_k, rpe_v = build_rpe(rpe_config,
                                        iblock=iblock, head_dim=head_dim, num_heads=num_heads, num_layers=num_layers,
                                        sublist_repeated_times=self.repeated_times_schedule['attn_rpe'][0],
                                        device=device)
        if rpe_q is not None:
            self.rpe_q = rpe_q
        else:
            self.rpe_q = None
        if rpe_k is not None:
            self.rpe_k = rpe_k
        else:
            self.rpe_k = None
        if rpe_v is not None:
            self.rpe_v = rpe_v
        else:
            self.rpe_v = None

        
        assert sum(self.repeated_times_schedule['attn_transform1'][0]) == num_layers
        self.LElinearP_w = RepeatedModuleList(iblock=iblock, Layer_num=num_layers, sublist_repeated_times=self.repeated_times_schedule['attn_transform1'][0],
                                        sublist_with_k=self.repeated_times_schedule['attn_transform1'][1],
                                        Sublist_instance=LELinear, Default_instance=nn.Linear,
                                        serve_module_name = 'LG_P_w', in_features=num_heads, out_features=num_heads, bias = False,
                                        power = power, power_method = power_method, 
                                        gene_dict = gene_dict, use_power_list=use_power_list, 
                                        learngene_d=learngene_d, only_rel_last=only_rel_last, constraint_d=constraint_d, LELinear_trunc_normal_std = LELinear_trunc_normal_std,
                                        device=device)
        assert sum(self.repeated_times_schedule['attn_transform2'][0]) == num_layers
        self.LElinearP_l = RepeatedModuleList(iblock=iblock, Layer_num=num_layers, sublist_repeated_times=self.repeated_times_schedule['attn_transform2'][0],
                                        sublist_with_k=self.repeated_times_schedule['attn_transform2'][1],
                                        Sublist_instance=LELinear, Default_instance=nn.Linear,
                                        serve_module_name = 'LG_P_l', in_features=num_heads, out_features=num_heads, bias = False,
                                        power = power, power_method = power_method, 
                                        gene_dict=gene_dict, use_power_list=use_power_list, 
                                        learngene_d=learngene_d, only_rel_last=only_rel_last, constraint_d=constraint_d, LELinear_trunc_normal_std = LELinear_trunc_normal_std,
                                        device=device)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
        q *= self.scale

        attn = (q @ k.transpose(-2, -1))

        # image relative position on keys
        if self.rpe_k is not None:
            attn += self.rpe_k(q)

        # image relative position on queries
        if self.rpe_q is not None:
            attn += self.rpe_q(k * self.scale).transpose(2, 3)
        if self.LElinearP_l is not None:
            attn = attn.permute(0,2,3,1)
            attn = self.LElinearP_l(attn)
            attn = attn.permute(0,3,1,2)

        attn = attn.softmax(dim=-1)

        if self.LElinearP_w is not None:
            attn = attn.permute(0,2,3,1)
            attn = self.LElinearP_w(attn)
            attn = attn.permute(0,3,1,2)

        attn = self.attn_drop(attn)
        out = attn @ v

        # image relative position on values
        if self.rpe_v is not None:
            out += self.rpe_v(attn)

        x = out.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class MiniBlock(nn.Module):

    def __init__(self, block_id, device, depth,
                 dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_paths=[0.], act_layer=nn.GELU, rpe_config=None, 
                 repeated_times_schedule=None, num_layers = 1, power = 1, power_method = 'power', 
                 use_power_list=[0,1], learngene_d=False, only_rel_last=False, constraint_d=False, gene_dict = None, LELinear_trunc_normal_std = -1.0, ):
        super().__init__()
        assert isinstance(repeated_times_schedule, dict)
        self.repeated_times_schedule = repeated_times_schedule
        self._iblock = block_id/float(depth)

        assert sum(self.repeated_times_schedule['norm1'][0]) == num_layers
        self.norm1 = RepeatedModuleList(iblock=self._iblock, Layer_num=num_layers, 
                                        sublist_repeated_times=self.repeated_times_schedule['norm1'][0],
                                        sublist_with_k=self.repeated_times_schedule['norm1'][1],
                                        Sublist_instance=LELayerNorm, Default_instance=nn.LayerNorm, 
                                        serve_module_name = 'LG_norm1', normalized_shape = dim, 
                                        power = power, power_method = power_method, 
                                        gene_dict = gene_dict, use_power_list=use_power_list, learngene_d=learngene_d, only_rel_last=only_rel_last, constraint_d=constraint_d, 
                                        device=device)
        assert sum(self.repeated_times_schedule['norm2'][0]) == num_layers
        self.norm2 = RepeatedModuleList(iblock=self._iblock, Layer_num=num_layers, 
                                        sublist_repeated_times=self.repeated_times_schedule['norm2'][0],
                                        sublist_with_k=self.repeated_times_schedule['norm2'][1],
                                        Sublist_instance=LELayerNorm, Default_instance=nn.LayerNorm, 
                                        serve_module_name = 'LG_norm2', normalized_shape = dim, 
                                        power = power, power_method = power_method, 
                                        gene_dict = gene_dict, use_power_list=use_power_list, learngene_d=learngene_d, only_rel_last=only_rel_last, constraint_d=constraint_d, 
                                        device=device)

        self.attn = MiniAttention(
            dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, rpe_config=rpe_config,
            repeated_times_schedule=repeated_times_schedule, num_layers = num_layers, 
            power = power, power_method = power_method, gene_dict = gene_dict, use_power_list=use_power_list, 
            learngene_d=learngene_d, only_rel_last=only_rel_last, constraint_d=constraint_d, LELinear_trunc_normal_std = LELinear_trunc_normal_std,
            device=device, iblock=self._iblock)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_paths = nn.ModuleList([DropPath(drop_path) if drop_path > 0. else nn.Identity() for drop_path in drop_paths])
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, 
            repeated_times_schedule=repeated_times_schedule, num_layers = num_layers, 
            power = power, power_method = power_method, 
            gene_dict = gene_dict, use_power_list=use_power_list, learngene_d=learngene_d, only_rel_last=only_rel_last, constraint_d=constraint_d, LELinear_trunc_normal_std = LELinear_trunc_normal_std,
            device=device, iblock=self._iblock)
        
        self.power = power
        self.power_method = power_method

    def dump_gene_dict(self, save_path):
        gene_dict = {}
        gene_dict.update(self.norm1.instances[0].build_weight_dict())
        gene_dict.update(self.norm2.instances[0].build_weight_dict())
        gene_dict.update(self.attn.qkv.instances[0].build_weight_dict())
        gene_dict.update(self.attn.proj.instances[0].build_weight_dict())
        gene_dict.update(self.attn.LElinearP_w.instances[0].build_weight_dict())
        gene_dict.update(self.attn.LElinearP_l.instances[0].build_weight_dict())
        gene_dict.update(self.mlp.fc1.instances[0].build_weight_dict())
        gene_dict.update(self.mlp.fc2.instances[0].build_weight_dict())

        # f_save = open(self.power_method+'_'+str(self.power)+'_gene.pkl', 'wb')
        f_save = open(save_path, 'wb')
        pickle.dump(gene_dict, f_save)
        f_save.close()


    def forward(self, x):
        drop_path = self.drop_paths[self._layer_id]
        x = x + drop_path(self.attn(self.norm1(x)))
        x = x + drop_path(self.mlp(self.norm2(x)))
        return x



class RepeatedMiniBlock(nn.Module):
    def __init__(self, block_id, repeated_times: int, repeated_times_schedule: dict, **kwargs):
        super().__init__()
        self.num_layers = repeated_times
        self.block = MiniBlock(block_id=block_id, repeated_times_schedule=repeated_times_schedule, 
                               num_layers = repeated_times, **kwargs) 

        def set_num_layers_fn(m):
            m._num_layers = repeated_times
        self.apply(set_num_layers_fn)

    def forward(self, x):
        for i, t in enumerate(range(self.num_layers)):
            def set_layer_id(m):
                m._layer_id = i
                if hasattr(m, 'layer_id_2_repeated_times'):
                    m.layer_id_2_repeated_times()
            self.block.apply(set_layer_id) 
            x = self.block(x)
        # x = self.blocks(x)
        return x

    def __repr__(self):
        msg = super().__repr__()
        msg += f'(num_layers={self.num_layers})'
        return msg



class VisionTransformer_d(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
                           and image relative position encoding
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, repeated_time=1,
                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., hybrid_backbone=None, norm_layer=LELayerNorm, rpe_config=None,
                 use_cls_token=True, repeated_times_schedule=None, use_transform=False, **kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        assert isinstance(repeated_times_schedule, dict)
      
        if hybrid_backbone is not None:
            self.patch_embed = HybridEmbed(
                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
        else:
            self.patch_embed = PatchEmbed(
                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        if use_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        else:
            self.cls_token = None
        pos_embed_len = 1 + num_patches if use_cls_token else num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_len, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, kwargs['depth'])]  # stochastic depth decay rule

        p, pm = (kwargs['power'], kwargs['power_method']) if 'power' in kwargs.keys() else (1, 'power')

        gene_dict = None
        if kwargs['gene_dict_path'] != '':
            f_read = open(kwargs['gene_dict_path'], 'rb')
            gene_dict = pickle.load(f_read)
            f_read.close()
        LELinear_trunc_normal_std = kwargs['LELinear_trunc_normal_std']

        depth = kwargs['depth'] // repeated_time

        print("==============use_power=====")
        print(kwargs['use_power_list'])
        assert not (kwargs['only_rel_last'] and kwargs['constraint_d'])
        block_kwargs = dict(
                depth=depth,
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate,rpe_config=rpe_config,
                power = p, power_method = pm, gene_dict=gene_dict, use_power_list=kwargs['use_power_list'], learngene_d=kwargs['learngene_d'], 
                only_rel_last=kwargs['only_rel_last'], constraint_d=kwargs['constraint_d'], LELinear_trunc_normal_std=LELinear_trunc_normal_std,
                device=kwargs['device'])
        
        blocks = []
        for i in range(depth):
            if repeated_time > 1:
                block = RepeatedMiniBlock(
                    block_id=i,
                    repeated_times=repeated_time,
                    repeated_times_schedule=repeated_times_schedule,
                    drop_paths=dpr,
                    **block_kwargs,
                )
            else:
                block = MiniBlock(block_id=i, repeated_times_schedule=repeated_times_schedule, drop_paths=[dpr[i]], **block_kwargs)
            blocks.append(block)
        self.blocks = nn.ModuleList(blocks)

        self.norm = nn.LayerNorm(embed_dim)

        def set_layer_id(m):
            m._layer_id = 0
        self.apply(set_layer_id)

        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
        #self.repr = nn.Linear(embed_dim, representation_size)
        #self.repr_act = nn.Tanh()

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        if not use_cls_token:
            self.avgpool = nn.AdaptiveAvgPool1d(1)
        else:
            self.avgpool = None
        
        trunc_normal_(self.pos_embed, std=.02)
        if self.cls_token is not None:
            trunc_normal_(self.cls_token, std=.02)
        if not kwargs['learngene_d']:
            if not kwargs['constraint_d']:
                self.apply(self._init_weights)
                self.apply(self._init_custom_weights)

    def dump_gene_dict(self, save_path):
        self.blocks.block.dump_gene_dict(save_path)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            # trunc_normal_(m.weight, std=.008)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, LELayerNorm) or isinstance(m, nn.LayerNorm):
            m.reset_parameters()
        elif isinstance(m, LELinear):
            m.reset_parameters()

    def _init_custom_weights(self, m):
        if hasattr(m, 'init_weights'):
            m.init_weights()

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        if self.cls_token is not None:
            cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
            x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        if self.cls_token is not None:
            return x[:, 0]
        else:
            return x

    def forward(self, x):
        x = self.forward_features(x)
        if self.avgpool is not None:
            x = self.avgpool(x.transpose(1, 2))  # (B, C, 1)
            x = torch.flatten(x, 1)
        x = self.head(x)
        return x

    def export_gene(): 
        return