import torch.nn as nn

class SlimmableConv2d(nn.Conv2d):
    def __init__(self, in_channel, out_channel,
                 kernel_size, stride=1, padding=0):
        super(SlimmableConv2d, self).__init__(
            in_channel, out_channel,
            kernel_size, stride=stride, padding=padding)
        self.real_in_channels = self.in_channels
        self.real_out_channels = self.out_channels
        self.start_in_channel = 0
        self.start_out_channel = 0


    def forward(self, input):
        weight = self.weight[self.start_out_channel:self.real_out_channels, self.start_in_channel:self.real_in_channels, :, :]
        bias = self.bias[self.start_out_channel:self.real_out_channels]

        y = nn.functional.conv2d(
            input, weight, bias, self.stride, self.padding,
            self.dilation, self.groups)
        return y
    

class Scaler(nn.Module):
    def __init__(self, max_channel,fixed=False):
        super().__init__()
        self.max_channel = max_channel
        self.real_in_channels = max_channel
        self.fixed = fixed

    def forward(self, input):
        if self.fixed:
            self.rate = 1
            # self.rate = 0.0625
        else:
            self.rate = self.real_in_channels/self.max_channel
        output = input / self.rate
        # output = input / self.rate if self.training else input
        return output