import os
import pathlib
import random
import shutil
import time
import json
import numpy as np
import csv

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.tensorboard import SummaryWriter

from utils.logging import AverageMeter, ProgressMeter
from utils.net_utils import save_checkpoint, get_lr, LabelSmoothing
from utils.schedulers import get_policy, cosine_lr, assign_learning_rate, warmup_lr, constant_lr, multistep_lr, multistep_lr_drops, imagenet_lr_drops, cosine_lr_warmup, multistep_lr_warmup, imagenet_lr_drops_warmup, cyclic_lr, one_cycle_lr
from utils.conv_type import STRConv, STRConvER, ConvER, ConvMask, LinearMask, ConvMaskSign, ConvMaskMW
from utils.conv_type import sparseFunction
from utils.compensation import CompensatePrune
from utils.custom_activation import TrackActReLU
from utils.bin_opt import MomentumWithThresholdBinaryOptimizer


from args import args
from trainer import train, validate, get_preds, hessian_trace, train_kd, train_with_fixed_signs

import data
from data import cifar10, imagenet

from pruners import prune_width_gradual, PruneRandInMask, make_regular_random_graph, make_config_from_mask, prune_width_gradual_nonuniform_ratio

import models
from models import resnet18
from models import resnet18c
from models import resnet
from models import fc
from models import resnet20
from models import vgg
from models import cnn
from models import vit

#from models import resnet_ima

# Making faster implementation on A100 GPUs
# # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# # in PyTorch 1.12 and later.

# # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
from prettytable import PrettyTable

def main():
    table = PrettyTable()

    args_dict = vars(args)

    table.field_names = ["Argument Index", "Argument Value"]
    
    for key, value in args_dict.items():
        table.add_row([key, value])

    print(table)
    #print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        np.random.seed(args.seed)

    # Simply call main_worker function
    main_worker(args)


