import torch
import random
import copy

def magnitude_scores(network):
    for p in network.parameters():
        p.data = p.data.abs_()
    return network

def grad_scores(network, dataloader, loss, dev):
    net  = copy.deepcopy(network)
    net.to(dev)
    net.eval()
    for batch_idx, (data, target) in enumerate(dataloader):
        print(batch_idx)
        data = data.to(dev)
        target = target.to(dev)

        output = net(data)
        loss(output, target).backward()

    for p in net.parameters():
        if len(p.data.size()) != 1:
            p.data = p.grad.abs_()

    return net
