import torch
import torch.nn as nn
from Utils import load
from Prune import Scoring
from Utils import train
from Prune import biased_randwalk_utils
import copy
import numpy as np
import pickle as pkl

def run(args):
    torch.manual_seed(args.seed)
    dev = load.device(args.gpu)

    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_size)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
    test_loader = load.dataloader(args.dataset, args.train_batch_size, False, args.workers)
    loss = nn.CrossEntropyLoss()

    model = load.model(args.model, args.dataset)(input_shape, num_classes).to(dev)
    grad_model = Scoring.grad_scores(model,prune_loader,loss, dev)
    prob, reverse_prob, kernel_prob = biased_randwalk_utils.generate_probability(grad_model, verbose=True)

    #for i in range(len(args.compression)):
    for i in range(len(args.prune_perc)):
        print(args.experiment, str(args.prune_perc[i]), str(args.model), str(args.dataset))
        #prune_perc = 1.0 - 10.0**(-args.compression[i])
        prune_perc = args.prune_perc[i]
        sparse_model = copy.deepcopy(model)
        loss = nn.CrossEntropyLoss()
        opt, opt_kwargs = load.optimizer(args.optimizer)

        weight_masks, bias_masks = biased_randwalk_utils.Wrand_walk_masks(model, prune_perc*100.0, prob, reverse_prob, kernel_prob,verbose=True)
        sparse_model.set_masks(weight_masks, bias_masks)
        sparse_model.to(dev)

        optimizer = opt(sparse_model.parameters(), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)

        (train_curve, test_loss, accuracy1, accuracy5) = train.train(sparse_model, loss, optimizer, train_loader,
                                                                        test_loader, dev, args.epochs, scheduler)

        results = []

        results.append(train_curve)
        results.append(test_loss)
        results.append(accuracy1)
        results.append(accuracy5)

        with open(args.experiment + str(args.prune_perc[i]) + str(args.model) + str(args.dataset) + '.pkl', "wb") as fout:
            pkl.dump(results, fout, protocol=pkl.HIGHEST_PROTOCOL)