def main_worker(args):

    ### Modifications for speeding up training    
    torch.backends.cuda.matmul.allow_tf32 = True

    args.gpu = 0 
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    # model = get_model(args)
    if args.set in {'cifar10', 'label-noise-cifar10'}:
        print('PAPERBACK WRITER FOR THIS CIFAR10 training!')
        if args.resnet_type == 'small-dense':
            model = resnet18.ResNetWidth18([3, 32, 32], num_classes=10, width = args.width)
        elif args.resnet_type == 'cnn':
            model = cnn.SmallCNN(num_classes=10, width = args.width)
        elif args.resnet_type == 'conv-5':
            model = cnn.CNN5(num_classes=10, width=args.width)
        elif args.resnet_type == 'small-dense-inc':
            model = resnet18.ResNetWidth18Inc([3, 32, 32], num_classes=10, width = args.width)
        elif args.resnet_type == 'small-sub':
            model = resnet18.ResNetWidthSub18([3, 32, 32], num_classes=10, width = args.width)
        elif args.resnet_type == 'vgg16':
            model = vgg.vgg16_bn(num_classes=10)
        elif args.resnet_type == 'large-c':
            model = resnet18c.ResNet18C([3, 32, 32], num_classes=10)
        elif args.resnet_type == 'res20':
            model = resnet20.ResNet20([3, 32, 32], num_classes=10)
            print(model)
        elif args.resnet_type == 'res20-width':
            model = resnet20.ResNetWidth20([3, 32, 32], width=args.width, num_classes=10)
            print(model)
        elif args.resnet_type == 'vit':
            model = vit.ViT(image_size = 32, patch_size = 4, num_classes = 10, dim = int(512),
                    depth = 6, heads = 8, mlp_dim = 512, dropout = 0.1, emb_dropout = 0.1)
        else:
            model = resnet18.ResNet18([3, 32, 32], num_classes=10)
            if args.kd:
                teacher = resnet18.ResNet18([3, 32, 32], num_classes=10)

    if args.set in {'cifar100', 'label-noise-cifar100'}:
        print('HALLELUJAH THIS RAN FOR CIFAR100 ALLES GUT')
        if args.resnet_type == 'small-dense':
            model = resnet18.ResNetWidth18([3, 32, 32], num_classes=100, width = args.width)
        elif args.resnet_type == 'small-dense-inc':
            model = resnet18.ResNetWidth18Inc([3, 32, 32], num_classes=100, width = args.width)
        elif args.resnet_type == 'r50-small-dense-inc':
            model = resnet18.ResNetWidth50Inc([3, 32, 32], num_classes=100, width = args.width)
        elif args.resnet_type == 'r50-small-dense':
            model = resnet18.ResNetWidth50([3, 32, 32], num_classes=100, width = args.width)
        elif args.resnet_type == 'small-sub':
            model = resnet18.ResNetWidthSub18([3, 32, 32], num_classes=100, width = args.width)
        elif args.resnet_type == 'res50':
            model = resnet18.ResNet50([3, 32, 32], num_classes=100)
        elif args.resnet_type == 'large-c':
            model = resnet18c.ResNet18C([3, 32, 32], num_classes=100)
        else:
            model = resnet18.ResNet18([3, 32, 32], num_classes=100)
            if args.kd:
                teacher = resnet18.ResNet18([3, 32, 32], num_classes=100)

    if args.set == 'tiny-imagenet':
        if args.resnet_type == 'res18':
            model = resnet.ResNet18(num_classes=200)
        elif args.resnet_type == 'small-dense-inc':
            model = resnet.ResNetWidth50Inc(width=args.width, num_classes=200)
        else:
            model = resnet.ResNet50(num_classes=200)

    if args.set == 'imagenet':
        if args.resnet_type == 'small-dense-inc':
            model = resnet.ResNetWidth50Inc(width=args.width, num_classes=1000)
        elif args.resnet_type == 'res18':
            model = resnet.ResNet18(num_classes=1000)
        else:
            model = resnet.ResNet50()

    if args.set == 'mnist':
        hidden = 256
        if args.resnet_type == 'linear':
            model = fc.MLP_linear(hidden)
        elif args.resnet_type == 'fc-width':
            model = fc.MLP(args.width)
        elif args.resnet_type == 'single':
            model = fc.SingleLayer(args.width)
        elif args.resnet_type == 'double':
            model = fc.TwoLayer(args.width)
        elif args.resnet_type == 'double-relu':
            model = fc.TwoLayerReLU(args.width)
        else:
            model = fc.MLP(hidden)
    
    if args.set == 'fashion-mnist':
        model = fc.MLP_BN(args.width)

    #Setting the base directory for running experiments
    base_dir = ''
    
    
    # push to utils
    def get_sparsity(model):
        # finds the current density of the model and returns the density scalar value
        nz = 0
        total = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, LinearMask)):
                nz += m.mask.sum()
                total += m.mask.numel()
        
        return nz / total
    ## pruners.py

    def prune_mag(model, density):
        score_list = {}
        for n, m in model.named_modules():
            # torch.cat([torch.flatten(v) for v in self.scores.values()])
            if isinstance(m, (ConvMask, LinearMask)):
                score_list[n] = (m.mask.to(m.weight.device) * m.weight).detach().abs_()

        global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
        k = int((1 - density) * global_scores.numel())
        threshold, _ = torch.kthvalue(global_scores, k)

        if not k < 1:
            total_num = 0
            total_den = 0
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask, LinearMask)):
                    score = score_list[n].to(m.weight.device)
                    zero = torch.tensor([0.]).to(m.weight.device)
                    one = torch.tensor([1.]).to(m.weight.device)
                    m.mask = torch.where(score <= threshold, zero, one)
                    total_num += (m.mask == 1).sum()
                    total_den += m.mask.numel()

        print('Overall model density after magnitude pruning at current iteration = ', total_num / total_den)
        return model
    
    if args.er_sparse_method == 'uniform':
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                m.set_er_mask(args.er_sparse_init)

        print(args.er_sparse_init)

    if args.er_sparse_method == 'mag':
        model = prune_mag(model, args.er_sparse_init)
        print('Initialized model at (random) init with one shot magnitude pruning with density: ', get_sparsity(model))

    if args.er_sparse_method == 'ERK':
        sparsity_list = []
        num_params_list = []
        total_params = 0
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                sparsity_list.append(torch.tensor(m.weight.shape).sum() / m.weight.numel())
                num_params_list.append(m.weight.numel())
                total_params += m.weight.numel()
        
        num_params_kept = (torch.tensor(sparsity_list) * torch.tensor(num_params_list)).sum()
        num_params_to_keep = total_params * args.er_sparse_init
        C = num_params_to_keep / num_params_kept
        sparsity_list = [torch.clamp(C*s, 0, 1) for s in sparsity_list]
        print(sparsity_list)
        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                m.set_er_mask(sparsity_list[l])
                l += 1
        print(sparsity_list)
    

    if args.er_sparse_method == 'balanced':
        total_params = 0
        l = 0
        sparsity_list = []
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                total_params += m.weight.numel()
                l += 1
        L = l
        X = args.er_sparse_init * total_params / l
        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                if X / m.weight.numel() < 1.0:
                    sparsity_list.append(X / m.weight.numel())
                else: 
                    sparsity_list.append(1)
                    # correction for taking care of exact sparsity
                    diff = X - m.mask.numel()
                    X = X + diff / (L - l)
                l += 1

        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                m.set_er_mask(sparsity_list[l])
                l += 1
        print(sparsity_list)

    if args.er_sparse_method == 'regular-graph':
        model = make_regular_random_graph(model, args)
        print('Initialized model as a regular graph with (approx) fixed degrees for each node')

    if args.er_sparse_method == 'configuration-ref':
        print('Creates a configuration model from a reference mask')
        print('Mask used is:  ', args.ref_sparsity_mask_name)
        model = make_config_from_mask(model, args)
        print('Density of the model is: ', get_sparsity(model))

    if args.er_sparse_method == 'uniform-regular':
        l = 0
        total = 0
        nz = 0
        p = args.er_sparse_init
        for n, m in model.named_modules():
            # print(m.shape)
            if isinstance(m, ConvMask):
                print('Layer: ', p, m.weight.shape)
                curr_mask = torch.ones_like(m.weight)
                if l < L-1:
                    if p == 1:
                        curr_mask[:, :, :, :] = 1
                        total += curr_mask.numel()
                        nz += curr_mask.sum()

                    else:
                        no, ni, k1, k2 = curr_mask.shape[0], curr_mask.shape[1], curr_mask.shape[2], curr_mask.shape[3]
                        for i in range(k1):
                            for j in range(k2):  
                                Al = make_config_model(p, ni, no)
                                curr_mask[:, :, i, j] = Al[:, :]
                        total += curr_mask.numel()
                        nz += curr_mask.sum()
                        # 
                        # print('Mask density: ', curr_mask.shape, (curr_mask.sum() / curr_mask.numel()).item())
                else:
                    curr_mask[:, :, :, :] = 1
                    total += curr_mask.numel()
                    nz += curr_mask.sum()
                m.mask[:, :, :, :] = curr_mask
                l += 1
                print('Initialized ER network with uniform sparsity ratios layerwise and a regular graph in each layer')

    if args.er_sparse_method == 'balanced-scale-free':
        # initialize the network with balanced layerwise ratios
        # but within each layer, neurons are assigned degrees in an exponentially decreasing ratio, not uniformly
        # This will ensure different degrees for neurons within a layer
        total_params = 0
        l = 0
        sparsity_list = []
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                total_params += m.weight.numel()
                l += 1
        L = l
        X = args.er_sparse_init * total_params / l
        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                if X / m.weight.numel() < 1.0:
                    sparsity_list.append(X / m.weight.numel())
                else: 
                    sparsity_list.append(1)
                    # correction for taking care of exact sparsity
                    diff = X - m.mask.numel()
                    X = X + diff / (L - l)
                l += 1

        l = 0
        nz = 0
        total = 0

        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                p = sparsity_list[l]
                n = m.weight.shape[0]
                mask = m.mask
                # solve the GP to assign neuron wise ratios
                a = np.ones(n + 1)
                a[-1] = -p*n
                r = np.roots(a)
                for i in r:
                    if np.isreal(i) and np.real(i) > 0 and np.real(i) < 1:
                        base = np.real(i)
                        print('base is: ', base)
                        neuron = base ** np.arange(n)
                    else:
                        # print('Cannot find base for: ', n, sparsity_list[l])
                        neuron = sparsity_list[l] * np.ones(n)
                        
                mask[:, :, :, :] = torch.tensor(neuron)[:, None, None, None]
                m.mask = torch.bernoulli(mask)
                l += 1
                total += m.mask.numel()
                nz += m.mask.sum()
        print(sparsity_list)
        print('The overall density of the initialized network is: ', nz / total)

    if args.er_sparse_method == 'pyramidal':

        num_params = 0
        layerwise_params = []
        for name, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMaskk)):
                num_params += m.weight.numel()
                layerwise_params.append(m.weight.numel())
        layerwise_params = layerwise_params[::-1]
        layerwise_params.append(-num_params*args.er_sparse_init)
        roots = np.roots(np.array(layerwise_params))
        print('Roots of np solver', roots)
        for r in roots:
            if r < 1 and r > 0 and np.imag(r) == 0:
                print('root identified: ', r)
                layerwise_sparsities = np.arange(1, len(layerwise_params) + 1)
                sparsity_list = np.real(r) ** layerwise_sparsities
        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                m.set_er_mask(sparsity_list[l])
                l += 1
        
        print(layerwise_sparsities)
    
    if args.er_sparse_method == 'str-induced-er':
        with open(base_dir + 'runs/layerwise_sparsity/' + args.er_sparsity_file) as f:
           data = json.load(f)
        sparsity_list = list(data.values())
        alpha = sparsity_list[-1] / args.er_sparse_init
        sparsity_list = [s / alpha for s in sparsity_list[:-1]]
        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask)):
                m.set_er_mask(sparsity_list[l])
                l += 1
        print(sparsity_list)


    if args.er_sparse_method == 'balanced-healthy':
        total_params = 0
        l = 0
        sparsity_list = []
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                total_params += m.weight.numel()
                l += 1
        L = l
        X = args.er_sparse_init * total_params / l
        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                if X / m.weight.numel() < 1.0:
                    sparsity_list.append(X / m.weight.numel())
                else: 
                    sparsity_list.append(1)
                    # correction for taking care of exact sparsity
                    diff = X - m.mask.numel()
                    X = X + diff / (L - l)
                l += 1

        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (STRConvER, ConvER, ConvMask, LinearMask)):
                m.set_er_mask(sparsity_list[l])
                l += 1
        print(sparsity_list)
        print('Pruning ER width healthy with healthy degree: ', args.healthy_degree)
        model = prune_width_healthy(model, args)


    if args.set == 'imagenet':
        # speeding up cudnn
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    
    print('student before')
    print(model.state_dict().keys())
    model = set_gpu(args, model)
    print('student after')
    print(model.state_dict().keys())
    if args.kd:
        print('Teacher before')
        args.gpu = None
        print(teacher.state_dict().keys())
        teacher = set_gpu(args, teacher)
        print('Teacher after ')
        print(teacher.state_dict().keys())
    
    ## save the model random init before warmup ##
    torch.save(model.state_dict(), "{}runs/model_{}_init_scratch.pt".format(base_dir, args.name))

    print('The model definition is:')
    print(model)
    total_num_narrow = 0
    for n, m in model.named_modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            total_num_narrow += m.weight.numel()
    print('num params', total_num_narrow)
    # Set up directories
    run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)

    
    resnet18_params = 11164352

    optimizer = get_optimizer(args, model)
    # data = get_dataset(args)
    if args.set == 'cifar10':  
        data = cifar10.CIFAR10(args)
    if args.set == 'cifar100':
        data = cifar10.CIFAR100(args)
    if args.set == 'tiny-imagenet':
        data = cifar10.TinyImagenet(args)        
        # data = imagenet.FFCVTinyImageNet(args)

    if args.set == 'imagenet':
        if args.compression == 'default':
            data = imagenet.ImageNet(args)
        if args.compression == 'ffcv':
        # FFCVImageNet
            data = imagenet.FFCVImageNet(args)


    if args.set == 'mnist':
        data = cifar10.MNIST(args)
    if args.set == 'fashion-mnist':
        data = cifar10.FashionMNIST(args)

    if args.set == 'label-noise-cifar10':
        data = cifar10.LabelNoiseCIFAR10(args)
    
    if args.set == 'label-noise-cifar100':
        data = cifar10.LabelNoiseCIFAR100(args)

   
    criterion = nn.CrossEntropyLoss().cuda()
    
    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    # if args.resume:
    #     best_acc1 = resume(args, model, optimizer)

    # Evaulation of a model
    if args.evaluate:
        acc1, acc5 = validate(
            data.val_loader, model, criterion, args, writer=None, epoch=args.start_epoch
        )
        return

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(
        1, [epoch_time, validation_time, train_time], prefix="Overall Timing"
    )

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    save_checkpoint(
        {
            "epoch": 0,
            "arch": args.arch,
            "state_dict": model.state_dict(),
            "best_acc1": best_acc1,
            "best_acc5": best_acc5,
            "best_train_acc1": best_train_acc1,
            "best_train_acc5": best_train_acc5,
            "optimizer": optimizer.state_dict(),
            "curr_acc1": acc1 if acc1 else "Not evaluated",
        },
        False,
        filename=ckpt_base_dir / f"initial.state",
        save=False,
    )


    # prune method
    def prune_flow(model, thresh):
        total_num = 0
        total_den = 0
        
        for name, m in model.named_modules():
            if isinstance(m, (ConvMask, LinearMask)):
                print('Pruning for every layer of ConvMask')
                score = (m.mask.to(m.weight.device) * m.weight).detach().abs_().view(m.weight.shape[0], -1)
                max_weights, _ = score.max(dim=1)
                score = score / max_weights[:, None]
    
                new_mask = (score >= thresh)
                m.mask = new_mask.view(m.weight.shape)
                print('New Density: ', (new_mask.sum() / new_mask.numel()).item())
                print('Mask Density of Layer: ', ((m.mask == 1).sum() / m.mask.to(m.weight.device).numel()).item())
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()
        print('Overall model density = ', total_num / total_den)
        return model
    

    # random prune iterative
    def prune_random(model, density):
        
        total_params = 0
        l = 0
        sparsity_list = []
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, LinearMask)):
                total_params += m.weight.numel()
                l += 1

        X = density * total_params / l

        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, LinearMask)):
                if X / m.weight.numel() < 1.0:
                    sparsity_list.append(X / m.weight.numel())
                else: 
                    sparsity_list.append(1)
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, LinearMask)):
                curr_nz = (m.mask == 1).sum()
                target_nz = int(sparsity_list[cnt] * m.mask.numel())
                print(curr_nz, target_nz, m.mask.numel())
                diff = target_nz / (curr_nz - target_nz) 
                if diff > 0 and diff < 1:
                    m.mask = torch.where(m.mask == 0, m.mask, torch.empty(m.mask.shape).bernoulli_(diff))
                cnt+=1
        return model
    
    def prune_random_with_imp_sparsity(model, density, args):
        # this pruner applies layerwise sparisty proportional to IMP (or a given reference) 
        # args.ref_sparsity_mask_name
        mask_name = base_dir + 'runs/' + args.ref_sparsity_mask_name
        mask_list = torch.load(mask_name)
        sparsity_list = []
        num_params_list = []
        total_params = 0
        score_list = []
        
        for m in mask_list:
            score_list.append((torch.rand_like(m)))
            sparsity_list.append(m.sum() / m.numel())
            num_params_list.append(m.numel())
            total_params += m.numel()

        num_params_kept = (torch.tensor(sparsity_list) * torch.tensor(num_params_list)).sum()
        num_params_to_keep = total_params * density
        C = num_params_to_keep / num_params_kept
        sparsity_list = [torch.clamp(C*s, 0, 1) for s in sparsity_list]
        
        total_num = 0
        total_den = 0
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, LinearMask)):
                global_scores = torch.flatten(m.mask.to(m.weight.device) * score_list[cnt].to(m.weight.device))
                k = int((1 - sparsity_list[cnt]) * global_scores.numel())
                if k == 0:
                    threshold = 0
                else: 
                    threshold, _ = torch.kthvalue(global_scores, k)
                print('Layer', n, ' params ', k, global_scores.numel())

                score = m.mask.to(m.weight.device) * score_list[cnt].to(m.weight.device)
                zero = torch.tensor([0.]).to(m.weight.device)
                one = torch.tensor([1.]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += m.mask.sum()
                total_den += m.mask.numel()
                cnt += 1
                print('layer density: ', m.mask.sum() / m.mask.numel())
        print('Overall density of the model is: ', total_num / total_den)
        return model

    

    class PruneMagInMask:
        # Prune weights given a target mask
        def __init__(self, model, args):
            # generate and fix a mask
            print('Generating target mask of density: ', args.er_sparse_init)
            total_params = 0
            l = 0
            sparsity_list = []
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask, LinearMask)):
                    total_params += m.weight.numel()
                    l += 1
            L = l
            X = args.er_sparse_init * total_params / l

            l = 0
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask, LinearMask)):
                    if X / m.weight.numel() < 1.0:
                        sparsity_list.append(X / m.weight.numel())
                    else: 
                        sparsity_list.append(1)
                        diff = X - m.mask.numel()
                        X = X + (diff / (L - l))
                    l += 1

            l = 0
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask, LinearMask)):
                    m.mask = torch.ones_like(m.mask).bernoulli_(sparsity_list[l])
                    l += 1

            # Fix a smaller width for the target random network, by limiting the max width. The layers with width larger than max_width will
            # have zeroed out neurons (args.max_width)
            print('sparsity of before reducing width (only ER): ', get_sparsity(model))

            if args.reduce_width_target:
                model = self.prune_width(model, args.max_width)
            print('sparsity of after reducing width: ', get_sparsity(model))
            target_density = get_sparsity(model)

            self.prune_idx = {}
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask, LinearMask)):
                    mask = m.mask
                    self.prune_idx[n] = torch.stack(torch.where(mask == 1), dim=0)
                    m.mask = torch.ones_like(m.mask)

            print('dense network restored: ', get_sparsity(model))
            
            curr_density = 1
            prune_frac = 0.8
            L = int((np.log(target_density) - np.log(curr_density))/ np.log(prune_frac))
            l = np.arange(1, L+1)
            self.threshold_list = curr_density * prune_frac ** l

        def prune_width(self, model, max_width):
            # remove neurons for widths > max_width
            
            prune_idx_list = {}
            l = 0
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask, LinearMask)):
                    k = m.weight.shape[0]
                    if k > max_width:
                        num = k - max_width
                    else:
                        num = 0
                    prune_idx_list[n] = num
                    l += 1
            L = l


            l = 0
            for n, m in model.named_modules():
                if isinstance(m, ConvMask):
                    if l >= 0 and l < L - 1:
                        print('Pruning the outgoing edges of prev layer: ', prune_idx_list[n])
                        # remove the outgoing edges of the previous layer
                        num = prune_idx_list[n]
                        if num > 0:
                            print('Num pruning now: ', num)
                            m.mask[:num, :, :, :] = 0
                            print('incoming, right after prunign', (m.mask.sum(dim=(1, 2, 3))==0).sum().item(), m.mask.shape[0])
                            
                    if l > 0:
                        print('Pruning the incoming edges of prev layer: ', prune_idx_list[n])
                        # remove the incoming edges
                        num = prune_idx_list[n]
                        if num > 0:
                            print('Num pruning now: ', num)
                            m.mask[:, :num, :, :] = 0
        
                    print('incoming', (m.mask.sum(dim=(1, 2, 3))==0).sum().item(), m.mask.shape[0])
                    print('outgoing', (m.mask.sum(dim=(0, 2, 3))==0).sum().item(), m.mask.shape[1])        
                    
                    l += 1

            return model

        def fix_target(self, model):
            # fix the mask that is the target
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask)):
                    curr_idx = self.prune_idx[n]
                    mask = m.mask
                    mask[curr_idx[0, :], curr_idx[1, :], curr_idx[2, :], curr_idx[3, :]] = 1
                    m.mask = mask.to(m.weight.device)
            print('Fixed the target mask')
            return model

        def prune_mag(self, model, density):
            score_list = {}
            for n, m in model.named_modules():
                # torch.cat([torch.flatten(v) for v in self.scores.values()])
                if isinstance(m, (ConvMask)):
                    curr_idx = self.prune_idx[n]

                    score_list[n] = (m.mask.to(m.weight.device) * m.weight).detach().abs_()
                    # make the mask large magnitudes, so that they are not pruned away
                    score_list[n][curr_idx[0, :], curr_idx[1, :], curr_idx[2, :], curr_idx[3, :]] = 5

            global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
            k = int((1 - density) * global_scores.numel())
            threshold, _ = torch.kthvalue(global_scores, k)

            if not k < 1:
                total_num = 0
                total_den = 0
                for n, m in model.named_modules():
                    if isinstance(m, (ConvMask)):
                        score = score_list[n].to(m.weight.device)
                        zero = torch.tensor([0.]).to(m.weight.device)
                        one = torch.tensor([1.]).to(m.weight.device)
                        m.mask = torch.where(score <= threshold, zero, one)
                        total_num += (m.mask == 1).sum()
                        total_den += m.mask.numel()

            print('Overall model density after magnitude pruning at current iteration = ', total_num / total_den)
            # model = self.fix_target(model)
            print('Density after fixing the mask: ', get_sparsity(model))
            return model

    

    class RGP:
        # hopefully the main function that prunes alternately, reducing width
        # and removing random weights during pruning.
        def __init__(self, model, args):
            # instead of randomness. remove neurons sequentially
            # create a list of all the active neurons
            self.layerwise_neurons = {}
            self.neuron_sparsity_list = torch.load(base_dir + 'runs/' + args.neuron_width_list)
            self.struct_level = 0
            self.prune_ratio = args.prune_ratio
            self.prune_every = args.prune_every

        def prune_step(self, model, level):
            # alternate between struct and unstruct
            if not (level % args.prune_every == args.prune_every - 1):
                model = self.prune_unstruct_random_balanced(model, self.prune_ratio)
            else:
                if self.struct_level < len(self.neuron_sparsity_list):
                    model = self.prune_neurons(model)
                    self.struct_level += 1
                else:
                    print('structured part of pruning is complete')
                    model = self.prune_unstruct_random_balanced(model, self.prune_ratio)
            print('overall density after prune step is: ', get_sparsity(model))
            return model

        def prune_unstruct_random_balanced(self, model, prune_ratio):
            density = prune_ratio * get_sparsity(model)
            print('Pruning unstructured to target density: ', density)

            #use the self.neuron_sparsity_list to get the number of active neurons in the current layer and only
            # sparsify those in the unstructured pruning setup
            # work out how sparsity will change with each prune step!

            total_params = 0
            l = 0
            sparsity_list = []
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask)):
                    total_params += m.weight.numel()
                    l += 1
            L = l
            X = density * total_params / l
            score_list = {}
            l = 0
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask)):
                    score_list[n] = (m.mask.to(m.weight.device) * torch.rand_like(m.weight).to(m.weight.device)).detach()

                    if X / m.weight.numel() < 1.0:
                        sparsity_list.append(X / m.weight.numel())
                    else: 
                        sparsity_list.append(1)
                        # correction for taking care of exact sparsity
                        diff = X - m.mask.numel()
                        X = X + (diff / (L - l))
                    l += 1
                    print('Balanced X: ', X)

            total_num = 0
            total_den = 0
            cnt = 0
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask)):
                    global_scores = torch.flatten(score_list[n])
                    k = int((1 - sparsity_list[cnt]) * global_scores.numel())
                    if k == 0:
                        threshold = 0
                    else: 
                        threshold, _ = torch.kthvalue(global_scores, k)
                    print('Layer', n, ' params ', k, global_scores.numel())

                    score = score_list[n].to(m.weight.device)
                    zero = torch.tensor([0.]).to(m.weight.device)
                    one = torch.tensor([1.]).to(m.weight.device)
                    m.mask = torch.where(score <= threshold, zero, one)
                    total_num += (m.mask == 1).sum()
                    total_den += m.mask.numel()
                    cnt += 1

            print('Overall model density after random global (balanced) pruning at current iteration = ', total_num / total_den)
            return model


        def prune_neurons(self, model):
            # get the proportions of pruning from the ref imp sparsity list and prune away
            # base_dir + 'runs/'
            
            neuron_density = self.neuron_sparsity_list[self.struct_level]
            prune_idx_list = {}
            l = 0
            for n, m in model.named_modules():
                if isinstance(m, ConvMask):
                    k = int(neuron_density[l] * m.mask.shape[0])
                    # k is dynamic, each level will have its own k, so we can ensure the width requirement is enforced
                    prune_idx_list[n] = k
                    l += 1
            L = l


            l = 0
            for n, m in model.named_modules():
                if isinstance(m, ConvMask):
                    if l >= 0 and l < L - 1:
                        print('Pruning the outgoing edges of prev layer: ', prune_idx_list[n])
                        # remove the outgoing edges of the previous layer
                        num = prune_idx_list[n]
                        if num > 0:
                            print('Num pruning now: ', num)
                            m.mask[:num, :, :, :] = 0
                            print('incoming, right after prunign', (m.mask.sum(dim=(1, 2, 3))==0).sum().item(), m.mask.shape[0])
                            
                    if l > 0:
                        print('Pruning the incoming edges of prev layer: ', prune_idx_list[n])
                        # remove the incoming edges
                        num = prune_idx_list[n]
                        if num > 0:
                            print('Num pruning now: ', num)
                            m.mask[:, :num, :, :] = 0
        
                    print('incoming', (m.mask.sum(dim=(1, 2, 3))==0).sum().item(), m.mask.shape[0])
                    print('outgoing', (m.mask.sum(dim=(0, 2, 3))==0).sum().item(), m.mask.shape[1])        
                    
                    l += 1

            return model

        
    class PruneIdxStruct:
        def __init__(self, model, density, num_levels, args):
            # create first a list of all the indices that are to be pruned away
            # each layer neurons are chosen by the balanced sparsity
            self.levels = torch.tensor(num_levels)
            total_params = 0
            l = 0
            sparsity_list = []
            for n, m in model.named_modules():
                if isinstance(m, ConvMask):
                    total_params += m.mask.shape[1]
                    l += 1
            L = l
            prune_frac = (1 - density)
            X = prune_frac * total_params / l

            l = 0
            for n, m in model.named_modules():
                if isinstance(m, ConvMask):
                    if X / m.mask.shape[1] < 1.0:
                        sparsity_list.append(X / m.mask.shape[1])
                    else: 
                        sparsity_list.append(1)
                        # correction for taking care of exact sparsity
                        diff = X - m.mask.shape[1]
                        X = X + diff / (L - l)
                    l += 1

            print(sparsity_list)
            self.neurons = {}
            self.prev_neurons = {}

            self.total_idxs = {}
            self.total_outgoing_idxs = {}
            
            self.L = len(sparsity_list)
            print('Initializing the Struct Pruner')
            print(self.L, density, len(self.total_idxs), len(self.total_outgoing_idxs))

            l = 0
            for n, m in model.named_modules():
                if isinstance(m, ConvMask):
                    if l == 0:
                        prev = m.mask
                        prev_name = n
                    if l > 0:
                        # Only sparsify the middle layers, not the first and the last
                        # k = torch.ceil(torch.tensor(sparsity_list[l] * m.mask.shape[1])).long()
                        k = int(np.ceil(torch.tensor(sparsity_list[l] * m.mask.shape[1])))

                        print(m.mask.shape[1], k, torch.randperm(m.mask.shape[1])[:k].shape)
                        neuron = torch.randperm(m.mask.shape[1])[:k]
                        self.neurons[n] = neuron
                        self.prev_neurons[prev_name] = neuron
                        # print(idx.shape, k)
                        msub = m.mask[:, neuron, :, :]
                        msub = msub.permute(1, 0, 2, 3)
                        idx = torch.nonzero(msub)
                        self.total_idxs[n] = idx
                        bsub = prev[neuron, :, :, :]
                        idx = torch.nonzero(bsub)
                        self.total_outgoing_idxs[prev_name] = idx
                        prev = m.mask
                        prev_name = n
                    l += 1

            print(self.L, density, len(self.total_idxs), len(self.total_outgoing_idxs))

        
        def update_masked_idx(self, model):
            print('Updating the total index list of weights to be pruned')
            l = 0
            for n, m in model.named_modules():
                if isinstance(m, ConvMask):
                    if l == 0:
                        prev = m.mask
                        prev_name = n
                    if l > 0:
                        neuron = self.neurons[n]
                        msub = m.mask[:, neuron, :, :]
                        msub = msub.permute(1, 0, 2, 3)
                        idx = torch.nonzero(msub)
                        # randidx = torch.randperm(idx.shape[0])
                        # idx = idx[randidx, :]
                        self.total_idxs[n] = idx
                        
                        neuron = self.prev_neurons[prev_name]
                        bsub = prev[neuron, :, :, :]
                        idx = torch.nonzero(bsub)
                        # randidx = torch.randperm(idx.shape[0])
                        # idx = idx[randidx, :]

                        self.total_outgoing_idxs[prev_name] = idx
                        prev = m.mask
                        prev_name = n

                    l += 1

            return model

        def prune_incoming_and_update(self, model):
            l = 0
            total = 0
            nz = 0
            for n, m in model.named_modules():
                if isinstance(m, ConvMask):
                    if l >= 0 and l < self.L - 1:
                        print('Pruning the outgoing edges of prev layer: ', self.total_outgoing_idxs[n].shape)
                        # remove the outgoing edges of the previous layer
                        idx = self.total_outgoing_idxs[n]
                        
                        if idx.shape[0] > 0:
                            num = int(np.ceil(idx.shape[0] * (1 / self.levels)))
                            if num > idx.shape[0]:
                                num = idx.shape[0]
                            print('Num pruning now: ', num)
                            # num is the number of weights within the designated neurons that should be zeroed out in this step
                            neuron = self.prev_neurons[n]
                            bsub = m.mask[neuron, :, :, :]
                            # set the designated neuron weights to zero
                            # only prunes the incoming weights for a neuron
                            bsub[idx[:num, 0], idx[:num, 1], idx[:num, 2], idx[:num, 3]] = 0
                            # update the mask itself after removing the edges
                            m.mask[neuron, :, :, :] = bsub
                            print('incoming, right after prunign', (m.mask.sum(dim=(1, 2, 3))==0).sum().item(), m.mask.shape[0])
                            # update the remaining indexes
                            if num == idx.shape[0]:
                                self.total_outgoing_idxs[n] = torch.empty(0,4)
                            else:
                                self.total_outgoing_idxs[n] = idx[num:]
                    
                    if l > 0:
                        print('Pruning the incoming edges of prev layer: ', self.total_idxs[n].shape)
                        # remove the incoming edges
                        idx = self.total_idxs[n]
                        if idx.shape[0] > 0:
                            num = int(np.ceil(idx.shape[0] * (1 / self.levels)))
                            if num > idx.shape[0]:
                                num = idx.shape[0]
                            print('Num pruning now: ', num)
                            # num is the number of weights within the designated neurons that should be zeroed out in this step
                            neuron = self.neurons[n]
                            msub = m.mask[:, neuron, :, :]
                            msub = msub.permute(1, 0, 2, 3)
                            # set the designated neuron weights to zero
                            # only prunes the incoming weights for a neuron
                            msub[idx[:num, 0], idx[:num, 1], idx[:num, 2], idx[:num, 3]] = 0
                            msub = msub.permute(1, 0, 2, 3)

                            # update the mask itself after removing the edges
                            m.mask[:, neuron, :, :] = msub
                            print('outgoign, right after prunign', (m.mask.sum(dim=(0, 2, 3))==0).sum().item(), m.mask.shape[0])

                            # update the remaining indexes
                            if num == idx.shape[0]:
                                self.total_idxs[n] = torch.empty(0,4)
                            else:
                                self.total_idxs[n] = idx[num:]
                    
                    print('incoming', (m.mask.sum(dim=(1, 2, 3))==0).sum().item(), m.mask.shape[0])
                    print('outgoing', (m.mask.sum(dim=(0, 2, 3))==0).sum().item(), m.mask.shape[1])        
                    total += m.mask.numel()
                    nz += m.mask.sum()
                    l += 1

            print('the density of the model after gradual struct pruning is: ', nz / total)

            return model

    def prune_layer(model, layer_idx, density):
        # the layer index determines the residual block that we want to get rid off.
        # The idea is to remove the full block, keeping only the skip connection
        # The density is passed to gradually anneal all the parameters in the layer to 0
        # layer_idx: a string of the form 1.1, 1.2, 2.1 etc.
        # The same density is applied to the entire layer
        block_name = 'layer'+str(layer_idx)
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            print(n)
            if (block_name in n) and ('conv' in n) and isinstance(m, ConvMask):
                print('pruning the entire layer: {} with density {}'.format(n, density))
                global_scores = (m.mask.to(m.weight.device) * torch.rand_like(m.mask).to(m.weight.device))
                k = int((1 - density) * global_scores.numel())
                if k == 0:
                    threshold = 0
                else: 
                    threshold, _ = torch.kthvalue(global_scores.view(-1), k)
                print('Layer', n, ' params ', k, global_scores.numel())

                zero = torch.tensor([0.]).to(m.weight.device)
                one = torch.tensor([1.]).to(m.weight.device)
                m.mask = torch.where(global_scores <= threshold, zero, one)
                total_num += m.mask.sum()
                total_den += m.mask.numel()
        print('pruned layer: ', block_name)
        print('Density of the pruned layer is: ', total_num / total_den)
        
        return model

    def prune_neuron_gradual(model, density):
        # prune density of the layer at a neuron level proportional to the size of thr layer
        print('identifying random neurons')
        l = 0
        idxs = []
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)) and not ('downsample' in n):
                num = int(m.weight.shape[0] * density)
                # print('layer', m.weight.shape[0], num)
                # Now assign scores accordig to how many neurons are active at the moment
                # A dead neuron will have zero score (ie indegree will be zero)
                scores = (m.mask.to(m.weight.device) * torch.rand_like(m.weight).to(m.weight.device)).sum(dim=(1, 2, 3))
                k = int((1 - density) * scores.numel())
                print('Num neurons to retain: ', k, scores.numel())
                if k == 0:
                    threshold = 0
                else: 
                    threshold, _ = torch.kthvalue(scores, k)

                idx = torch.where(scores <= threshold)[0]  
                idxs.append(idx)
                l += 1
        
        l = 0
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)) and not ('downsample' in n):
                # print('layer name: ', n)
                if l == 0 or l == len(idxs)-1:
                    pass
                else:
                    idx = idxs[l]
                    print(idx.shape)
                    idx_out = idxs[l-1]
                    print(m.mask.shape, idx, idx_out)
                    m.mask[idx, :, :, :] = 0
                    m.mask[:, idx_out, :, :] = 0

                l += 1
                total_num += m.mask.sum()
                total_den += m.mask.numel()

        print('Pruned neurons in each layer as per their size and now the post pruning density is: ', total_num / total_den)
        return model


    def prune_random_balanced(model, density):

        total_params = 0
        l = 0
        sparsity_list = []
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, LinearMask)):
                total_params += m.weight.numel()
                l += 1
        L = l
        X = density * total_params / l
        score_list = {}
        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, LinearMask)):
                score_list[n] = (m.mask.to(m.weight.device) * torch.randn_like(m.weight).to(m.weight.device)).detach().abs_()

                if X / m.weight.numel() < 1.0:
                    sparsity_list.append(X / m.weight.numel())
                else: 
                    sparsity_list.append(1)
                    # correction for taking care of exact sparsity
                    diff = X - m.mask.numel()
                    X = X + diff / (L - l)
                l += 1

        total_num = 0
        total_den = 0
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, LinearMask)):
                global_scores = torch.flatten(score_list[n])
                k = int((1 - sparsity_list[cnt]) * global_scores.numel())
                if k == 0:
                    threshold = 0
                else: 
                    threshold, _ = torch.kthvalue(global_scores, k)
                print('Layer', n, ' params ', k, global_scores.numel())

                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.]).to(m.weight.device)
                one = torch.tensor([1.]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()
                cnt += 1

        print('Overall model density after random global (balanced) pruning at current iteration = ', total_num / total_den)
        return model


    def prune_random_uniform(model, density):

        total_num = 0
        total_den = 0
        
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                score = (m.mask.to(m.weight.device) * torch.randn_like(m.weight).to(m.weight.device)).detach().abs_()
                global_scores = torch.flatten(score)
                k = int((1 - density) * global_scores.numel())
                if k == 0:
                    threshold = 0
                else: 
                    threshold, _ = torch.kthvalue(global_scores, k)
                print('Layer', n, ' params ', k, global_scores.numel())

                score = score.to(m.weight.device)
                zero = torch.tensor([0.]).to(m.weight.device)
                one = torch.tensor([1.]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()
                

        print('Overall model density after random global (uniform) pruning at current iteration = ', total_num / total_den)
        return model


    def prune_random_erk(model, density):

        sparsity_list = []
        num_params_list = []
        total_params = 0
        score_list = {}

        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                score_list[n] = (m.mask.to(m.weight.device) * torch.randn_like(m.weight).to(m.weight.device)).detach().abs_()

                sparsity_list.append(torch.tensor(m.weight.shape).sum() / m.weight.numel())
                num_params_list.append(m.weight.numel())
                total_params += m.weight.numel()
        
        num_params_kept = (torch.tensor(sparsity_list) * torch.tensor(num_params_list)).sum()
        num_params_to_keep = total_params * density
        C = num_params_to_keep / num_params_kept
        print('Factor: ', C)
        sparsity_list = [torch.clamp(C*s, 0, 1) for s in sparsity_list]

        total_num = 0
        total_den = 0
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                global_scores = torch.flatten(score_list[n])
                k = int((1 - sparsity_list[cnt]) * global_scores.numel())
                if k == 0:
                    threshold = 0
                else: 
                    threshold, _ = torch.kthvalue(global_scores, k)
                print('Layer', n, ' params ', k, global_scores.numel())

                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.]).to(m.weight.device)
                one = torch.tensor([1.]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()
                cnt += 1

        print('Overall model density after random global (ERK) pruning at current iteration = ', total_num / total_den)
        return model


    def prune_random_global(model, density):

        score_list = {}

        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                score_list[n] = (m.mask.to(m.weight.device) * torch.randn_like(m.weight).to(m.weight.device)).detach().abs_()

        global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
        k = int((1 - density) * global_scores.numel())
        threshold, _ = torch.kthvalue(global_scores, k)

        total_num = 0
        total_den = 0
        cnt = 0
        if not k < 1:
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask)):
                    
                    score = score_list[n].to(m.weight.device)
                    zero = torch.tensor([0.]).to(m.weight.device)
                    one = torch.tensor([1.]).to(m.weight.device)
                    m.mask = torch.where(score <= threshold, zero, one)
                    total_num += (m.mask == 1).sum()
                    total_den += m.mask.numel()
                    cnt += 1

        print('Overall model density after random global pruning at current iteration = ', total_num / total_den)
        return model

    def prune_snip(model, trainloader, loss, density):

        for i, (images, target) in enumerate(trainloader):
            images = images.to(torch.device(f'cuda:{args.gpu}'))
            target = target.to(torch.device(f'cuda:{args.gpu}')).long()
            model.zero_grad()
            output = model(images)
            loss(output, target).backward()
            break
        
        score_list = {}
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                score_list[n] = (m.weight.grad * m.weight * m.mask.to(m.weight.device)).detach().abs_()
        
        global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
        k = int((1 - density) * global_scores.numel())
        threshold, _ = torch.kthvalue(global_scores, k)

        if not k < 1:
            total_num = 0
            total_den = 0
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask)):
                    score = score_list[n].to(m.weight.device)
                    zero = torch.tensor([0.]).to(m.weight.device)
                    one = torch.tensor([1.]).to(m.weight.device)
                    m.mask = torch.where(score <= threshold, zero, one)
                    total_num += (m.mask == 1).sum()
                    total_den += m.mask.numel()

        print('Overall model density after snip pruning at current iteration = ', total_num / total_den)
        return model


    def prune_synflow(model, trainloader, density):

        @torch.no_grad()
        def linearize(model):
            # model.double()
            signs = {}
            for name, param in model.state_dict().items():
                signs[name] = torch.sign(param)
                param.abs_()
            return signs

        @torch.no_grad()
        def nonlinearize(model, signs):
            # model.float()
            for n, param in model.state_dict().items():
                param.mul_(signs[n])
        
        signs = linearize(model)

        # (data, _) = next(iter(trainloader))
        for i, (images, target) in enumerate(trainloader):
            images = images.to(torch.device(f'cuda:{args.gpu}'))
            target = target.to(torch.device(f'cuda:{args.gpu}')).long()
            input_dim = list(images[0,:].shape)
            input = torch.ones([1] + input_dim).to(args.gpu)#, dtype=torch.float64).to(device)
            output = model(input)
            torch.sum(output).backward()
            break
        
        score_list = {}
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                score_list[n] = (m.mask.to(m.weight.device) * m.weight.grad * m.weight).detach().abs_()
        
        model.zero_grad()

        nonlinearize(model, signs)

        global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
        k = int((1 - density) * global_scores.numel())
        threshold, _ = torch.kthvalue(global_scores, k)

        if not k < 1:
            total_num = 0
            total_den = 0
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask)):
                    score = score_list[n].to(m.weight.device)
                    zero = torch.tensor([0.]).to(m.weight.device)
                    one = torch.tensor([1.]).to(m.weight.device)
                    m.mask = torch.where(score <= threshold, zero, one)
                    total_num += (m.mask == 1).sum()
                    total_den += m.mask.numel()

        print('Overall model density after synflow pruning at current iteration = ', total_num / total_den)
        return model

    def perturb_signs(model, perturb_ratio):
        # This function randomly flips the signs of perturb_ratio fraction of weights uniformly in each layer
        for n, m in model.named_modules():
            if isinstance(m, ConvMask):
                sign = torch.where(m.mask == 1, torch.where(torch.ones_like(m.mask).bernoulli_(perturb_ratio) == 1, -1, 1), 0).to(m.mask.device)
                m.weight.data = sign * m.weight.data
        return model

    def dead_neurons(model):
        dead_list = []
        dead_mask_list_in = []
        dead_mask_list_out = []
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                idxs = (m.mask.sum(dim=(1,2,3)) == 0).nonzero(as_tuple=True)
                dead_mask_in = torch.where(m.mask.sum(dim=(1,2,3)) == 0, 1, 0)
                dead_mask_list_in.append(dead_mask_in)

                dead_mask_out = torch.where(m.mask.sum(dim=(0,2,3)) == 0, 1, 0)
                dead_mask_list_out.append(dead_mask_out)

                num_dead = idxs[0].numel()
                dead_ratio = num_dead/m.mask.shape[0]
                dead_list.append((num_dead, dead_ratio))
            
        return dead_list, dead_mask_list_in, dead_mask_list_out

    def prune_dead(model):
        dead_idx_in = []
        dead_idx_out = []
        num_in = 0
        num_out = 0
        print('if there are any dead neurons in the network, all the incoming and outgoing weights to it are zeroed out')
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                in_idxs = (m.mask.sum(dim=(1,2,3)) == 0).nonzero(as_tuple=True)
                dead_idx_in.append(in_idxs)
                num_in += in_idxs[0].shape[0]
                out_idxs = (m.mask.sum(dim=(0,2,3)) == 0).nonzero(as_tuple=True)
                dead_idx_out.append(out_idxs)
                num_out += out_idxs[0].shape[0]

        print('Num in and num out: ', num_in, num_out)
        # Removing dead incoming neurons
        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                if l == 0:
                    idx = dead_idx_in[l]
                    m.mask[idx, :, :, :] = 0
                elif l == len(dead_idx_in)-1:
                    pass
                else:
                    idx = dead_idx_in[l]
                    m.mask[idx, :, :, :] = 0
                    idx_out = dead_idx_in[l-1]
                    m.mask[:, idx, :, :] = 0

                l += 1

        # Removing dead outgoing neurons
        l = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                if l == 0:
                    idx = dead_idx_out[l]
                    m.mask[:, idx, :, :] = 0
                elif l == len(dead_idx_out)-1:
                    pass
                else:
                    idx = dead_idx_out[l]
                    m.mask[:,idx, :, :] = 0
                    idx_out = dead_idx_out[l-1]
                    m.mask[idx, :, :, :] = 0

                l += 1

        return model

    
    def prune_neuron_random(model):
        num_neurons = []
        idxs = []
        num = 1
        print('identifying random neurons')
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)) and not ('downsample' in n):
                print(n, m.weight.shape)
                num_neurons.append(m.weight.shape[0])
                idx = torch.randperm(m.weight.shape[0])                
                # while m.mask.sum(dim=(1,2,3))[idx[:num]].any() != 0:
                #     idx = torch.randperm(m.weight.shape[0])                
                print('Index layer: ', idx[:num])
                idxs.append(idx[:num])
        
        print('pruning these neurons to zero')
        l = 0
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)) and not ('downsample' in n):
                print('layer name: ', n)
                if l == 0:
                    # does not prune the first and last layer and maintains their width if flag is set to True
                    if not args.fix_first_last_structured:
                        idx = idxs[l]
                        m.mask[idx, :, :, :] = 0

                elif l == len(idxs)-1:
                    if not args.fix_first_last_structured:
                        idx_out = idxs[l-1]
                        m.mask[:, idx_out, :, :] = 0
                
                else:
                    idx = idxs[l]
                    m.mask[idx, :, :, :] = 0
                    idx_out = idxs[l-1]
                    m.mask[:, idx_out, :, :] = 0

                l += 1
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

        print('Pruned one neuron in each layer and now the post pruning density is: ', total_num / total_den)
        return model

    def prune_neuron_sequential(model):
        # define number of neurons in a class and then prune 
        # within those neurons in each layer unstructured till they are completely removed
        num_neurons = []
        prune_ratio = args.structured_prune_ratio
        # remove edges from a fixed set of neurons in every layer
        idxs = []
        num = 1
        print('identifying random neurons')
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)) and not ('downsample' in n):
                print(n, m.weight.shape)
                idx = torch.randperm(m.weight.shape[0])
                idxs.append(idx[:int(prune_ratio * m.weight.shape[0])])

                # while m.mask.sum(dim=(1,2,3))[idx[:num]].any() != 0:
                #     idx = torch.randperm(m.weight.shape[0])                
                print('Index layer: ', idx[:num])
                idxs.append(idx[:num])
        
        
        print('pruning these neurons to zero')
        l = 0
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)) and not ('downsample' in n):
                print('layer name: ', n)
                if l == 0:
                    # does not prune the first and last layer and maintains their width if flag is set to True
                    if not args.fix_first_last_structured:
                        idx = idxs[l]
                        m.mask[idx, :, :, :] = 0

                elif l == len(idxs)-1:
                    if not args.fix_first_last_structured:
                        idx_out = idxs[l-1]
                        m.mask[:, idx_out, :, :] = 0
                
                else:
                    idx = idxs[l]
                    m.mask[idx, :, :, :] = 0
                    idx_out = idxs[l-1]
                    m.mask[:, idx_out, :, :] = 0

                l += 1
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

        print('Pruned one neuron sequential in each layer and now the post pruning density is: ', total_num / total_den)
        return model
    


    def prune_mask_ref(model, level, name):
        # Loading a mask at every level
        # ref_mask = 'runs/mask_cifar-imp-rewind-save-every-0.1-seed-4_'
        base_dir = ''
            
        ref_mask = base_dir + 'runs/mask_' + str(name) + '_'
        mask_list = torch.load(ref_mask + str(level) + '.pt')

        total_num = 0
        total_den = 0
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                
                m.mask = mask_list[cnt]
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()
                cnt += 1
    
        print('Overall model density after loading IMP mask = ', total_num / total_den)
        return model


    def integrate_bn_params(model):
        '''
        The aim is to incorporate the scaling of the BN parameter in the weights of the previous layer.
        Since we do not have a bias in the weight, we modify the bias of the BN layer to account for this integration.
        '''
        bn_params = []
        eps = 1e-5
        for n, m in model.named_modules():
            if isinstance(m, nn.BatchNorm2d):
                gamma, beta, mu, var = m.weight.clone(), m.bias.clone(), m.running_mean, m.running_var
                bn_params.append([gamma, beta, mu, var])
                m.weight.data.fill_(1)
                m.bias.data = beta - (gamma * mu / torch.sqrt(var + eps))
                m.running_mean.zero_()
                m.running_var.fill_(1)
        
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, ConvMask) and (cnt < len(bn_params)):
                gamma, beta, mu, var = bn_params[cnt][0], bn_params[cnt][1], bn_params[cnt][2], bn_params[cnt][3] 
                m.weight.data = m.weight.data * gamma[:, None, None, None] / torch.sqrt(var[:, None, None, None] + eps)
                cnt += 1
        print('integrated the batch norm parameters in the Conv Weights, but kept the Bias of the BN')

        return model


    def prune_mag_integrated_bn(model, density, include_gamma=False):
        ### This function first integrates the BN parameters in the weights and then uses these corrected weights as the scores for magnitude prunining
        ### In effect, we are only scaling the weights by the factor gamma. It should be gamma / sd, but this introduces further problems. Hence we start with simply multiplying by gamma.
        
        bn_params = []
        eps = 1e-5
        for n, m in model.named_modules():
            if isinstance(m, nn.BatchNorm2d):
                gamma, beta, mu, var = m.weight.clone(), m.bias.clone(), m.running_mean, m.running_var
                bn_params.append([gamma, beta, mu, var])
        
        score_list = {}
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                if cnt < len(bn_params):
                    gamma = bn_params[cnt][0]
                    var = bn_params[cnt][3]
                    if include_gamma:
                        score_list[n] = (m.mask.to(m.weight.device) * m.weight * gamma[:, None, None , None ] / torch.sqrt(eps + var[:, None, None, None])).detach().abs_()
                    else:
                        score_list[n] = (m.mask.to(m.weight.device) * m.weight * gamma[:, None, None , None ]).detach().abs_()
                    cnt += 1
                else: 
                    score_list[n] = (m.mask.to(m.weight.device) * m.weight).detach().abs_()
                    cnt += 1

        global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
        k = int((1 - density) * global_scores.numel())
        threshold, _ = torch.kthvalue(global_scores, k)

        if not k < 1:
            total_num = 0
            total_den = 0
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask)):
                    score = score_list[n].to(m.weight.device)
                    zero = torch.tensor([0.]).to(m.weight.device)
                    one = torch.tensor([1.]).to(m.weight.device)
                    m.mask = torch.where(score <= threshold, zero, one)
                    total_num += (m.mask == 1).sum()
                    total_den += m.mask.numel()

        print('Overall model density after magnitude pruning with integrated BN gamma in scores at current iteration = ', total_num / total_den)
        return model

    
    if args.er_sparse_method == 'snip':
        print('Initializing with ER snip mask')
        model = prune_snip(model, data.train_loader, criterion, args.er_sparse_init)
    
    if args.er_sparse_method == 'synflow':
        print('Initializing with ER synflow mask')
        model = prune_synflow(model, data.train_loader, args.er_sparse_init)

    ### Define a new pruner, which prunes based on the value of the BN parameter gamma, and removes filters (channels with the lowest )
    # def prune_bn_struct(model, density):
    #     score_list = {}
    #     cnt = 0
    #     for n, m in model.named_modules():
    #         if isinstance(m, (nn.BatchNorm2d)):
    #             score_list[n] = m.weight.detach().abs_()

                
    #     global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    #     k = int((1 - density) * global_scores.numel())
    #     threshold, _ = torch.kthvalue(global_scores, k)
    ######
    ## Set all the conv parameters to non trainable, and only train the BN parameters
    #####
    if args.train_only_bn:
        for n, m in model.named_modules():
                if isinstance(m, (ConvMask)):
                    m.weight.requires_grad = False
        print('Setting all the weights of the conv layers to not require gradient i.e. freezing them')
    
    if args.train_only_bn_and_linear:
        for n, m in model.named_modules():
                if isinstance(m, (ConvMask)) and ('linear' not in n):
                    m.weight.requires_grad = False
        print('Freezing all conv layers, only training BN and last linear layeer')
    # Iterative Prune Train before final train
    
    warmup_scheduler = warmup_lr(optimizer, args)
    val_acc_total = []
    sparsity_actual = []
    sparsity_actual.append(1.0)
    # Adding the first level sparsity of the dense network to the sparsity list
    ############## 
    base_dir = ''

        
    # warmup training
    if not args.warmup:
        args.warmup_epochs = 0
    curr_epoch = 0
    if args.warmup:
        print('Warm Up training for the model')
        for epoch in range(args.warmup_epochs):
            warmup_scheduler(epoch)
            lr = get_lr(optimizer)
            print('The curent learning rate is: ', lr)

            start_train = time.time()
            train_acc1, train_acc5 = train(
                data.train_loader, model, criterion, optimizer, epoch, args, writer=writer
            )
            train_time.update((time.time() - start_train) / 60)
            # evaluate on validation set
            start_validation = time.time()
            acc1, acc5 = validate(data.val_loader, model, criterion, args, writer, epoch)
            validation_time.update((time.time() - start_validation) / 60)
            
            val_acc_total.append(acc1)
            curr_epoch += 1

    # # save the model and the optimizer
    if (args.resume_training_from_level) or (args.spawn_training_from_level):
        print('Not saving model inits since they are already saved from a previous run of the expt')
        pass
    else:
        print('saving the model init')
        torch.save(model.state_dict(), "{}runs/model_{}_init.pt".format(base_dir, args.name))
        torch.save(optimizer.state_dict(),"{}runs/optimizer_{}.pt".format(base_dir, args.name))
        # torch.save(val_acc_total, base_dir + 'runs/val_acc_'+ args.name + '.pt')


    #### Initializing the class for Gradual pruning within a predefined set of neurons
    # PruneIdxStruct
    if args.pruner == 'struct-gradual':
        pruner = PruneIdxStruct(model, args.struct_density, len(args.threshold_list), args)

    if args.pruner == 'mag-fixed-mask':
        print('Fix a target mask with balanced sparsity and do magnitude pruning around it to reach the target')
        pruner = PruneMagInMask(model, args)
        args.threshold_list = pruner.threshold_list
        print('The threshold list for this prune mag in mask experiment is:', args.threshold_list)

    if args.pruner == 'rgp':
        pruner = RGP(model, args)

    if args.pruner == 'rand-fixed-mask':
        print('Fixing a mask and pruning towards it randomly')
        pruner = PruneRandInMask(model, args)

    hessian_list = []

    ####################
    # Load a model and mask from a given checkpoint
    if args.load_before_prune:
        print('Loading model from: ', args.load_model_name, args.load_mask_name)
        model = resume_from_checkpoint(model, args.load_model_name, args.load_mask_name)
        if args.fix_sign_and_train:
            init_sign_list = []
            for n, m in model.named_modules():
                # Only looking at conv weight signs for now, not at BN parameter signs
                if isinstance(m, ConvMask):
                    init_sign_list.append(m.weight.sign())
    
    
    # Loads only the model, mask will be all ones
    if args.load_only_model:
        print('Loading ONLY model from: ', args.load_model_name)
        base_dir = ''

        model_name = base_dir + 'runs/' + args.load_model_name
        model.load_state_dict(torch.load(model_name))
        # Save the model again as init so that it can be loaded for reset weights
        print('over writing model weights init with loaded model')
        torch.save(model.state_dict(), "{}runs/model_{}_init.pt".format(base_dir, args.name))
        ####################

    ### Loading only the mask of a reference model for any ablations
    if args.load_only_mask:
        print('Loading ONLY mask: ', args.load_mask_name)
        mask_name = base_dir + 'runs/' + args.load_mask_name
        mask_list = torch.load(mask_name)
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, ConvMask):
                m.mask = mask_list[cnt].to(m.weight.device)
                cnt += 1
        torch.save(model.state_dict(), "{}runs/model_{}_init.pt".format(base_dir, args.name))
        


    # Loads only the model, mask will be all ones and also loads the sign for the initialized model
    if args.load_model_mask_and_sign:
        print('Loading ONLY model and mask from: ', args.load_model_name, args.load_mask_name)
        base_dir = ''


        print('Changing the sign of the reset weight according to the expt: ', args.load_expt_sign)
        # sign_list = torch.load("{}runs/sign_list_{}.pt".format(base_dir, args.load_expt_sign))
        model_name = base_dir + 'runs/' + args.load_model_sign_name
        model.load_state_dict(torch.load(model_name))
        init_sign_list = []
        for n, m in model.named_modules():
            # Only looking at conv weight signs for now, not at BN parameter signs
            if isinstance(m, ConvMask):
                init_sign_list.append(m.weight.sign())
        
        model = resume_from_checkpoint(model, args.load_model_name, args.load_mask_name)
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                m.weight.data = m.weight.data.abs_() * init_sign_list[cnt].to(m.weight.device)
                cnt += 1

        # Save the model again as init so that it can be loaded for reset weights
        print('over writing model weights init with loaded model')
        torch.save(model.state_dict(), "{}runs/model_{}_sign_changed_init.pt".format(base_dir, args.name))

    ######
    if args.load_shuffled_model_mask_and_sign:
        print('Loading model and mask from, and shuffling the model weights: ', args.load_model_name, args.load_mask_name)
        base_dir = ''

        model = resume_from_checkpoint(model, args.load_model_name, args.load_mask_name)

        print('Changing the sign of the reset weight according to the expt: ', args.load_model_sign_name)
        # sign_list = torch.load("{}runs/sign_list_{}.pt".format(base_dir, args.load_expt_sign))
        model_name = base_dir + 'runs/' + args.load_model_sign_name
        model.load_state_dict(torch.load(model_name))
        init_sign_list = []
        for n, m in model.named_modules():
            # Only looking at conv weight signs for now, not at BN parameter signs
            if isinstance(m, ConvMask):
                init_sign_list.append(m.weight.sign())
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                curr_shape = m.weight.shape
                idx = torch.where(m.mask == 1)
                buf = m.weight.data.clone()
                shuffle = buf[idx]
                perm = torch.randperm(shuffle.numel())
                shuffle = shuffle[perm]
                buf[idx] = shuffle
                
                m.weight.data = buf.abs_() * init_sign_list[cnt].to(m.weight.device)
                cnt += 1

        # Save the model again as init so that it can be loaded for reset weights
        print('over writing model weights init with loaded model')
        torch.save(model.state_dict(), "{}runs/model_{}_model_shuffled_sign_changed_init.pt".format(base_dir, args.name))

    ##### Loading a teacher model for self knowledge distillation
    if args.kd:
        teacher = resume_from_checkpoint(teacher, args.teacher_model, args.teacher_mask)
        kd_train_loss = []

    if args.hessian:
        trace = hessian_trace(data.train_loader, model, criterion, optimizer, args)
        hessian_list.append(trace)
    
    if args.lbfgs_compensate:
        lbfgs_compensate = CompensatePrune(model)

    
    
    ### Level for pruning
    level = 0

    ### Spawn a different training run from a level loading the current model and mask
    if args.spawn_training_from_level:
        level = args.resume_level
        model_name = base_dir + 'runs/' + args.load_model_init_name
        model.load_state_dict(torch.load(model_name))
        torch.save(model.state_dict(), "{}runs/model_{}_init.pt".format(base_dir, args.name))
        torch.save(optimizer.state_dict(), "{}runs/optimizer_{}.pt".format(base_dir, args.name))

        print('saved the model init for the spawned experiment')
        model = resume_from_checkpoint(model, args.load_model_name, args.load_mask_name)
        print('loaded the resumption level for the spawned experiment')

        if args.reset_only_bn:
            print('resetting BN parameters and the optimizer')
            original_dict = torch.load("{}runs/model_{}_init.pt".format(base_dir, args.name))
            # original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
            original_weights = {k:v for (k,v) in original_dict.items() if ('bn' in k) or ('downsample.1' in k)}
            print('resets the BN parameters along with the running mean and variance to the init values')
            model_dict = model.state_dict()
            model_dict.update(original_weights)
            model.load_state_dict(model_dict)
            
            # Reset Optimizer and Scheduler
            optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
        
        if args.reset_bn_default:
            print('resetting BN parameters back to unity and the optimizer')
            for n, m in model.named_modules():
                if isinstance(m, nn.BatchNorm2d):
                    print('resetting BN parameters back to unity and the optimizer')
                    m.weight.data = torch.ones_like(m.weight)
                    m.bias.data = torch.zeros_like(m.bias)
                    m.running_mean = torch.zeros_like(m.running_mean)
                    m.running_var = torch.ones_like(m.running_var)
            # Reset Optimizer and Scheduler
            optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
            

    if args.resume_training_from_level:
        print('resuming model and mask from: ', args.load_model_name, args.load_mask_name)

        model = resume_from_checkpoint(model, args.load_model_name, args.load_mask_name)
        # reset the weights after loading from the previous level so that the exact weights and biases are rewound
        if args.reset_weights:
            print('resetting the weights and resuming training')
            original_dict = torch.load("{}runs/model_{}_init.pt".format(base_dir, args.name))
            original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
            model_dict = model.state_dict()
            model_dict.update(original_weights)
            model.load_state_dict(model_dict)
            optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))

        if args.reset_only_bn:
            print('resetting BN parameters and the optimizer')
            original_dict = torch.load("{}runs/model_{}_init.pt".format(base_dir, args.name))
            # original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
            original_weights = {k:v for (k,v) in original_dict.items() if ('bn' in k) or ('downsample.1' in k)}
            print('resets the BN parameters along with the running mean and variance to the init values')
            model_dict = model.state_dict()
            model_dict.update(original_weights)
            model.load_state_dict(model_dict)
            
            # Reset Optimizer and Scheduler
            optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
        
        if args.reset_bn_default:
            print('resetting BN parameters back to unity and the optimizer')
            for n, m in model.named_modules():
                if isinstance(m, nn.BatchNorm2d):
                    print('resetting BN parameters back to unity and the optimizer')
                    m.weight.data = torch.ones_like(m.weight)
                    m.bias.data = torch.zeros_like(m.bias)
                    m.running_mean = torch.zeros_like(m.running_mean)
                    m.running_var = torch.ones_like(m.running_var)
            # Reset Optimizer and Scheduler
            optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
            

        val_acc_total = torch.load("{}runs/val_acc_{}.pt".format(base_dir, args.name))
        print('Loaded val acc list from previous checkpoint to continue training')
        level = args.resume_level
        print('Level is: ', level)

    torch.save(val_acc_total, base_dir + 'runs/val_acc_'+ args.name + '.pt')
    torch.save(sparsity_actual, base_dir + 'runs/sparsity_list_'+ args.name + '.pt')
    
    ###
    if args.transfer_task:
        base_dir = ''

        model_name = base_dir + 'runs/' + args.load_model_name
        mask_name = base_dir + 'runs/' + args.load_mask_name
        model_weight = torch.load(model_name)
        mask_list = torch.load(mask_name)
        pretrained_dict = model.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if 'linear' not in k}
        # 2. overwrite entries in the existing state dict
        model_dict = model.state_dict()
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_dict)
        cnt = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)) and 'linear' not in n:
                m.mask = mask_list[cnt].to(m.weight.device)
                cnt += 1
            # load the weights of cifar100 res18 into cifar10 res18
            print('Transferring Task')
        
    #####
    if args.pruner == 'layer-remove':
        density_list = []
        for idx in args.layer_idx_list:
            for density in args.threshold_list:
                density_list.append([idx, density])
        args.threshold_list = density_list
        print('Changed the threshold list for removing layers gradually: ', args.threshold_list)

    ######
    # FIX THE FIRST AND LAST LAYER
    ######
    if args.fix_first_and_last:
        for n, m in model.named_modules():
            if isinstance(m, ConvMask):
                if n == 'conv1' or n == 'linear':
                    m.mask = torch.ones_like(m.weight).to(m.weight.device)
        print('Fixed the mask of the first and last layer to be dense')

    def learn_signs(model, args):

        criterion = nn.CrossEntropyLoss()
        acc1, acc5 = validate(data.val_loader, model, criterion, args, writer=None, epoch=args.start_epoch)

        sign_dict = []
        bn_dict = []
        for n, m in model.named_modules():
            if isinstance(m, ConvMask):
                sign_dict.append(m.weight)
            if isinstance(m, nn.BatchNorm2d):
                bn_dict.append(m.weight)
                bn_dict.append(m.bias)
        
        sign_optimizer = MomentumWithThresholdBinaryOptimizer(sign_dict, bn_dict)
        
        print('Optimizing the signs and learning only the signs of the network, given the mask is fixed')
        for k in range(10):
            train_loss = 0
            for i, (images, target) in enumerate(data.train_loader):
                model.train()
                sign_optimizer.zero_grad()
                images = images.to(torch.device(f'cuda:{args.gpu}'))
                target = target.to(torch.device(f'cuda:{args.gpu}')).long()
                output = model(images)
                loss = criterion(output, target)
                preds = output.argmax(dim=1)
                train_loss += loss
                loss.backward()
                flips = sign_optimizer.step()
                if i % 50 == 49:
                    print('Train Loss while only optimizing signs: {:.4f}, {:.2f}, {}'.format(loss, 100*(preds == target).sum()/target.shape[0], flips))
                
            model.eval()
            acc1, acc5 = validate(data.val_loader, model, criterion, args, writer=None, epoch=args.start_epoch)
        
        print('Freezing the sign parameters')
        for n, m in model.named_modules():
            if isinstance(m, ConvMaskSign):
                m.sign.requires_grad = False
        return model

    #######
    if args.optimize_signs:
        # The idea is to optimize only the signs
        model = learn_signs(model, args)
        torch.save(model.state_dict(), "{}runs/model_{}_optimized_signs_init.pt".format(base_dir, args.name))
        
    #######
    
    if args.conv_type == 'ConvMask' or args.conv_type == 'LinearMask':
        if args.cycle_train:
            assign_learning_rate(optimizer, args.lr)
            if args.prune_scheduler == 'cosine':
                prune_scheduler = cosine_lr(optimizer, args.final_prune_epoch, args)
            if args.prune_scheduler == 'step-warmup':
                prune_scheduler = multistep_lr_warmup(optimizer, args)
            if args.prune_scheduler == 'cosine-warmup':
                prune_scheduler = cosine_lr_warmup(optimizer, args.final_prune_epoch, args)
            if args.prune_scheduler == 'constant':
                prune_scheduler = constant_lr(optimizer, args)
                assign_learning_rate(optimizer, args.constant_prune_lr)
                print('Learning rate during pruning is constant at: ', args.constant_prune_lr)
            if args.prune_scheduler == 'step':
                prune_scheduler = multistep_lr(optimizer, args)
            if args.prune_scheduler == 'step-drops':
                prune_scheduler = multistep_lr_drops(optimizer, args)
                print('Step drops LR scheduler')
            if args.prune_scheduler == 'imagenet-step':
                prune_scheduler = imagenet_lr_drops(optimizer, args)
                print('ImageNet Step drops LR scheduler')
            if args.prune_scheduler == 'imagenet-step-warmup':
                prune_scheduler = imagenet_lr_drops_warmup(optimizer, args)
                print('ImageNet Step drops with warmup every cycle LR scheduler')
            if args.prune_scheduler == 'cyclic':
                print('Training with cyclic LR with a cycle duration of 50 epochs')
                prune_scheduler = cyclic_lr(optimizer, args)
            if args.prune_scheduler == 'one-cycle':
                print('Training with a one cycle LR with a warmup cycle of 10 epochs and then a linear decrease')
                prune_scheduler = one_cycle_lr(optimizer, args)



            print('Performing cycle training, repeating same schedule for multiple levels: ', args.cycle_train_levels)
            for iter in range(args.cycle_train_levels):
                for epoch in range(args.final_prune_epoch):
                    print('chai me dooba')
                    prune_scheduler(epoch)
                    lr = get_lr(optimizer)
                    print('The curent learning rate is: ', lr)

                    #####
                    if args.perturb_signs:
                        if (epoch % args.perturb_sign_every) == args.perturb_sign_every-1:
                            print('Perturbing signs of model while training')
                            model = perturb_signs(model, args.perturb_sign_ratio)
                    #####

                    start_train = time.time()
                    train_acc1, train_acc5 = train(data.train_loader, model, criterion, optimizer, epoch, args, writer=writer)
                    train_time.update((time.time() - start_train) / 60)

                    # evaluate on validation set
                    start_validation = time.time()
                    acc1, acc5 = validate(data.val_loader, model, criterion, args, writer, epoch)
                    acc1, acc5 = 0.2, 0.2
                    validation_time.update((time.time() - start_validation) / 60)
                    val_acc_total.append(acc1)
                    curr_epoch += 1
                    if args.hessian:
                        trace = hessian_trace(data.train_loader, model, criterion, optimizer, args)
                        hessian_list.append(trace)
                    
                is_best = acc1 > best_acc1
                best_acc1 = max(acc1, best_acc1)
                best_acc5 = max(acc5, best_acc5)
                best_train_acc1 = max(train_acc1, best_train_acc1)
                best_train_acc5 = max(train_acc5, best_train_acc5)
            ###
        
        print('Threshold list: ', args.threshold_list)
        for threshold in args.threshold_list:
            assign_learning_rate(optimizer, args.lr)
            if args.prune_scheduler == 'cosine':
                prune_scheduler = cosine_lr(optimizer, args.final_prune_epoch, args)
            if args.prune_scheduler == 'step-warmup':
                prune_scheduler = multistep_lr_warmup(optimizer, args)
            if args.prune_scheduler == 'cosine-warmup':
                prune_scheduler = cosine_lr_warmup(optimizer, args.final_prune_epoch, args)
            if args.prune_scheduler == 'constant':
                prune_scheduler = constant_lr(optimizer, args)
                assign_learning_rate(optimizer, args.constant_prune_lr)
                print('Learning rate during pruning is constant at: ', args.constant_prune_lr)
            if args.prune_scheduler == 'step':
                prune_scheduler = multistep_lr(optimizer, args)
            if args.prune_scheduler == 'step-drops':
                prune_scheduler = multistep_lr_drops(optimizer, args)
                print('Step drops LR scheduler')
            if args.prune_scheduler == 'imagenet-step':
                prune_scheduler = imagenet_lr_drops(optimizer, args)
                print('ImageNet Step drops LR scheduler')
            if args.prune_scheduler == 'imagenet-step-warmup':
                prune_scheduler = imagenet_lr_drops_warmup(optimizer, args)
                print('ImageNet Step drops with warmup every cycle LR scheduler')
            if args.prune_scheduler == 'cyclic':
                print('Training with cyclic LR with a cycle duration of 50 epochs')
                prune_scheduler = cyclic_lr(optimizer, args)
            if args.prune_scheduler == 'one-cycle':
                print('Training with a one cycle LR with a warmup cycle of 10 epochs and then a linear decrease')
                prune_scheduler = one_cycle_lr(optimizer, args)


            for epoch in range(args.final_prune_epoch):
                print('chai me dooba')
                prune_scheduler(epoch)
                lr = get_lr(optimizer)
                print('The curent learning rate is: ', lr)

                #####
                if args.perturb_signs:
                    if (epoch % args.perturb_sign_every) == args.perturb_sign_every-1:
                        print('Perturbing signs of model while training')
                        model = perturb_signs(model, args.perturb_sign_ratio)
                #####

                start_train = time.time()
                train_acc1, train_acc5 = train(
                    data.train_loader, model, criterion, optimizer, epoch, args, writer=writer
                )
                train_time.update((time.time() - start_train) / 60)

                # evaluate on validation set
                start_validation = time.time()
                acc1, acc5 = validate(data.val_loader, model, criterion, args, writer, epoch)
                validation_time.update((time.time() - start_validation) / 60)
                val_acc_total.append(acc1)
                curr_epoch += 1
                if args.hessian:
                    trace = hessian_trace(data.train_loader, model, criterion, optimizer, args)
                    hessian_list.append(trace)
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)
            best_acc5 = max(acc5, best_acc5)
            best_train_acc1 = max(train_acc1, best_train_acc1)
            best_train_acc5 = max(train_acc5, best_train_acc5)
            ###
            if args.integrate_bn:
                print('Integrating the BN parameters in the model weights')
                model = integrate_bn_params(model)
            ###
            if args.pruner == 'neuron-flow':
                print('Pruning at threshold: ', threshold)
                model = prune_flow(model, threshold) 
            
            if args.pruner == 'random-balanced':
                print('Pruning randomly with balanced at density: ', threshold)
                model = prune_random_balanced(model, threshold)
            
            if args.pruner == 'random-global':
                print('Pruning globally randomly at density: ', threshold)
                model = prune_random_global(model, threshold) 
            
            if args.pruner == 'random-erk':
                print('Pruning with random (ERK) with density: ', threshold)
                model = prune_random_erk(model, threshold)

            if args.pruner == 'random-uniform':
                print('Pruning with random (Uniform) with density: ', threshold)
                model = prune_random_uniform(model, threshold) 

            if args.pruner == 'mag':
                print('Pruning by magnitude at density: ', threshold)
                model = prune_mag(model, threshold)    

            if args.pruner == 'snip':
                print('Pruning by SNIP at density: ', threshold)
                model = prune_snip(model, data.train_loader, criterion, threshold)     
            
            if args.pruner == 'synflow':
                print('Pruning by SynFlow at density: ', threshold)
                model = prune_synflow(model, data.train_loader, threshold)     
            
            if args.pruner == 'mask-ref':
                print('Pruning by loading an equivalent mask from a different experiment')
                model = prune_mask_ref(model, level, args.target_expt_name)

            if args.pruner == 'mag-with-bn':
                print('Pruning with magnitude corrected with BN gamma')
                model = prune_mag_integrated_bn(model, threshold)
            
            if args.pruner == 'random-imp-ref':
                print('Pruning randomly with layerwise sparsities proportional to IMP')
                model = prune_random_with_imp_sparsity(model, threshold, args)
            
            if args.pruner == 'neuron-gradual':
                print('Pruning neurons in each layer with pruning fraction: ', threshold)
                model = prune_neuron_gradual(model, threshold)

            if args.pruner == 'layer-remove':
                # The threshold_list must contain both the layer index and the step in which it is pruned
                print('pruning the entire layer: ', threshold[0], threshold[1])
                model = prune_layer(model, threshold[0], threshold[1])

            if args.pruner == 'struct-gradual':
                
                print('Updating the remaining indexes of the weights to be pruned')
                model = pruner.update_masked_idx(model)

                print('Pruning weights within a predefined set of neurons')
                model = pruner.prune_incoming_and_update(model)

            if args.pruner == 'rgp':
                print('Pruning with RGP, alternating struct and random unstruct')
                model = pruner.prune_step(model, level)

            if args.pruner == 'mag-fixed-mask':
                print('Pruning by magnitude for a target mask')
                model = pruner.prune_mag(model, threshold)

            if args.pruner == 'rand-fixed-mask':
                print('pruning randomly to a fixed mask')
                model = pruner.prune_random_balanced(model, threshold)

            if args.pruner == 'prune-width-gradual':
                if level % 2 == 0:
                    print('Pruning width gradually with prune fraction: {}'.format(args.prune_ratio))
                    model = prune_width_gradual(model, args)
                else:
                    density = get_sparsity(model)
                    model = prune_random_balanced(model, args.prune_ratio_unstruct*density)
                    print('pruning unstruct at density: ', args.prune_ratio_unstruct*density)
            
            if args.pruner == 'prune-width-gradual-nonuniform':
                if level % 2 == 0:
                    print('Pruning width gradually with nonuniform neurons and prune fraction: {}'.format(args.prune_ratio))
                    model = prune_width_gradual_nonuniform_ratio(model, args)
                else:
                    density = get_sparsity(model)
                    model = prune_random_balanced(model, args.prune_ratio_unstruct*density)
                    print('pruning unstruct at density: ', args.prune_ratio_unstruct*density)

            if args.pruner == 'prune-width-healthy':
                if level % 2 == 0:
                    print('Pruning width to maintain healthy sparsity')
                    prev = get_sparsity(model)
                    model = prune_width_healthy(model, args)
                    curr = get_sparsity(model)
                    print('Pruning width before and after density: ', prev, curr)
                else:
                    model = prune_random_balanced(model, args.prune_ratio_unstruct*curr)
                    print('pruning unstruct at density: ', args.prune_ratio_unstruct*curr)
                    print('Pruning unstruct after: ', get_sparsity(model))

            if args.pruner == 'mag-max-perturb':
                model = prune_max_perturb(model, prune_mag, data.val_loader, criterion, writer, val_acc_total[-1], args)

            if args.pruner == 'balanced-max-perturb':
                model = prune_max_perturb(model, prune_random_balanced, data.val_loader, criterion, writer, val_acc_total[-1], args)

            if args.pruner == 'memorized':
                print('Pruning the memorized neurons from the network')
                model = prune_memorized(args, model)
                print('Density after pruning memorized neurons is: {}'.format(get_sparsity(model)))

            # else:
            #     density = get_sparsity(model)
            #     model = prune_random_balanced(model, args.prune_ratio_unstruct*density)

            # compensation with LBFGS
            if args.lbfgs_compensate:
                print('compensation with LBFGS')
                lbfgs_compensate.step(model)

            mask_list = []
            for name, m in model.named_modules():
                if isinstance(m, (ConvMask, LinearMask)):
                    mask_list.append(m.mask)
            torch.save(mask_list, '{}runs/mask_{}_{}.pt'.format(base_dir, args.name, level))
            torch.save(model.state_dict(),"{}runs/model_{}_{}.pt".format(base_dir, args.name, level))
            level += 1

            torch.save(val_acc_total, base_dir + 'runs/val_acc_'+ args.name + '.pt')
            sparsity_actual.append(get_sparsity(model))
            torch.save(sparsity_actual, base_dir + 'runs/sparsity_list_'+ args.name + '.pt')

            if args.reset_weights:
                original_dict = torch.load("{}runs/model_{}_init.pt".format(base_dir, args.name))
                original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
                model_dict = model.state_dict()
                model_dict.update(original_weights)
                model.load_state_dict(model_dict)
                
                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
                print('Weights of the model reset to initialization weights')
                
                if args.reset_weight_sign:
                    print('Changing the sign of the reset weight according to the expt: ', args.load_expt_sign)
                    sign_list = torch.load("{}runs/sign_list_{}_{}.pt".format(base_dir, args.load_expt_sign, level))
                    cnt = 0
                    for n, m in model.named_modules():
                        if isinstance(m, (ConvMask)):
                            m.weight.data = m.weight.data.abs_() * sign_list[cnt].to(m.weight.device)
                            cnt += 1
            
            if args.reset_shuffled_amplitude:

                sign_list = []
                for n, m in model.named_modules():
                    if isinstance(m, (ConvMask)):
                        sign_list.append(m.weight.sign())

                original_dict = torch.load("{}runs/model_{}_init.pt".format(base_dir, args.name))
                original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
                model_dict = model.state_dict()
                model_dict.update(original_weights)
                model.load_state_dict(model_dict)
                
                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
                print('Resetting only the amplitude during IMP and shuffling it, while keeping the signs!!!')

                cnt = 0
                for n, m in model.named_modules():
                    if isinstance(m, (ConvMask)):
                        buf = m.weight.data.clone()
                        perm = torch.randperm(buf.numel())
                        buf = buf.view(-1)[perm].view(buf.shape)
                        m.weight.data = buf.abs_() * sign_list[cnt].to(m.weight.device)
                        cnt += 1

            if args.reset_only_bn:
                print('resetting BN parameters and the optimizer')
                original_dict = torch.load("{}runs/model_{}_init.pt".format(base_dir, args.name))
                # original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
                original_weights = {k:v for (k,v) in original_dict.items() if ('bn' in k) or ('downsample.1' in k)}
                print('resets the BN parameters along with the running mean and variance to the init values')
                model_dict = model.state_dict()
                model_dict.update(original_weights)
                model.load_state_dict(model_dict)
                
                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
            
            if args.reset_bn_default:
                print('resetting BN parameters back to unity and the optimizer')
                for n, m in model.named_modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.weight.data = torch.ones_like(m.weight)
                        m.bias.data = torch.zeros_like(m.bias)
                        m.running_mean = torch.zeros_like(m.running_mean)
                        m.running_var = torch.ones_like(m.running_var)

                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
                
            if args.reset_only_weights_not_bn:
                print('resetting weights of the model and the optimizer, but not resetting the BN parameters')

                original_dict = torch.load("{}runs/model_{}_init.pt".format(base_dir, args.name))
                # original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
                original_weights = {k:v for (k,v) in original_dict.items() if ('weight' in k) and (('conv' in k) or ('linear' in k) or ('downsample.0.' in k))}
                print('resets the only the weights to the initial values')
                model_dict = model.state_dict()
                model_dict.update(original_weights)
                model.load_state_dict(model_dict)
                
                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
            

            if args.reset_only_bn_amplitude:
                print('resetting only BN parameter amplitudes and the optimizer')

                sign_list = []
                for n, m in model.named_modules():
                    if isinstance(m, (nn.BatchNorm2d)):
                        sign_list.append(m.weight.sign())

                original_dict = torch.load("{}runs/model_{}_init.pt".format(base_dir, args.name))
                # original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
                original_weights = {k:v for (k,v) in original_dict.items() if ('bn' in k) or ('downsample.1' in k)}
                
                print('resets the BN parameters along with the running mean and variance to the init values')
                model_dict = model.state_dict()
                model_dict.update(original_weights)
                model.load_state_dict(model_dict)
                
                cnt = 0
                for n, m in model.named_modules():
                    if isinstance(m, (nn.BatchNorm2d)):
                        m.weight.data = m.weight.data.abs_() * sign_list[cnt].to(m.weight.device)
                        cnt += 1

                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
               

            if args.reset_weight_amplitude:

                sign_list = []
                for n, m in model.named_modules():
                    if isinstance(m, (ConvMask)):
                        sign_list.append(m.weight.sign())

                original_dict = torch.load("{}runs/model_{}_init.pt".format(base_dir, args.name))
                original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
                model_dict = model.state_dict()
                model_dict.update(original_weights)
                model.load_state_dict(model_dict)
                
                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
                print('Resetting only the amplitude during IMP, while keeping the signs!!!')

                cnt = 0
                for n, m in model.named_modules():
                    if isinstance(m, (ConvMask)):
                        m.weight.data = m.weight.data.abs_() * sign_list[cnt].to(m.weight.device)
                        cnt += 1

            if args.reset_only_weight_sign:
                # Get the signs from the initial model.
                original_dict = torch.load("{}runs/model_{}_init.pt".format(base_dir, args.name))
                sign_list = []
                for n in original_dict.keys():
                    if 'weight' in n and(('conv' in n) or ('linear' in n) or ('downsample.0.' in n)):
                        sign_list.append(original_dict[n].sign())

                optimizer.load_state_dict(torch.load("{}runs/optimizer_{}.pt".format(base_dir, args.name)))
                print('resetting only the signs of the weights to the signs of IMP init and continuing LRR')
                cnt = 0
                for n, m in model.named_modules():
                    if isinstance(m, (ConvMask)):
                        m.weight.data = m.weight.data.abs_() * sign_list[cnt].to(m.weight.device)
                        cnt += 1

            if args.perturb_signs_once:
                if level == args.perturb_level:
                    # Perturb signs in every layer with
                    print('Perturbing signs at level, ', level)
                    model = perturb_signs(model, args.perturb_sign_ratio)


    ### Perturb the mask and train final
    if args.prune_perturb:
        # Load only the model weights trained to that point. And then perturb with the ratio given as argument
        print('Loading model from: ', args.load_model_name, args.load_mask_name)
        model = resume_from_checkpoint(model, args.load_model_name, args.load_mask_name)

        print('Pruning the current model to perturb prune and train to completion with perturb ratio: ', args.prune_ratio)
        curr_density = get_sparsity(model)
        if args.pruner == 'random-balanced':
            print('Pruning randomly with balanced at density: ', curr_density*args.prune_ratio)
            model = prune_random_balanced(model, curr_density*args.prune_ratio)
        if args.pruner == 'mag':
            print('Pruning randomly with balanced at density: ', curr_density*args.prune_ratio)
            model = prune_mag(model, curr_density*args.prune_ratio)
        acc1, acc5 = validate(data.val_loader, model, criterion, args, writer, epoch=0)
        torch.save(acc1, '{}runs/perturbed_val_acc_{}.pt'.format(base_dir, args.name))

        if args.reset_bn_default:
            print('resetting BN parameters back to unity and the optimizer')
            for n, m in model.named_modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.weight.data = torch.ones_like(m.weight)
                    m.bias.data = torch.zeros_like(m.bias)
                    m.running_mean = torch.zeros_like(m.running_mean)
                    m.running_var = torch.ones_like(m.running_var)
                
    #For C100, num params for R18 = 11210432 and R50 = 23652032, R18 / R50 = 0.474

    # save the mask of the sparse structure
    mask_list = []
    total_num = 0
    total_den = 0
    for name, m in model.named_modules():
        if isinstance(m, (ConvMask, LinearMask)):
            mask_list.append(m.mask)
            total_num += m.mask.sum()
            total_den += m.mask.numel()
    print('Density before full training is: ', total_num / total_den)
    torch.save(mask_list, '{}runs/mask_{}.pt'.format(base_dir, args.name))

    
    # Start training
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    # Tracking time per epoch    
    expt_start_time = time.time()
    total_time = []

    # Final training for the pruned network follows a cosine LR schedule
    assign_learning_rate(optimizer, args.lr)
    if args.prune_scheduler == 'cosine':
        prune_scheduler = cosine_lr(optimizer, args.final_prune_epoch, args)
    if args.prune_scheduler == 'step':
        prune_scheduler = multistep_lr(optimizer, args)
    if args.prune_scheduler == 'step-drops':
        prune_scheduler = multistep_lr_drops(optimizer, args)
    if args.prune_scheduler == 'imagenet-step':
        prune_scheduler = imagenet_lr_drops(optimizer, args)
        print('ImageNet Step drops LR scheduler')
    if args.prune_scheduler == 'step-warmup':
        prune_scheduler = multistep_lr_warmup(optimizer, args)
    if args.prune_scheduler == 'cosine-warmup':
        prune_scheduler = cosine_lr_warmup(optimizer, args.final_prune_epoch, args)
    if args.prune_scheduler == 'imagenet-step-warmup':
        prune_scheduler = imagenet_lr_drops_warmup(optimizer, args)
        print('ImageNet Step drops with warmup every cycle LR scheduler')
    if args.prune_scheduler == 'cyclic':
        print('Training with cyclic LR with a cycle duration of 50 epochs')
        prune_scheduler = cyclic_lr(optimizer, args)
    if args.prune_scheduler == 'one-cycle':
        print('Training with a one cycle LR with a warmup cycle of 10 epochs and then a linear decrease')
        prune_scheduler = one_cycle_lr(optimizer, args)

    level = 'final'

    if args.track_bn_running_stats:
        curr_mean_list = []
        running_mean_list = []
        curr_var_list = []
        running_var_list = []

    for epoch in range(args.start_epoch, args.epochs):
        prune_scheduler(epoch)
        cur_lr = get_lr(optimizer)
        print('The curent learning rate is: ', cur_lr)

        # train for one epoch
        start_train = time.time()

        if args.perturb_signs:
            if (epoch % args.perturb_sign_every) == args.perturb_sign_every - 1:
                print('Perturbing signs of model while training')
                model = perturb_signs(model, args.perturb_sign_ratio)

        if args.kd:
            print('Knowledge distillation epoch: ', epoch)
            train_acc1, train_acc5, kd_loss = train_kd(
                data.train_loader, model, teacher, optimizer, epoch, args, writer=writer)
            kd_train_loss.append(kd_loss)

        elif args.fix_sign_and_train:
            print('Training with weight signs fixed to initial values')
            train_acc1, train_acc5 = train_with_fixed_signs(
                init_sign_list, data.train_loader, model, criterion, optimizer, epoch, args, writer=writer
            )

        else:
            print('zoobie doobie pa paraaaaa') 
            train_acc1, train_acc5 = train(
                data.train_loader, model, criterion, optimizer, epoch, args, writer=writer)
        train_time.update((time.time() - start_train) / 60)

        
        ## Tracking the running BN stats on the training data
        if args.track_bn_running_stats:
            ####
            cnt = 0
            for n, m in model.named_modules():
                if isinstance(m, nn.BatchNorm2d):
                    # if m.running_mean is not None:
                    if args.bn_layer_idx == cnt:
                        curr_mean_list.append(m.curr_mean)
                        running_mean_list.append(m.running_mean)
                        curr_var_list.append(m.curr_var)
                        running_var_list.append(m.running_var)
                    cnt += 1

        # evaluate on validation set
        start_validation = time.time()
        acc1, acc5 = validate(data.val_loader, model, criterion, args, writer, epoch)
        validation_time.update((time.time() - start_validation) / 60)
        val_acc_total.append(acc1)
        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)
        curr_epoch += 1
        
        if (epoch % args.save_every) == 0 or epoch == (args.epochs - 1):
            print('Saving model for tracking signs')
            torch.save(model.state_dict(), f"{base_dir}/runs/model_{args.name}_track_epoch_{epoch}_.pt")

        # get activation after validation
        if args.track_activation:
            get_activations(model, epoch, level, args)

        if args.hessian:
            trace = hessian_trace(data.train_loader, model, criterion, optimizer, args)
            hessian_list.append(trace)
        
        

        epoch_time.update((time.time() - end_epoch) / 60)
        progress_overall.display(epoch)
        progress_overall.write_to_tensorboard(
            writer, prefix="diagnostics", global_step=epoch
        )

        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()
        torch.save(val_acc_total, base_dir + 'runs/val_acc_'+ args.name + '.pt')  

        # sparsity_actual.append(get_sparsity(model))
        # torch.save(sparsity_actual, base_dir + 'runs/sparsity_list_'+ args.name + '.pt')

        
        if args.kd:
            torch.save(kd_train_loss, base_dir + 'runs/kd_train_loss_'+ args.name + '.pt')    

        # Storing sparsity and threshold statistics for STRConv models
        if args.conv_type == 'STRConv' or args.conv_type == 'STRConvER' or args.conv_type == 'ConvER' or args.conv_type == 'ConvMask':
            count = 0
            sum_sparse = 0.0
            for n, m in model.named_modules():
                if isinstance(m, (STRConv, STRConvER, ConvER, ConvMask, LinearMask)):
                    sparsity, total_params, thresh = m.getSparsity()
                    writer.add_scalar("sparsity/{}".format(n), sparsity, epoch)
                    writer.add_scalar("thresh/{}".format(n), thresh, epoch)
                    sum_sparse += int(((100 - sparsity) / 100) * total_params)
                    count += total_params
            total_sparsity = 100 - (100 * sum_sparse / count)
            writer.add_scalar("sparsity/total", total_sparsity, epoch)
        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

        total_time.append(time.time() - expt_start_time)

    torch.save(total_time, base_dir + 'runs/total_expt_time_'+ args.name + '.pt')

    torch.save(hessian_list, '{}runs/hessian_{}.pt'.format(base_dir, args.name))
    torch.save(model.state_dict(),"{}runs/model_{}_trained.pt".format(base_dir, args.name))
    torch.save(get_sparsity(model), base_dir + 'runs/final_sparsity_'+ args.name + '.pt')
    
    if args.track_bn_running_stats:
        bn_stats = [
            curr_mean_list,
            running_mean_list,
            curr_var_list,
            running_var_list
            ]
        torch.save(bn_stats, "{}runs/bn_stats_{}_{}.pt".format(base_dir, args.name, args.bn_type))

    write_result_to_csv(
        best_acc1=best_acc1,
        best_acc5=best_acc5,
        best_train_acc1=best_train_acc1,
        best_train_acc5=best_train_acc5,
        prune_rate=args.prune_rate,
        curr_acc1=acc1,
        curr_acc5=acc5,
        base_config=args.config,
        name=args.name,
        sparsity=total_sparsity,
    )

    

    if args.conv_type == "STRConv" or args.conv_type == 'STRConvER' or args.conv_type == 'ConvER' or args.conv_type == 'ConvMask':
        json_data = {}
        json_thres = {}
        for n, m in model.named_modules():
            if isinstance(m, (STRConv, STRConvER, ConvER, ConvMask, LinearMask)):
                sparsity = m.getSparsity()
                json_data[n] = sparsity[0]
                sum_sparse += int(((100 - sparsity[0]) / 100) * sparsity[1])
                count += sparsity[1]
                json_thres[n] = sparsity[2]
        json_data["total"] = 100 - (100 * sum_sparse / count)
        if not os.path.exists("runs/layerwise_sparsity"):
            os.mkdir("runs/layerwise_sparsity")
        if not os.path.exists("runs/layerwise_threshold"):
            os.mkdir("runs/layerwise_threshold")
        with open("runs/layerwise_sparsity/{}.json".format(args.name), "w") as f:
            json.dump(json_data, f)
        with open("runs/layerwise_threshold/{}.json".format(args.name), "w") as f:
            json.dump(json_thres, f)


