import torch
from torch.nn.utils import prune

class ThresholdPruning(prune.BasePruningMethod):
	PRUNING_TYPE = "unstructured"

	def __init__(self, threshold):
		self.threshold = threshold

	def compute_mask(self, tensor, default_mask):
		return torch.abs(tensor) > self.threshold

class GroupThresholdPruning(prune.BasePruningMethod):
	PRUNING_TYPE = "unstructured"
	
	def compute_mask(self, tensor, default_mask):
		print(tensor.shape, default_mask.shape, default_mask.shape)
		mask = default_mask.clone()
		
		return mask
		
		
def group_unstructured(module, name):
	GroupThresholdPruning.apply(module, name)
	return module
