from bypass.core.activation import ActivationForBypass
from bypass.core.models.attn_dev.pruned_attn import PrunedAttention_ind, PrunedAttention_spatial

import torch_pruning as tp
from torch_pruning import ops
from torch_pruning.pruner import BasePruningFunc
import torch
import timm
import torch.nn as nn
def prune_tp(module:torch.nn.Module,prune_indices:torch.Tensor,target='out',mask_skip=False,attn=False):
    
    if isinstance(module,ActivationForBypass):
        pruning_func = BypassActivationPruner()
    elif isinstance(module,timm.models.vision_transformer.Attention) or attn==True:
        pruning_func = AttentionNeuronPruner()
    elif isinstance(module, ops.TORCH_PARAMETER):
        pruning_func = ParameterPruner()
    else:
        try:
            pruning_func = tp.function.PrunerBox[ops.module2type(module)]
        except:
            print(f'{type(module)} is not supported by torch_pruner. Possibly it is not a leaf node')
            return module
    if target == 'out':
        module = pruning_func.prune_out_channels(module,prune_indices.tolist())
    elif target == 'in':
        module = pruning_func.prune_in_channels(module,prune_indices.tolist())
    elif target == 'mid':
        module = pruning_func.prune_mid_channels(module,prune_indices.tolist())
    else:
        raise NotImplementedError()
    if isinstance(module,ActivationForBypass) and hasattr(module,'activation') and hasattr(module.activation,f'{target}_mask'):
        pass
    elif isinstance(module, ops.TORCH_PARAMETER):
        pass
    else:
        # delattr(module,f'{target}_mask')
        module.register_buffer(f'{target}_mask',prune_indices)
    return module

class ParameterPruner(BasePruningFunc):
    TARGET_MODULES = ops.TORCH_PARAMETER
    def __init__(self, pruning_dim=-1):
        super().__init__(pruning_dim=pruning_dim)
        
    def prune_out_channels(self, tensor, idxs: list) -> nn.Module:
        keep_idxs = list(set(range(tensor.data.shape[self.pruning_dim])) - set(idxs))
        keep_idxs.sort()
        pruned_parameter = self._prune_parameter_and_grad(tensor, keep_idxs, self.pruning_dim)
        return pruned_parameter

    prune_in_channels = prune_out_channels

    def get_out_channels(self, parameter):
        return parameter.shape[self.pruning_dim]

    def get_in_channels(self, parameter):
        return parameter.shape[self.pruning_dim]


