import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import copy
import types


"""
SNIP: Single-shot Network Pruning based on Connection Sensitivity (ICLR 2019)
https://github.com/namhoonlee/snip-public

Strongly inspired by https://github.com/mi-lad/snip
"""

 
def snip_forward_conv2d(self, x):
    if self.bias is not None:
        return F.conv2d(x, self.weight * self.weight_mask_w, self.bias * self.weight_mask_b,
                        self.stride, self.padding, self.dilation, self.groups)
    else:
        return F.conv2d(x, self.weight * self.weight_mask_w, self.bias ,
                        self.stride, self.padding, self.dilation, self.groups)


def snip_forward_linear(self, x):
    if self.bias is not None:
        return F.linear(x, self.weight * self.weight_mask_w, self.bias * self.weight_mask_b)
    else:
        return F.linear(x, self.weight * self.weight_mask_w, self.bias)


def get_saliency_snip(args, origin_model, batch, device, num_alive, alive_idx, \
                                alive_mask, data_shape, num_classes):
    model = copy.deepcopy(origin_model)

    # Grab a single batch from the training dataset
    data, target = batch
    data, target = data.to(device), target.to(device)

    for layer in model.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            layer.weight_mask_w = nn.Parameter(torch.ones_like(layer.weight))
            layer.weight.requires_grad = False

            if layer.bias is not None:
                layer.weight_mask_b = nn.Parameter(torch.ones_like(layer.bias))
                layer.bias.requires_grad = False

        # Override the forward methods:
        if isinstance(layer, nn.Conv2d):
            layer.forward = types.MethodType(snip_forward_conv2d, layer)

        if isinstance(layer, nn.Linear):
            layer.forward = types.MethodType(snip_forward_linear, layer)

    # Compute gradients (but don't apply them)
    model.zero_grad()
    output = model.forward(data)
    loss = F.cross_entropy(output, target, reduction='none')
    loss.sum().backward()
    
    saliency = {}

    # Make new mask
    for name, p in model.named_parameters():
        if 'weight_mask' in name:
            if '_w' in name:
                weight_name = name.rsplit(".", 1)[0] + '.weight'
            elif '_b' in name:
                weight_name = name.rsplit(".", 1)[0] + '.bias'

            tensor = p.grad.data.cpu().numpy()
            
            saliency[weight_name] = np.abs(tensor)

    return saliency