from torch.nn.utils.prune import CustomFromMask
class OutChannelStructuredFromMask(CustomFromMask):
    PRUNING_TYPE = "structured"
    def __init__(self, mask,dim = -1):
        self.mask = mask
        self.dim=dim
    @classmethod
    def apply(cls, module, name, mask, dim=-1):
        return super(CustomFromMask,cls).apply(module, name,mask, dim=dim)
    def compute_mask(self, t, default_mask):
        if t.dim() > 1:  # self.mask.unsqueeze(1).expand([-1,*t.shape[1:]]))
            mask = (default_mask*self.mask.reshape(-1,*[1 for _ in range(t.dim()-1)]).expand_as(default_mask)).to(dtype=default_mask.dtype)
        else:
            mask = (default_mask*self.mask).to(dtype=default_mask.dtype)
        return mask
class InChannelStructuredFromMask(CustomFromMask):
    PRUNING_TYPE = "structured"
    def __init__(self, mask,dim = -1):
        self.mask = mask
        self.dim=dim
    @classmethod
    def apply(cls, module, name, mask, dim=-1):
        return super(CustomFromMask,cls).apply(module, name,mask, dim=dim)
    def compute_mask(self, t, default_mask):
        if t.dim() > 1: 
            mask = (default_mask*self.mask.reshape(1,-1,*[1 for _ in range(t.dim()-2)]).expand_as(default_mask)).to(dtype=default_mask.dtype)
        else:
            return default_mask # bias는 나중에 더해지는거라 in channel pruning에 영향없음
        return mask
def custom_structured_from_mask(module, name, mask, dim, target='out'):
    if target == 'out':
        OutChannelStructuredFromMask.apply(module, name, mask=mask, dim=dim)
    elif target == 'in':
        InChannelStructuredFromMask.apply(module, name, mask=mask, dim=dim)
    return module


if __name__ == '__main__':
    from torch.nn.utils.prune import LnStructured, ln_structured
    from bypass.train import BypassTrainerBase
    import torch
    import shutil
    # for i in [200,400,600,800,1000,1200,1400,1600,1800]:
    trainer=BypassTrainerBase.load_from(f"/workspace/jaeheun_MildPruning/save/mild_pruning_W/cifar10_NaiveBypassBNresnet56/20240117-014157/save/model_pruning_dev_point_0_layer0.pt",autosave=False,prune_epsilon=1e-4)
    delta_vec=trainer.model.bypass_layers[0]['D'].delta
    shutil.rmtree(trainer.configs.summary_path.parent)
    before_prune_result = trainer.eval_loop()
    print(f'before_prune_acc:{before_prune_result.acc/before_prune_result.count}')
    print(f'before_prune_loss:{before_prune_result.loss/before_prune_result.count}')

    epsilon =1e-4

    preserve_d_mask = abs(delta_vec) < epsilon
    prune_d_mask = abs(delta_vec) >= epsilon

    preserve_d =  delta_vec[preserve_d_mask].cpu().numpy()
    prune_d =  delta_vec[prune_d_mask].cpu().numpy()

    print(f'epsilon: {epsilon}')
    print(f'preserve nodes: {len(preserve_d)}')
    print(f'prune nodes: {len(prune_d)}')

    W_layer =  trainer.model.bypass_layers[0]['W'][0]
    D_layer =  trainer.model.bypass_layers[0]['D']
    A_layer =  trainer.model.bypass_layers[0]['A'][0]
    # trainer.model.projection(0)

    
    custom_structured_from_mask(W_layer,'weight',preserve_d_mask,dim=0,target='out')
    if W_layer.bias is not None:
        custom_structured_from_mask(W_layer,'bias',preserve_d_mask,dim=0,target='out')

    custom_structured_from_mask(A_layer,'weight',preserve_d_mask,dim=0,target='in')
    if A_layer.bias is not None:
        custom_structured_from_mask(A_layer,'bias',preserve_d_mask,dim=0,target='in')
    # random_mask = preserve_d_mask[torch.randperm(1024)]
    # custom_structured_from_mask(W_layer,'weight',random_mask,dim=0,target='out')
    # custom_structured_from_mask(W_layer,'bias',random_mask,dim=0,target='out')
    # custom_structured_from_mask(D_layer,'delta',preserve_d_mask,dim=0)

    after_prune_result = trainer.eval_loop()
    print(f'after_prune_acc:{after_prune_result.acc/after_prune_result.count}')
    print(f'after_prune_loss:{after_prune_result.loss/after_prune_result.count}')
    print(1)

    