from .networks import *
import torch
from torch.nn import functional as F
def shrink_2D(a, b, weight):
    temp = torch.clamp(torch.sqrt(torch.pow(a, 2) + torch.pow(b, 2)), min=1e-6)
    nsk = F.relu(1 - weight / temp)
    
    return a * nsk, b * nsk