def prune_loader_tp(module:torch.nn.Module,state_dict:dict):
    new_state_dict= {}
    for name, weight in state_dict.items():
        if name.endswith('out_mask') or name.endswith('in_mask'):
            continue
        new_state_dict[name] = weight
        parent_module_name = '.'.join(name.split('.')[:-1])
        parent_module = module.get_submodule(parent_module_name)
        attr_name = name.split('.')[-1]
        model_weight =  getattr(parent_module,attr_name)
        if model_weight is None:
            setattr(parent_module,attr_name,torch.nn.Parameter(data =weight.data,requires_grad=True))
            model_weight = getattr(parent_module,attr_name)
        # try:
        #     model_weight=module.get_parameter(name)
        # except AttributeError:
        #     model_weight = module.get_buffer(name)
        if isinstance(parent_module,ActivationForBypass) and parent_module.status==0: # already pruned
            if f'{parent_module_name}.delta2' in state_dict:
                parent_module.embed()
                bypass_status = 1
            elif f'{parent_module_name}.delta' in state_dict: # 보통 항상 true. pretrained에서 바로 전환됐을때만 false
                parent_module.embed()
                parent_module.proj()
                bypass_status = 2
            else:
                bypass_status = 0
            if hasattr(module,'bypass_units'): # Bypassing model
                parent_group = [group for k,group in module.bypass_units.items() if group.D == parent_module][0]
                parent_group.status = bypass_status
                # [module.bypass_layers[i]['D'] for i in range(len(module.bypass_layers))].index(parent_module)
                # module.bypass_layers[d_layer_idx]['bypass']=bypass_flag
        if model_weight.shape == weight.shape:
            continue
        if model_weight.shape[0]!=weight.shape[0]: # out_prune
            # parent_module = module.get_submodule('.'.join(name.split('.')[:-1]))
            prune_nodes = model_weight.shape[0]-weight.shape[0]
            prune_tp(parent_module,torch.tensor(range(prune_nodes)),'out')
            # del(parent_module.out_mask)
            
            # continue
        if model_weight.dim() > 1 and model_weight.shape[1]!=weight.shape[1]: # in_prune
            # parent_module = module.get_submodule('.'.join(name.split('.')[:-1]))
            prune_nodes = model_weight.shape[1]-weight.shape[1]
            prune_tp(parent_module,torch.tensor(range(prune_nodes)),'in')
            # del(parent_module.in_mask)
            # if isinstance(parent_module,ActivationForBypass) and parent_module.status==0: # already pruned
            #     if f'{parent_module_name}.delta2' in state_dict:
            #         parent_module.embed()
            #         bypass_flag=True
            #     else:
            #         parent_module.embed()
            #         parent_module.proj()
            #         bypass_flag=False
            #     if hasattr(module,'bypass_layers'): # Bypassing model
            #         d_layer_idx = [module.bypass_layers[i]['D'] for i in range(len(module.bypass_layers))].index(parent_module)
            #         module.bypass_layers[d_layer_idx]['bypass']=bypass_flag
            # continue
    missing_keys, unexpected_keys = module.load_state_dict(new_state_dict,strict=False)
    if len(unexpected_keys) > 0:
        raise KeyError(f'unexpected keys found: {unexpected_keys}')
    if len(missing_keys) > 0:
        print(f'[WARN] missing keys found while loading model: {[x for x in missing_keys if not x.endswith("mask")]}')
    return module, 

class AttentionNeuronPruner(BasePruningFunc):
    def prune_in_channels(self,layer: timm.models.vision_transformer.Attention, idxs: list):
        return prune_tp(layer.qkv,torch.tensor(idxs),target='in')

    def prune_out_channels(self,layer: timm.models.vision_transformer.Attention, idxs: list):
        return prune_tp(layer.proj,torch.tensor(idxs),target='out')
    def prune_mid_channels(self,layer:timm.models.vision_transformer.Attention,idxs:list):
        num_heads = layer.num_heads
        head_dim= layer.qkv.weight.shape[1] // num_heads
        # head_grouped_preserve_mask = torch.ones(layer.qkv.weight.shape[0]//num_heads,dtype=bool)
        # head_grouped_preserve_mask[idxs]=0
        preserve_mask = torch.ones((layer.qkv.weight.shape[0],),dtype=torch.bool)
        preserve_mask[idxs] = 0
        # q_mask,k_mask,v_mask =[x.expand(num_heads,-1).flatten().to(dtype=bool) for x in head_grouped_preserve_mask.split([head_dim,head_dim,head_dim])]
        dim = layer.qkv.weight.shape[1]
        num_heads = layer.num_heads
        qkv_split = [dim,dim,dim]
        q_mask, k_mask, v_mask = preserve_mask.split(qkv_split)
        
        qk_mask = torch.logical_and(q_mask,k_mask)
        qkv_mask = torch.concat([qk_mask,qk_mask,v_mask]) 
        pruned_layer = PrunedAttention_spatial.convert_from(layer)
        #breakpoint()
        pruned_layer.qk_size=sum(qk_mask)//num_heads
        pruned_layer.v_size=sum(v_mask)//num_heads

        # attn.qkv pruning
        qkv_prune_indices = torch.where(~qkv_mask)[0]
        prune_tp(pruned_layer.qkv,qkv_prune_indices,'out')

        # attn.proj pruning
        proj_prune_indices = torch.where(~v_mask)[0]
        prune_tp(pruned_layer.proj,proj_prune_indices,'in')
        return pruned_layer

    def prune_head_channels(self,layer: timm.models.vision_transformer.Attention, idxs: list):
        return NotImplemented
    def get_mid_channels(self,layer):
        return layer.qkv.weight.shape[0]
    
    def get_out_channels(self, layer):
        return layer.proj.weight.shape[0]

    def get_in_channels(self, layer):
        return layer.qkv.weight.shape[1]
    