def set_gpu(args, model):
    if args.gpu is not None:
        # Explicitly set a single GPU
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        if hasattr(args, 'multigpu') and args.multigpu:
            print(args.multigpu)
            # DataParallel will divide and allocate batch_size to all available GPUs
            print(f"=> Parallelizing on {len(args.multigpu)} GPUs: {args.multigpu}")
            model = torch.nn.DataParallel(model, device_ids=args.multigpu)
            model = model.cuda(args.multigpu[0])  # Move the model to the first GPU in the list before wrapping with DataParallel
        else:
            # Fallback if no GPU or multiple GPUs specified, use default GPU
            model = model.cuda()

    cudnn.benchmark = True  # This may improve performance if inputs are fixed size

    return model


def resume_from_checkpoint(model, model_name, mask_name):
    base_dir = ''


    model_name = base_dir + 'runs/' + model_name
    mask_name = base_dir + 'runs/' + mask_name
    model.load_state_dict(torch.load(model_name))
    mask_list = torch.load(mask_name)
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, LinearMask)):
            m.mask = mask_list[cnt].to(m.weight.device)
            cnt += 1

    return model




def pretrained(args, model):
    if os.path.isfile(args.pretrained):
        print("=> loading pretrained weights from '{}'".format(args.pretrained))
        pretrained = torch.load(
            args.pretrained,
            map_location=torch.device("cuda:{}".format(args.multigpu[0])),
        )["state_dict"]

        model_state_dict = model.state_dict()

        if not args.ignore_pretrained_weights:

            pretrained_final = {
                k: v
                for k, v in pretrained.items()
                if (k in model_state_dict and v.size() == model_state_dict[k].size())
            }

            if args.conv_type != "STRConv":
                for k, v in pretrained.items():
                    if 'sparseThreshold' in k:
                        wkey = k.split('sparse')[0] + 'weight'
                        weight = pretrained[wkey]
                        pretrained_final[wkey] = sparseFunction(weight, v)

            model_state_dict.update(pretrained_final)
            model.load_state_dict(model_state_dict)

        # Using the budgets of STR models for other models like DNW and GMP
        if args.use_budget:
            budget = {}
            for k, v in pretrained.items():
                if 'sparseThreshold' in k:
                    wkey = k.split('sparse')[0] + 'weight'
                    weight = pretrained[wkey]
                    sparse_weight = sparseFunction(weight, v)
                    budget[wkey] = (sparse_weight.abs() > 0).float().mean().item()

            for n, m in model.named_modules():
                if hasattr(m, 'set_prune_rate'):
                    pr = 1 - budget[n + '.weight']
                    m.set_prune_rate(pr)
                    print('set prune rate', n, pr)


    else:
        print("=> no pretrained weights found at '{}'".format(args.pretrained))

