import torch
import torch.nn as nn


class SR(nn.Module):
    def __init__(self, r=20, d=5):
        super(SR, self).__init__()
        self.r1 = r
        self.r2 = r + d

    def forward(self, x):
        # D = x.size(1)
        d = torch.norm(x, p=2, dim=1, keepdim=True)
        n1 = (d * 0).detach() + self.r1
        n2 = (d * 0).detach() + self.r2
        x = x * torch.clamp(n1 / d, min=1)
        x = x * torch.clamp(n2 / d, max=1)
        return x