class BypassActivationPruner(BasePruningFunc):

    def prune_out_channels(self, layer: ActivationForBypass, idxs: list):
        if layer.num_parameters == 1: # prune nothing
            return layer
        keep_idxs = list(set(range(layer.num_parameters)) - set(idxs))
        keep_idxs.sort()
        layer.num_parameters = layer.num_parameters-len(idxs)
        layer.delta = self._prune_parameter_and_grad(layer.delta, keep_idxs, 0)

        if getattr(layer,'delta2',None) is not None:
            layer.delta2 = self._prune_parameter_and_grad(layer.delta2,keep_idxs,0)

        if hasattr(layer.activation,'weight'):
            prune_tp(layer.activation,torch.tensor(idxs,device=layer.activation.weight.device),'out')

        return layer

    prune_in_channels = prune_out_channels

    # def prune_in_channels(self, layer:nn.Module, idxs: Sequence[int]) -> nn.Module:
    #    return self.prune_out_channels(layer=layer, idxs=idxs)

    def get_out_channels(self, layer):
        if layer.num_parameters == 1:
            return None
        else:
            return layer.num_parameters

    def get_in_channels(self, layer):
        return self.get_out_channels(layer=layer)

if __name__ == '__main__':
    import torch
    from timm.models import create_model
    from timm.models.vision_transformer import VisionTransformer, _cfg, Attention

    import sys
    sys.path.append('/workspace/IPPRO_pruning')
    from bypass.core.activation import TrivialActivationForBypass,TrivialActivationForDx2,ActivationForDx2


    model_id = 'deit_base_patch16_224'

    pretrained_model = create_model(
            model_id,
            pretrained=True,
            num_classes=1000,
            drop_rate=0,
            drop_path_rate=0.1,
            drop_block_rate=None,
        )
    pretrained_model.eval()

    for name,module in pretrained_model.named_modules():
        parent_module_name = parent_module_name = '.'.join(name.split('.')[:-1])
        node_name = name.split('.')[-1]
        if node_name == 'qkv':
            parent_module:torch.nn.Module = pretrained_model.get_submodule(parent_module_name)
            new_module = TrivialActivationForBypass(module.weight.shape[0],module,channel_last=True) #Note: you can use TrivialActivationForDx2 for alternative
            parent_module.register_module(node_name,new_module)
        if name.endswith('mlp.act'):
            parent_module:torch.nn.Module = pretrained_model.get_submodule(parent_module_name)
            new_module = ActivationForDx2(parent_module.fc1.out_features, module,channel_last=True)
            parent_module.register_module(node_name,new_module)
            
    pruning_indices = torch.tensor([1,1000,2000])

    layer_name= 'blocks.0.attn'
    layer = pretrained_model.get_submodule(layer_name)
    # layer = pretrained_model.blocks[0].attn

    pruned_layer = prune_tp(layer,pruning_indices)
    print(id(layer))
    print(id(pruned_layer))
    
    if id(layer)!=id(pruned_layer):
        parent_module_name = '.'.join(layer_name.split('.')[:-1])
        node_name = layer_name.split('.')[-1]
        parent_module:torch.nn.Module = pretrained_model.get_submodule(parent_module_name)
        parent_module.register_module(node_name,pruned_layer)
        del(layer)