# Write a custom train loop and data loader to get activations sample wise and class wise for full filter and mean filter

def get_activations(model, epoch, level, args):
    base_dir = ''

    activation_list = []
    for n, m in model.named_modules():
        if isinstance(m, TrackActReLU):
            activation_list.append(m.avg_preacts)
    torch.save(activation_list, '{}runs/activations/activation_{}_{}_{}.pt'.format(base_dir, args.name, level, epoch))

def get_dataset(args):
    print(f"=> Getting {args.set} dataset")
    dataset = getattr(data, args.set)(args)

    return dataset


def get_model(args):

    if args.first_layer_dense:
        args.first_layer_type = "DenseConv"

    print("=> Creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch]()

    print(f"=> Num model params {sum(p.numel() for p in model.parameters())}")

    

        

    # applying sparsity to the network
    if args.conv_type != "DenseConv":

        print(f"==> Setting prune rate of network to {args.prune_rate}")

        def _sparsity(m):
            if hasattr(m, "set_prune_rate"):
                m.set_prune_rate(args.prune_rate)

        model.apply(_sparsity)

    # freezing the weights if we are only doing mask training
    if args.freeze_weights:
        print(f"=> Freezing model weights")

        def _freeze(m):
            if hasattr(m, "mask"):
                m.weight.requires_grad = False
                if hasattr(m, "bias") and m.bias is not None:
                    m.bias.requires_grad = False

        model.apply(_freeze)

    return model


def get_optimizer(args, model):
    for n, v in model.named_parameters():
        if v.requires_grad:
            pass #print("<DEBUG> gradient to", n)

        if not v.requires_grad:
            pass #print("<DEBUG> no gradient to", n)

    if args.optimizer == "sgd":
        parameters = list(model.named_parameters())
        sparse_thresh = [v for n, v in parameters if ("sparseThreshold" in n) and v.requires_grad]
        bn_params = [v for n, v in parameters if ("bn" in n) and v.requires_grad]
        # rest_params = [v for n, v in parameters if ("bn" not in n) and ('sparseThreshold' not in n) and v.requires_grad]
        rest_params = [v for n, v in parameters if ("bn" not in n) and ("sparseThreshold" not in n) and v.requires_grad]
        optimizer = torch.optim.SGD(
            [
                {
                    "params": bn_params,
                    "weight_decay": 0 if args.no_bn_decay else args.weight_decay,
                },
                {
                    "params": sparse_thresh,
                    "weight_decay": args.st_decay if args.st_decay is not None else args.weight_decay,
                },
                {"params": rest_params, "weight_decay": args.weight_decay},
            ],
            args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )
    elif args.optimizer == "adam":
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr
        )

    return optimizer


