import torch_pruning as tp
from torch_pruning import ops
from torch_pruning.pruner import BasePruningFunc
import torch

def prune_tp(module:torch.nn.Module,prune_indices,target='out'):
    try:
        pruning_func = tp.function.PrunerBox[ops.module2type(module)]
    except:
        print(f'{type(module)} is not supported by torch_pruner. Possibly it is not a leaf node')
        return module
    if target == 'out':
        pruning_func.prune_out_channels(module,prune_indices.cpu().numpy())
    elif target == 'in':
        pruning_func.prune_in_channels(module,prune_indices.cpu().numpy())
    else:
        raise NotImplementedError()
    
    module.register_buffer(f'{target}_mask',prune_indices)
    return module