"""
    Main script for pruning functions and utils.
"""
import torch



def norm_prune_weights(model, p=0.1, tol=1e-4, max_iter=50):
    """
        Weight Pruning function.
        Given a ratio p, we prune based on the magnitude of the weights.
    """
    with torch.no_grad():
        for name, param in model.named_parameters():
            if 'weight' not in name: continue
            orig = param.norm(p=2)
            target = (1 - p) * orig; low, high = 0.0, param.abs().max().item()
            tau = high
            for _ in range(max_iter):
                mid = (low + high) / 2
                mask = (param.abs() > mid).float()
                surv = (param * mask).norm(p=2)
                if abs(surv - target) < tol * orig:
                    tau = mid; break
                if surv > target: low = mid
                else: high = mid
                tau = mid
            param.mul_((param.abs() > tau).float())
    return model


if __name__=="__main__":
    pass