def _run_dir_exists(run_base_dir):
    log_base_dir = run_base_dir / "logs"
    ckpt_base_dir = run_base_dir / "checkpoints"

    return log_base_dir.exists() or ckpt_base_dir.exists()


def get_directories(args):
    if args.config is None or args.name is None:
        raise ValueError("Must have name and config")

    config = pathlib.Path(args.config).stem
    if args.log_dir is None:
        run_base_dir = pathlib.Path(
            f"runs/{config}/{args.name}/prune_rate={args.prune_rate}"
        )
    else:
        run_base_dir = pathlib.Path(
            f"{args.log_dir}/{config}/{args.name}/prune_rate={args.prune_rate}"
        )
    if args.width_mult != 1.0:
        run_base_dir = run_base_dir / "width_mult={}".format(str(args.width_mult))

    if _run_dir_exists(run_base_dir):
        rep_count = 0
        while _run_dir_exists(run_base_dir / str(rep_count)):
            rep_count += 1

        run_base_dir = run_base_dir / str(rep_count)

    log_base_dir = run_base_dir / "logs"
    ckpt_base_dir = run_base_dir / "checkpoints"

    if not run_base_dir.exists():
        os.makedirs(run_base_dir)

    (run_base_dir / "settings.txt").write_text(str(args))

    return run_base_dir, ckpt_base_dir, log_base_dir


def write_result_to_csv(**kwargs):
    filename = args.result_dir + '.csv'
    base_dir = ''

    results = pathlib.Path(base_dir, "runs", filename)
    
    if not results.exists():
        with open(results, 'w', newline=''):

            results.write_text(
                "Date Finished, "
                "Base Config, "
                "Name, "
                "Prune Rate, "
                "Current Val Top 1, "
                "Current Val Top 5, "
                "Best Val Top 1, "
                "Best Val Top 5, "
                "Best Train Top 1, "
                "Best Train Top 5,"
                "Sparsity\n"
            )

    now = time.strftime("%m-%d-%y_%H:%M:%S")

    with open(results, "a+") as f:
        f.write(
            (
                "{now}, "
                "{base_config}, "
                "{name}, "
                "{prune_rate}, "
                "{curr_acc1:.02f}, "
                "{curr_acc5:.02f}, "
                "{best_acc1:.02f}, "
                "{best_acc5:.02f}, "
                "{best_train_acc1:.02f}, "
                "{best_train_acc5:.02f},"
                "{sparsity:.04f}\n"
            ).format(now=now, **kwargs)
        )


if __name__ == "__main__":
    main()
