import torch
import torch.nn as nn


class SRLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias=bias)
        self.register_buffer('u', nn.functional.normalize(torch.randn(in_features), dim=0))
        with torch.no_grad():
            sigma = self.get_sigma()
        self.register_buffer('spectral_norm', sigma)

        self.sigma = nn.Parameter(torch.ones(1))

    def get_sigma(self):
        with torch.no_grad():
            u = self.u
            v = self.weight.mv(u)
            v = nn.functional.normalize(v, dim=0)
            u = self.weight.T.mv(v)
            u = nn.functional.normalize(u, dim=0)
            self.u.data.copy_(u)
        return torch.einsum('c,cd,d->', v, self.weight, u)
        
    def get_weight(self):
        sigma = self.get_sigma()
        if self.training:
            self.spectral_norm.data.copy_(sigma)
        weight = (self.sigma / sigma) * self.weight
        return weight
        
    def forward(self, x):
        return nn.functional.linear(x, self.get_weight(), self.bias)

class SRConv2d(SRLinear):
    def __init__(self, in_features, out_features, kernel_size, stride=1, padding=0, bias=True):
        kernel_size = to_2tuple(kernel_size)
        stride = to_2tuple(stride)
        in_features = in_features * kernel_size[0] * kernel_size[1]
        super().__init__(in_features, out_features, bias=bias)
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.kernel_size = kernel_size
        self.stride = stride

    def forward(self, x):
        weight = self.get_weight().view(self.out_features, -1, self.kernel_size[0], self.kernel_size[1])
        return nn.functional.conv2d(x, weight, bias=self.bias, stride=self.stride)
