import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair

def GMPChooseEdges(weight, prune_rate):
    output = weight.clone()
    _, idx = weight.flatten().abs().sort()
    p = int(prune_rate * weight.numel())
    # flat_oup and output access the same memory.
    flat_oup = output.flatten()
    flat_oup[idx[:p]] = 0
    return output

def GMPChangeMasks(weight, mask, curr_prune_rate):
    output = mask.clone()
    w = weight.clone() * output
    _, idx = w.flatten().abs().sort()
    p = int(curr_prune_rate * w.numel())
    # flat_oup and output access the same memory.
    flat_oup = output.flatten()
    flat_oup[idx[:p]] = 0
    return output

class GMPConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def set_prune_rate(self, prune_rate):
        self.prune_rate = prune_rate
        self.curr_prune_rate = 0.0
        self.mask = nn.torch.ones(self.weight.shape)
        print(f"=> Setting prune rate to {prune_rate}")

    def set_curr_prune_rate(self, curr_prune_rate):
        self.curr_prune_rate = curr_prune_rate

        self.mask = GMPChangeMasks(self.weight, self.mask, self.curr_prune_rate)


    def forward(self, x):
        self.weight.data = self.weight.data * self.mask
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(x, self._padding_repeated_twice, mode=self.padding_mode),
                            self.weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        x = F.conv2d(
            x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )

        return x

class GMPLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def set_prune_rate(self, prune_rate):
        self.prune_rate = prune_rate
        self.curr_prune_rate = 0.0
        self.mask = nn.torch.ones(self.weight.shape)
        print(f"=> Setting prune rate to {prune_rate}")

    def set_curr_prune_rate(self, curr_prune_rate):
        self.curr_prune_rate = curr_prune_rate

        self.mask = GMPChangeMasks(self.weight, self.mask, self.curr_prune_rate)


    def forward(self, x):
        self.weight.data = self.weight.data * self.mask

        x = F.linear(x, self.weight)

        return x


def gmp_prune_conv_linear(model):

    for name, module in reversed(model._modules.items()):

        if len(list(module.children())) > 0:
            model._modules[name] = gmp_prune_conv_linear(model=module)

        if isinstance(module, nn.Linear):
            bias=True
            if module.bias == None:
                bias=False
            layer_new = GMPLinear(module.in_features, module.out_features, bias)
            model._modules[name] = layer_new

        if isinstance(module, nn.Conv2d):
            bias=True
            if module.bias == None:
                bias=False
            layer_new = GMPConv(module.in_channels, module.out_channels, module.kernel_size, module.stride,
                            padding=module.padding, dilation=module.dilation, groups=module.groups,
                            bias=bias)
            model._modules[name] = layer_new

    return model