## pythonic imports
import numpy as np
import os

## torch
import torch
import torch.nn as nn

## torchvision
import torchvision.models as models
from torch.amp import autocast
import torch.distributed as dist

## file-based imports
from utils.conv_type import ConvMask, Conv1dMask, LinearMask, STRConv, Conv1dMaskMW, ConvMaskMW, LinearMaskMW, replace_layers, replace_vit_layers
from utils.harness_params import get_current_params
from utils.custom_models import PreActResNet, PreActBlock
from utils.dataset import imagenet, imagenet_pytorch, CIFARLoader, CIFARLoader_subsampled, imagenet_subsampled, OxfordPetsLoader, Flowers102Loader
from utils.harness_utils import get_model_mask, convert_dict_to_mw
# import deit model
from utils.deit import deit_tiny_patch16_224, deit_small_patch16_224
from utils.res20 import resnet20

## fastargs
from fastargs import get_current_config
from fastargs.decorators import param
from typing import Optional, Dict, Any


get_current_params()

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
DEFAULT_CROP_RATIO = 224 / 256

def perturb_signs(model: nn.Module, perturb_ratio: float) -> nn.Module:
    """Calculate the sparsity of the model.
    Args:
        model (nn.Module): The model whose signs to perturb.
    Returns:
        model (nn.Module)
    """
    # This function randomly flips the signs of perturb_ratio fraction of weights uniformly in each layer
    for n, m in model.named_modules():
        if isinstance(m, ConvMask):
            sign = torch.where(m.mask == 1, torch.where(torch.ones_like(m.mask).bernoulli_(perturb_ratio) == 1, -1, 1), 0).to(m.weight.device)
            m.weight.data = sign * m.weight.data
    return model

def get_sparsity(model: nn.Module) -> float:
    """Calculate the sparsity of the model.

    Args:
        model (nn.Module): The model to calculate sparsity for.

    Returns:
        float: The sparsity of the model.
    """
    nz = 0
    total = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            nz += m.mask.sum()
            total += m.mask.numel()

    return nz / total


class PruningStuff:
    """Class to handle pruning-related functionalities."""

    def __init__(self, model: Optional[nn.Module] = None) -> None:
        """Initialize the PruningStuff class.

        Args:
            model (nn.Module, optional): The model to prune. Default is None.
                                         If no model is provided, then one is automatically set based on configuration.
        """
        self.this_device = "cuda:0"
        self.config = get_current_config()

        if self.config['dataset.dataset_name'] == 'ImageNet':
            if self.config['dataset.use_ffcv']:
                self.loaders = imagenet(distributed=False, this_device='cuda:0')
            else: 
                self.loaders =  imagenet_pytorch(distributed=False, this_device='cuda:0')
            self.train_loader = self.loaders.train_loader

        elif self.config["dataset.dataset_name"] == "OxfordPets":
            self.loaders = OxfordPetsLoader(distributed=True, this_device='cuda:0')
            self.train_loader = self.loaders.train_loader
            self.test_loader = self.loaders.test_loader

        elif self.config["dataset.dataset_name"] == "Flowers102":
            self.loaders = Flowers102Loader(distributed=True, this_device='cuda:0')
            self.train_loader = self.loaders.train_loader
            self.test_loader = self.loaders.test_loader

        elif 'CIFAR' in self.config['dataset.dataset_name']:
            #self.train_loader = airbench.CifarLoader(path='./cifar10', batch_size=512, train=True, aug={'flip' : True, 'translate' : 2}, altflip=True)
            self.loaders = CIFARLoader(distributed=False)
            self.train_loader = self.loaders.train_loader
        if model is None:
            self.model = self.acquire_model()
        else:
            self.model = model
        self.criterion = nn.CrossEntropyLoss()

    @param("model_params.model_name")
    @param("model_params.tf_pretrained")
    @param("dataset.dataset_name")
    def acquire_model(
        self, model_name: str, dataset_name: str, tf_pretrained: bool
    ) -> nn.Module:
        """Acquire the model based on the provided parameters.

        Args:
            model_name (str): Name of the model.
            dataset_name (str): Name of the dataset.

        Returns:
            nn.Module: The acquired model.
        """
        if dataset_name == 'CIFAR10':
            num_classes = 10
        elif dataset_name == 'CIFAR100':
            num_classes = 100
        elif dataset_name == 'ImageNet':
            num_classes = 1000
        elif dataset_name == "OxfordPets":
            num_classes = 37
        elif dataset_name == "Flowers102":
            num_classes = 102
        if model_name == "preresnet":
            model = PreActResNet(block=PreActBlock, num_blocks=[2, 2, 2, 2])
        
        elif any(name in dataset_name for name in ["ImageNet", "OxfordPets", "Flowers102"]) and "deit-small" in model_name:
            model = deit_small_patch16_224(pretrained=tf_pretrained)
            if "OxfordPets" in dataset_name:
                model.head = nn.Linear(384, 37)
            if "Flowers102" in dataset_name:
                model.head = nn.Linear(384, 102)
        elif any(name in dataset_name for name in ["ImageNet", "OxfordPets", "Flowers102"]) and "deit-tiny" in model_name:
            print('Initializing a DeiT')
            model = deit_tiny_patch16_224(pretrained=tf_pretrained)
        # elif "ImageNet" in dataset_name and "vit-base" in model_name:
        #     print('Initializing a ViT-B/16')
        #     model = vit_base_patch16_224(pretrained=tf_pretrained, num_classes=num_classes)
        elif "CIFAR10" in dataset_name and "resnet20" in model_name:
            model = resnet20()
        else:
            model = getattr(models, model_name)(num_classes=num_classes)

        if ("CIFAR" in dataset_name) and ("resnet" in model_name) and ("resnet20" not in model_name):
            model.conv1 = nn.Conv2d(
                    3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
                )
            model.maxpool = nn.Identity()

        if "CIFAR" in dataset_name and "vgg11" in model_name:
            model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            model.classifier = nn.Sequential(
                nn.Linear(512 * 1 * 1, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, num_classes),
            )

        # Replacing the layers.
        if "deit" in model_name:
            print('Replacing the layers')
            model = replace_vit_layers(model=model)
        else:
            model = replace_layers(model=model)        

        model.to(self.this_device)
        return model
    

    @param("prune_params.load_level")
    def load_init_mask_from_str_master_model(self, load_level: int) -> nn.Module:
        """ Load a mask and model separately from checkpoints created in the old codebase
        Args:
            level (int): pruning level we are at
            load_expt_name (str): the expt name to load the mask from
        Returns:
            nn.Module
        """
        original_dict = torch.load(
                os.path.join("/home/c01adga/CISPA-projects/sparse_lottery-2022/TurboPrune/imagenet-res50-ckpt", "model_imagenet-imp-rewind-warmup-every-cycle-0.001-seed-42_init.pt"))
        mask_list = torch.load("/home/c01adga/CISPA-projects/sparse_lottery-2022/TurboPrune/imagenet-res50-ckpt/mask_imagenet-imp-rewind-warmup-every-cycle-0.001-seed-42_{}.pt".format(load_level))
        
        new_dict = {}
        for k in original_dict.keys():
            new_k = k[7:]
            if 'fc.weight' in k:
                new_dict[new_k] = original_dict[k].squeeze(dim=3)
            else:
                new_dict[new_k] = original_dict[k]

        model_dict = self.model.state_dict()

        mask_dict = {}
        cnt = 0
        for k in new_dict.keys():
            if 'weight' in k and(('conv' in k) or ('fc' in k) or ('downsample.0.' in k)):
                if 'downsample' in k:
                    pass
                elif 'fc' in k:
                    mask_dict[k] = mask_list[cnt].squeeze(dim=3)
                else:
                    mask_dict[k] = mask_list[cnt]
                cnt += 1
        print('Mask keys are: ', mask_dict.keys())
        
        model_dict.update(new_dict)
        self.model.load_state_dict(model_dict)

        
        for n, m in self.model.named_modules():
            if isinstance(m, (Conv1dMask, ConvMask, LinearMask)):
                print(n)
                print(m.mask.shape, mask_dict[n + '.weight'].shape)
                m.mask = mask_dict[n + '.weight']

        print('Loaded from imp-rewind-warmup-every-cycle-0.001-seed-42 level {}'.format(load_level))
    
    @param("prune_params.init_type")
    @param("prune_params.load_level")
    @param("prune_params.target_dir")
    @param("prune_params.load_only_warmup_sign")
    def load_init_and_mask(self, init_type: str, load_level: int, target_dir: str, load_only_warmup_sign: bool) -> nn.Module:
        """ Load a mask and initialization that can be trained with cyclic training.
        Args:
            init_type (str): Type of weights, init warmup or current
            level (int): pruning level we are at
            target_dir (str): the checkpoint to load the mask from
        Returns:
            nn.Module
        """
        if init_type == "warmup":
            original_dict = torch.load(
                os.path.join(target_dir, "checkpoints", "model_rewind.pt")
            )
        elif init_type == "init":
            original_dict = torch.load(
                os.path.join(target_dir, "checkpoints", "model_init.pt")
            )
        else:
            print("loading mask and model from the same level.")
        original_weights = dict(
            filter(lambda v: v[0].endswith((".weight", ".bias")), original_dict.items())
        )
        # loading a mask
        mask_dict = torch.load(
                os.path.join(target_dir, "checkpoints", "model_level_{}.pt".format(load_level))
            )
        mask_dict = dict(
            filter(lambda v: v[0].endswith((".mask")), mask_dict.items())
        )
        model_dict = self.model.state_dict()
        model_dict.update(original_weights)
        model_dict.update(mask_dict)
        self.model.load_state_dict(model_dict)

        if load_only_warmup_sign and init_type=='warmup':
            init_sign_list = []
            print('Loading only warmup signs and random weights with the mask')
            for n, m in self.model.named_modules():
                if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                    init_sign_list.append(m.weight.sign())
            cnt = 0
            for n, m in self.model.named_modules():
                if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                    perm = torch.randperm(m.weight.numel())
                    weight = m.weight.data.abs_().view(-1)[perm].view(m.weight.shape)
                    m.weight.data = weight.to(m.weight.device) * init_sign_list[cnt].to(m.weight.device)
                    cnt += 1
        print('Loaded init {} and mask from level {} of target {}'.format(init_type, load_level, target_dir))

    @param("prune_params.er_method")
    @param("prune_params.er_init")
    @param("prune_params.update_sign_init")
    @param("prune_params.small_weight_frac")
    def prune_at_initialization(
        self, er_method: str, er_init: float, update_sign_init: bool, small_weight_frac: float) -> nn.Module:
        """Prune the model at initialization.

        Args:
            er_method (str): Method of pruning.
            er_init (float): Initial sparsity target.

        Returns:
            nn.Module: The pruned model.
        """
        self.model = self.acquire_model()
        total = 0
        nonzero = 0
        for n, m in self.model.named_modules():
            if isinstance(m, (Conv1dMask, ConvMask, LinearMask)):
                total += m.mask.numel()
                nonzero += m.mask.sum()

            if isinstance(m, STRConv):
                mask = m.get_mask()
                total += mask.numel()
                nonzero += mask.sum()
                
        print(f'density is {(total / nonzero) * 100:3f}')
                
        print('prior to pruning at init')
        er_method_name = f"prune_{er_method}"
        pruner_method = globals().get(er_method_name)
        if er_method in {"synflow", "snip", "rel_grad"}:
            self.model = pruner_method(self.model, self.train_loader, er_init)
        elif er_method == 'load_sign_and_mask':
            self.model = load_final_mask_and_intermediate_sign(self.model)
        elif er_method == 'load_mag_and_mask':
            self.model = load_final_mask_and_intermediate_mag(self.model)
        elif er_method == 'load_weight_and_mask':
            self.model = load_final_mask_and_intermediate_weight(self.model)
        elif er_method == 'load_mw_sign_and_mask':
            self.model = load_mw_final_mask_and_intermediate_sign(self.model)
        elif er_method == 'load_only_mask':
            self.model = load_only_final_mask(self.model)
        elif er_method == 'load_dense_and_prune_mag':
            self.model = load_dense_and_prune_mag(self.model, er_init)
        elif er_method == 'load_dense_and_prune_mag_mw':
            self.model = load_dense_and_prune_mag_mw(self.model, er_init)
        elif er_method == 'load_mw_trained_dense_and_prune_mag_mw':
            self.model = load_mw_trained_dense_and_prune_mag_mw(self.model, er_init)
        elif er_method == 'uniform':
            self.model = prune_uniform(self.model, er_init)
        elif er_method == 'grad_mask_train_small_and_large':
            self.model = prune_small_and_large(self.model, er_init, prune_extremes=False, small_weight_frac=small_weight_frac)
        elif er_method == 'grad_mask_train_mid':
            self.model = prune_small_and_large(self.model, er_init, prune_extremes=True, small_weight_frac=small_weight_frac)
        elif er_method == 'grad_mask_train_small_rel_grad':
            self.model = prune_rel_grad(self.model, self.train_loader, 1-er_init, distributed=True)
            self.model = flip_mask(self.model)
        elif er_method == 'rel_grad':
            self.model = prune_rel_grad(self.model, self.train_loader, er_init, distributed=True)
        elif er_method == 'rel_grad_uniform':
            self.model = prune_rel_grad_uniform(self.model, self.train_loader, er_init, distributed=True)
        elif er_method == 'grad_mask_train_small_rel_grad_uniform':
            self.model = prune_rel_grad_uniform(self.model, self.train_loader, 1-er_init, distributed=True)
            self.model = flip_mask(self.model)
        elif er_method == "just dont":
            print('We dont prune at init')
            pass
        elif er_method == 'anneal_balanced':
            print('Determining the random mask for sparse training post mask decay')
            prune_er_balanced(self.model, er_init)
            self.target_mask_list = get_model_mask(self.model)
            # now make the mask all ones again
            for n, m in self.model.named_modules():
                if isinstance(m, (Conv1dMask, ConvMask, LinearMask)):
                    m.mask = torch.ones_like(m.weight)
        else:
            print('Prune method: {}'.format(er_method_name))
            pruner_method(self.model, er_init)

        if update_sign_init:
            update_func = globals().get('update_sign_from_grad')
            self.model = update_func(self.model, self.train_loader, frac=0.05)
        print('We just pruned at init, woohoo!')


    @param("prune_params.prune_method")
    @param("prune_params.update_sign_every_level")
    def level_pruner(self, prune_method: str, density: float, level: int, update_sign_every_level: bool) -> None:
        """Prune the model at a specific density level.

        Args:
            prune_method (str): Method of pruning.
            density (float): Desired density after pruning.
        """
        print("---" * 20)
        print(f"Density before pruning: {get_sparsity(self.model)}")
        print("---" * 20)

        prune_method_name = f"prune_{prune_method}"
        pruner_method = globals().get(prune_method_name)
        if prune_method in {"synflow", "snip"}:
            self.model = pruner_method(self.model, self.train_loader, density)
        elif prune_method == "just dont":
            print('We dont prune at the end of the level')
            pass
        else:
            pruner_method(self.model, density)
        
        if update_sign_every_level:
            update_func = globals().get('update_sign_from_grad')
            self.model = update_func(self.model, self.train_loader, frac=0.05)

        # put the model back on the GPU
        self.model.to(self.this_device)

        print("---" * 20)
        print(f"Density after pruning: {get_sparsity(self.model)}")
        print("---" * 20)


    def load_from_ckpt(self, path: str) -> None:
        """Load the model from a checkpoint.

        Args:
            path (str): Path to the checkpoint.
        """
        self.model.load_state_dict(torch.load(path))


def set_target_mask(model, target_mask_list):
    cnt = 0
    print('Setting the target random mask')
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            m.mask = target_mask_list[cnt].to(m.weight.device)
            cnt += 1
    return model

@param("prune_params.anneal_scale_floor")
@param("prune_params.anneal_scale_ceil")
def anneal_scale(step, total_anneal_steps, anneal_scale_floor=0, anneal_scale_ceil=1):
    if step >= total_anneal_steps:
        return 0
    return anneal_scale_floor + (anneal_scale_ceil - anneal_scale_floor) * (step / total_anneal_steps)


def masked_l2_decay(model, target_mask_list, scale):
    # decay the params that are outside the mask
    curr_loss = 0   
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, STRConv)):
            mask = target_mask_list[cnt].to(m.weight.device)
            curr_loss += ((m.weight * (mask==0)) ** 2).sum()
            cnt += 1
    return scale * curr_loss

def make_dense(model: nn.Module) -> nn.Module:
    """
    Make the mask to all ones i.e. dense network.
    """
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            m.mask = torch.ones_like(m.mask)
    
    print('The mask is now dense.')
    return model
    
def prune_mag(model: nn.Module, density: float, distributed=False) -> nn.Module:
    """Magnitude-based pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    score_list = {}

    # Sync model weights across devices if distributed
    if distributed:
        for param in model.parameters():
            dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
            param.data = param.data / dist.get_world_size()

    curr_total = 0
    curr_nz = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (Conv1dMaskMW, LinearMaskMW, ConvMaskMW)):
                score_list[n] = (m.mask.to(m.weight.device) * m.weight * m.m).detach().abs_()
            else:
                score_list[n] = (m.mask.to(m.weight.device) * m.weight).detach().abs_()
            curr_nz += m.mask.sum()
            curr_total += m.mask.numel()
    print(f'Before pruning at {density}, the density of the model is {curr_nz / curr_total}')

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])

    # if distributed:
    #     for param in model.parameters():
    #         dist.all_reduce(global_scores, op=dist.ReduceOp.SUM)
    #         global_scores = global_scores / world_size

    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0

        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

    print(
        "Overall model density after magnitude pruning at current iteration = ",
        (total_num / total_den).item(),
    )

    
    return model

def prune_largest(model: nn.Module, density: float, distributed=False) -> nn.Module:
    """Prunes the larges magnitude params in the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """

    # Sync model weights across devices if distributed
    if distributed:
        for param in model.parameters():
            dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
            param.data = param.data / dist.get_world_size()

    score_list = {}

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (Conv1dMaskMW, LinearMaskMW, ConvMaskMW)):
                score_list[n] = -1 * (m.mask.to(m.weight.device) * m.weight * m.m).detach().abs_()
            else:
                score_list[n] = -1 * (m.mask.to(m.weight.device) * m.weight).detach().abs_()

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])

    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0

        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

    print(
        "Overall model density after magnitude pruning at current iteration = ",
        (total_num / total_den).item(),
    )

    
    return model

@param("prune_params.start_percentile")
def prune_percentile_window(model: nn.Module, density: float, start_percentile: float, distributed=False) -> nn.Module:
    """Prune the model using a sliding window approach starting from a specific percentile.

    Args:
        model (torch.nn.Module): The model to prune.
        start_percentile (float): The percentile from which to start the sliding window.
        window_fraction (float): Fraction of parameters to set to zero within the sliding window.

    Returns:
        torch.nn.Module: The pruned model.
    """
    score_list = {}

    # Sync model weights across devices if distributed
    if distributed:
        for param in model.parameters():
            dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
            param.data = param.data / dist.get_world_size()

    curr_total = 0
    curr_nz = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (Conv1dMaskMW, LinearMaskMW, ConvMaskMW)):
                score_list[n] = (m.mask.to(m.weight.device) * m.weight * m.m).detach().abs_()
            else:
                score_list[n] = (m.mask.to(m.weight.device) * m.weight).detach().abs_()
            curr_nz += m.mask.sum()
            curr_total += m.mask.numel()
    print(f'Before pruning at {density}, the density of the model is {curr_nz / curr_total}')

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    mask_list = torch.zeros_like(global_scores)
    num_params = global_scores.numel()
    sorted_indices = torch.argsort(global_scores)
    # start zeroing from percentile
    start_index = int(start_percentile * num_params)
    # num to set to zero
    k = int((density) * global_scores.numel())
    one_indices = sorted_indices[start_index:start_index + k]
    # Make only the params in the window 1, the others zero.
    mask_list[one_indices] = 1.0

    start_idx = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            end_idx = start_idx + m.weight.numel() 
            m.mask = mask_list[start_idx:end_idx].reshape(m.weight.shape)
            start_idx = end_idx

    print(f"Created mask with params nonzero from {start_percentile} to {start_percentile+density} in ascending order")

    return model

@param("prune_params.er_method")
@param("prune_params.er_init")
@param("prune_params.small_weight_frac")
def prune_and_update_grad_mask(model: nn.Module, grad_mask: list, train_loader, step_counter, total_steps, er_method, er_init, small_weight_frac) -> nn.Module:
    """Prunes the model using specified method and saves the mask as grad_mask before setting mask to 1.

    Args:
        model (nn.Module): The model to prune.
        prune_method (str): Method to use for pruning ('mag', 'largest', or 'er_balanced').
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The model with updated masks.
    """
    # First apply the pruning method
    if er_method == 'mag':
        model = prune_mag(model, density=er_init, distributed=True)
    elif er_method == 'largest':
        model = prune_largest(model, density=er_init, distributed=True)
    elif er_method == 'er_balanced':
        prune_random_balanced(model, density=er_init, distributed=True)
    elif er_method == 'percentile_window':
        model = prune_percentile_window(model, density=er_init, distributed=True)
    elif er_method == 'grad_mask_train_small_and_large':
        # small_weight_frac = get_small_weight_schedule(initial_small_weight=0.5, current_step=step_counter, max_steps=int(total_steps))
        model = prune_small_and_large(model, density=er_init, prune_extremes=False, distributed=True, small_weight_frac=small_weight_frac)
    elif er_method == 'grad_mask_train_mid':
        # small_weight_frac = get_small_weight_schedule(initial_small_weight=0.5, current_step=step_counter, max_steps=int(total_steps))
        model = prune_small_and_large(model, density=er_init, prune_extremes=True, distributed=True, small_weight_frac=small_weight_frac)
    elif er_method == 'rel_grad':
        model = prune_rel_grad(model, train_loader, density=er_init, distributed=True)
    elif er_method == 'grad_mask_train_small_rel_grad':
        # prunes with rel grad and then flips the mask
        model = prune_rel_grad(model, train_loader, density=1-er_init, distributed=True)
        model = flip_mask(model)
    elif er_method == 'rel_grad_uniform':
        model = prune_rel_grad_uniform(model, train_loader, er_init, distributed=True)
    elif er_method == 'grad_mask_train_small_rel_grad_uniform':
        model = prune_rel_grad_uniform(model, train_loader, 1-er_init, distributed=True)
        model = flip_mask(model)
    else:
        raise ValueError(f"Unknown pruning method: {er_method}")

    # Save mask to grad_mask and set mask to 1
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            # Save current mask as grad_mask
            grad_mask[cnt] = m.mask.clone()
            # Set mask to all ones
            m.mask = torch.ones_like(m.weight)
            cnt += 1

    print(f"Pruned with {er_method} and changed grad_mask")
    return model, grad_mask


@param("optimizer.update_grad_mask_every")
@param("prune_params.train_first_and_last_grad_mask")
def apply_mask_to_grad(model, grad_mask, step_counter, train_loader, total_steps, update_grad_mask_every, train_first_and_last_grad_mask):
    """
    This function masks the gradients before an optimizer step, 
    using the parameter mask attribute
    """
    
    # this line will ensure that the mask is updated every step
    if (step_counter % update_grad_mask_every == 0) and (update_grad_mask_every != -1):
        model, grad_mask = prune_and_update_grad_mask(model, grad_mask, train_loader, step_counter, total_steps)
    
    if train_first_and_last_grad_mask:
        # Set first and last layer masks to all ones, i.e, trainable
        print('Making the first and last layers trainable with grad mask 1')
        for idx in [0, -1]:
            grad_mask[idx] = torch.ones_like(grad_mask[idx])

    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (Conv1dMask, ConvMask, LinearMask)):
            m.weight.grad *= grad_mask[cnt]
            cnt += 1
    

def prune_random_erk(model: nn.Module, density: float) -> nn.Module:
    """Random ERK-based pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    sparsity_list = []
    num_params_list = []
    total_params = 0
    score_list = {}

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score_list[n] = (
                (
                    m.mask.to(m.weight.device)
                    * torch.randn_like(m.weight).to(m.weight.device)
                )
                .detach()
                .abs_()
            )
            sparsity_list.append(torch.tensor(m.weight.shape).sum() / m.weight.numel())
            num_params_list.append(m.weight.numel())
            total_params += m.weight.numel()

    num_params_kept = (
        torch.tensor(sparsity_list) * torch.tensor(num_params_list)
    ).sum()
    num_params_to_keep = total_params * density
    C = num_params_to_keep / num_params_kept
    print("Factor: ", C)
    sparsity_list = [torch.clamp(C * s, 0, 1) for s in sparsity_list]

    total_num = 0
    total_den = 0
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            global_scores = torch.flatten(score_list[n])
            k = int((1 - sparsity_list[cnt]) * global_scores.numel())
            if k == 0:
                threshold = 0
            else:
                threshold, _ = torch.kthvalue(global_scores, k)
            print("Layer", n, " params ", k, global_scores.numel())

            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            m.mask = torch.where(score <= threshold, zero, one)
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()
            cnt += 1

    print(
        "Overall model density after random global (ERK) pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def prune_uniform_mag(model: nn.Module, density: float, distributed=False) -> nn.Module:
    """Layerwise mag pruning.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    if distributed:
        for param in model.parameters():
            dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
            param.data = param.data / dist.get_world_size()

    score_list = {}
    total_num = 0
    total_den = 0
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (Conv1dMaskMW, LinearMaskMW, ConvMaskMW)):
                score_list[n] = (m.mask.to(m.weight.device) * m.weight * m.m).detach().abs_()
            else:
                score_list[n] = (m.mask.to(m.weight.device) * m.weight).detach().abs_()
            global_scores = torch.flatten(score_list[n])
            k = int((1 - density) * global_scores.numel())
            if k == 0:
                threshold = 0
            else:
                threshold, _ = torch.kthvalue(global_scores, k)
            print("Layer", n, " params ", k, global_scores.numel())

            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            m.mask = torch.where(score <= threshold, zero, one)
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()
            cnt += 1

    print(
        "Overall model density layerwise mag pruning ",
        total_num / total_den,
    )
    return model


def prune_uniform_largest(model: nn.Module, density: float, distributed=False) -> nn.Module:
    """Layerwise largest val pruning.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    if distributed:
        for param in model.parameters():
            dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
            param.data = param.data / dist.get_world_size()

    score_list = {}
    total_num = 0
    total_den = 0
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (Conv1dMaskMW, LinearMaskMW, ConvMaskMW)):
                score_list[n] = -1 * (m.mask.to(m.weight.device) * m.weight * m.m).detach().abs_()
            else:
                score_list[n] = -1 * (m.mask.to(m.weight.device) * m.weight).detach().abs_()
            global_scores = torch.flatten(score_list[n])
            k = int((1 - density) * global_scores.numel())
            if k == 0:
                threshold = 0
            else:
                threshold, _ = torch.kthvalue(global_scores, k)
            print("Layer", n, " params ", k, global_scores.numel())

            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            m.mask = torch.where(score <= threshold, zero, one)
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()
            cnt += 1

    print(
        "Overall model density layerwise largest pruning ",
        total_num / total_den,
    )
    return model


def get_small_weight_schedule(initial_small_weight=0.5, current_step=0, max_steps=100):
    """
    Create a schedule for small_weight that changes gradually over time.
    
    Args:
        initial_small_weight: Starting value for small_weight (default: 0.5)
        current_step: Current training step
        max_steps: Maximum number of steps for the schedule
    
    Returns:
        Current small_weight value based on progress
    """
    # Ensure we start at initial_small_weight and approach 1.0
    progress = min(current_step / max_steps, 1.0)
    # Use a smooth curve that starts slow, accelerates in the middle, and slows down near the end
    # value = initial_small_weight + (1.0 - initial_small_weight) * (1 - np.cos(progress * np.pi)) / 2
    value = initial_small_weight + (1.0 - initial_small_weight) * progress
    return value


def prune_small_and_large(model: nn.Module, density: float, prune_extremes=True, small_weight_frac=0.5, distributed=False) -> nn.Module:
    """Prunes either the extreme (smallest and largest) weights or the middle weights.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.
        prune_extremes (bool): If True, prunes smallest and largest weights (keeps middle).
                               If False, prunes middle weights (keeps smallest and largest).
        distributed (bool): Whether to synchronize across distributed processes.

    Returns:
        nn.Module: The pruned model.
    """
    score_list = {}

    # Sync model weights across devices if distributed
    if distributed:
        for param in model.parameters():
            dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
            param.data = param.data / dist.get_world_size()

    curr_total = 0
    curr_nz = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (Conv1dMaskMW, LinearMaskMW, ConvMaskMW)):
                score_list[n] = (m.mask.to(m.weight.device) * m.weight * m.m).detach().abs_()
            else:
                score_list[n] = (m.mask.to(m.weight.device) * m.weight).detach().abs_()
            curr_nz += m.mask.sum()
            curr_total += m.mask.numel()
    
    print(f'Before pruning at {density}, the density of the model is {curr_nz / curr_total}')

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    
    # Calculate how many weights to keep/prune
    total_params = global_scores.numel()

    # Adjusting the number of parameters that are pruned, to maintain density correctly
    if prune_extremes:
        target_nonzero = int(density * total_params)
    else:
        target_nonzero = int((1-density) * total_params)

    total_to_prune = total_params - target_nonzero
    
    # Split pruning between small and large weights
    k_small = int(small_weight_frac * total_to_prune)
    k_large = total_to_prune - k_small
    
    # Find thresholds for both small and large weights
    if k_small > 0:
        small_threshold, _ = torch.kthvalue(global_scores, k_small)
    else:
        small_threshold = -1  # Keep all weights if no small weights to prune
    
    if k_large > 0:
        large_threshold, _ = torch.kthvalue(global_scores, total_params - k_large)
    else:
        large_threshold = float('inf')  # Keep all weights if no large weights to prune
    
    total_num = 0
    total_den = 0

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            
            if prune_extremes:
                # Keep weights that are BOTH: > small_threshold AND < large_threshold
                m.mask = torch.where((score > small_threshold) & (score < large_threshold), one, zero)
            else:
                # Keep weights that are EITHER: <= small_threshold OR >= large_threshold
                m.mask = torch.where((score <= small_threshold) | (score >= large_threshold), one, zero)
            
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()

    if prune_extremes:
        print(
            "Overall model density after pruning smallest and largest weights = ",
            (total_num / total_den).item(),
        )
    else:
        print(
            "Overall model density after pruning middle weights = ",
            (total_num / total_den).item(),
        )
    
    return model

def prune_uniform(model: nn.Module, density: float) -> nn.Module:
    """Random uniform pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    score_list = {}
    total_num = 0
    total_den = 0
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score_list[n] = m.mask * torch.randn_like(m.weight)
            global_scores = torch.flatten(score_list[n])
            k = int((1 - density) * global_scores.numel())
            if k == 0:
                threshold = 0
            else:
                threshold, _ = torch.kthvalue(global_scores, k)
            print("Layer", n, " params ", k, global_scores.numel())

            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            m.mask = torch.where(score <= threshold, zero, one)
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()
            cnt += 1

    print(
        "Overall model density after random global (ERK) pruning at current iteration = ",
        total_num / total_den,
    )
    return model



def update_sign_from_grad(model: nn.Module, trainloader: Any, frac: float) -> nn.Module:
    # This function is designed to choose a fraction of the smallest parameters and update their signs based on their gradient value
    num_steps = 10
    criterion = nn.CrossEntropyLoss()
    model.zero_grad()
    print('Updating the Sign of the mask')
    for i, (images, target) in enumerate(trainloader):
        images = images.to(torch.device("cuda"))
        target = target.to(torch.device("cuda")).long()
        with autocast(dtype=torch.bfloat16, device_type='cuda'):
            output = model(images)
            loss = criterion(output, target)
            loss.backward()
            # accumulate the gradient over multiple steps
        if i == num_steps:
            break
    
    # based on the value of the magnitude, flip the smallest fraction of weights inside the mask
    # since these values are close to zero and are more likely to flip
    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score = (m.weight).detach().abs_()
            # elements outside the mask should have a high score so they are ingnored
            score = torch.where(m.mask.to(m.weight.device) == 0, 10.0, score)
            score_list[n] = score

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int(frac * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)
    if not k < 1:
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                flip = torch.ones_like(m.weight.data)
                # if the weight and the gradient have same sign: change the weight sign,
                # if the weight and the gradient have opposite signs: keep the weight sign
                flip = torch.where(score_list[n] <= k, torch.where(m.weight.data.sign() * m.weight.grad.sign() == 1, -1, 1), 1)
                m.weight.data = flip * m.weight.data
    return model
    

def prune_snip(model: nn.Module, trainloader: Any, density: float) -> nn.Module:
    """SNIP method for pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        trainloader (Any): The training data loader.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    criterion = nn.CrossEntropyLoss()
    
    model.zero_grad()
    for i, (images, target) in enumerate(trainloader):
        images = images.to(torch.device("cuda"))
        target = target.to(torch.device("cuda")).long()
        with autocast(dtype=torch.bfloat16, device_type='cuda'):
            output = model(images)
            criterion(output, target).backward()
        if i == 2:
            break
        
    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
                score = ((m.weight.grad * m.weight + m.m * m.m.grad) * m.mask.to(m.weight.device)).detach().abs_()
                # normalizing the score by its sum as in the paper
                score_list[n] = score
                
            else:    
                score = (m.weight.grad * m.weight * m.mask.to(m.weight.device)).detach().abs_()
                score_list[n] = score

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

    print(
        "Overall model density after snip pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def flip_mask(model: nn.Module) -> nn.Module:
    """Flips the mask values in each layer of the model (0→1, 1→0).
    
    This function takes the existing mask in each masked layer and complements it,
    turning all 0s to 1s and all 1s to 0s. This effectively inverts which parameters
    are active and which are pruned.

    Args:
        model (nn.Module): The model whose masks to flip.

    Returns:
        nn.Module: The model with flipped masks.
    """
    total_flipped = 0
    total_params = 0
    
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            # Create the complement of the mask (flip 0s to 1s and 1s to 0s)
            m.mask = 1 - m.mask
            total_flipped += m.mask.sum()
            total_params += m.mask.numel()
    
    print(
        f"Mask flipped: new density = {total_flipped / total_params:.4f}, "
        f"previous density was {1 - (total_flipped / total_params):.4f}"
    )
    
    return model

def prune_rel_grad(model: nn.Module, trainloader: Any, density: float, distributed=False) -> nn.Module:
    """Prune based on relative gradients

    Args:
        model (nn.Module): The model to prune.
        trainloader (Any): The training data loader.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """

    criterion = nn.CrossEntropyLoss()
    # Check if model has gradients
    has_grad = False
    for param in model.parameters():
        if param.grad is not None:
            has_grad = True
            break
    if not has_grad:
        # populated the gradients in case they are zero.
        for i, (images, target) in enumerate(trainloader):
            images = images.to(torch.device("cuda"))
            target = target.to(torch.device("cuda")).long()
            with autocast(dtype=torch.bfloat16, device_type='cuda'):
                output = model(images)
                criterion(output, target).backward()
                break
    
    # To avoid OOM error, we use the gradient for a single step from the batch update
    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
                score = (((m.weight.grad * m.weight + m.m * m.m.grad) / ((m.m * m.weight)**2)) * m.mask.to(m.weight.device)).detach().abs_()
                # normalizing the score by its sum as in the paper
                score_list[n] = score
                
            else:    
                score = ((m.weight.grad / m.weight) * m.mask.to(m.weight.device)).detach().abs_()
                score_list[n] = score
    
    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

    print(
        "Overall model density after pruning with relative gradients at current iteration = ",
        total_num / total_den,
    )
    return model

def prune_rel_grad_uniform(model: nn.Module, trainloader: Any, density: float, distributed=False) -> nn.Module:
    """Prune based on relative gradients
    Layerwise uniform selection of the smallest relative gradients

    Args:
        model (nn.Module): The model to prune.
        trainloader (Any): The training data loader.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """

    criterion = nn.CrossEntropyLoss()
    # Check if model has gradients
    has_grad = False
    for param in model.parameters():
        if param.grad is not None:
            has_grad = True
            break
    if not has_grad:
        # populated the gradients in case they are zero.
        for i, (images, target) in enumerate(trainloader):
            images = images.to(torch.device("cuda"))
            target = target.to(torch.device("cuda")).long()
            with autocast(dtype=torch.bfloat16, device_type='cuda'):
                output = model(images)
                criterion(output, target).backward()
                break
    
    # To avoid OOM error, we use the gradient for a single step from the batch update
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
                score = (((m.weight.grad * m.weight + m.m * m.m.grad) / ((m.m * m.weight)**2)) * m.mask.to(m.weight.device)).detach().abs_()
                
            else:    
                score = ((m.weight.grad / m.weight) * m.mask.to(m.weight.device)).detach().abs_()
    
            flattened_score = torch.flatten(score)
            k = int((1 - density) * flattened_score.numel())
            threshold, _ = torch.kthvalue(flattened_score, k)

            if not k < 1:
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                
    print('pruning uniform relative gradients')
    return model

def prune_synflow(model: nn.Module, trainloader: Any, density: float) -> nn.Module:
    """SynFlow method pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        trainloader (Any): The training data loader.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """

    @torch.no_grad()
    def linearize(model: nn.Module) -> Dict[str, torch.Tensor]:
        """Linearize the model by taking the absolute value of its parameters.

        Args:
            model (nn.Module): The model to linearize.

        Returns:
            Dict[str, torch.Tensor]: Dictionary of parameter signs.
        """
        signs = {}
        for name, param in model.state_dict().items():
            signs[name] = torch.sign(param)
            param.abs_()
        return signs

    @torch.no_grad()
    def nonlinearize(model: nn.Module, signs: Dict[str, torch.Tensor]) -> None:
        """Restore the signs of the model parameters.

        Args:
            model (nn.Module): The model to restore.
            signs (Dict[str, torch.Tensor]): Dictionary of parameter signs.
        """
        for n, param in model.state_dict().items():
            param.mul_(signs[n])

    signs = linearize(model)

    for i, (images, target) in enumerate(trainloader):
        images = images.to(torch.device("cuda"))
        target = target.to(torch.device("cuda")).long()
        input_dim = list(images[0, :].shape)
        input = torch.ones([1] + input_dim).to("cuda")
        with autocast(dtype=torch.bfloat16, device_type='cuda'):
            output = model(input)
            torch.sum(output).backward()
        break

    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
                score_list[n] = (
                    (m.mask.to(m.weight.device) * (m.weight.grad * m.weight + m.m * m.m.grad)).detach().abs_()
                )
            else:
                score_list[n] = (
                    (m.mask.to(m.weight.device) * m.weight.grad * m.weight).detach().abs_()
                )

    model.zero_grad()
    nonlinearize(model, signs)

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

    print(
        "Overall model density after synflow pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def prune_random_balanced(model: nn.Module, density: float, distributed: bool) -> nn.Module:
    """Random balanced pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    # Sync model weights across devices if distributed
    if distributed:
        for param in model.parameters():
            dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
            param.data = param.data / dist.get_world_size()

    total_params = 0
    l = 0
    sparsity_list = []
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            total_params += m.weight.numel()
            l += 1
    L = l
    X = density * total_params / l
    score_list = {}
    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score_list[n] = (
                (
                    m.mask.to(m.weight.device)
                    * torch.randn_like(m.weight).to(m.weight.device)
                )
                .detach()
                .abs_()
            )

            if X / m.weight.numel() < 1.0:
                sparsity_list.append(X / m.weight.numel())
            else:
                sparsity_list.append(1)
                # correction for taking care of exact sparsity
                diff = X - m.mask.numel()
                X = X + diff / (L - l)
            l += 1
    
    total_num = 0
    total_den = 0
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            global_scores = torch.flatten(score_list[n])
            k = int((1 - sparsity_list[cnt]) * global_scores.numel())
            if k == 0:
                threshold = 0
            else:
                threshold, _ = torch.kthvalue(global_scores, k)
            print("Layer", n, " params ", k, global_scores.numel())

            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            m.mask = torch.where(score <= threshold, zero, one)
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()
            cnt += 1

    print(
        "Overall model density after random global (balanced) pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def prune_er_erk(model: nn.Module, er_sparse_init: float) -> None:
    """ERK-based pruning at initialization.

    Args:
        model (nn.Module): The model to prune.
        er_sparse_init (float): Initial sparsity target.
    """
    sparsity_list = []
    num_params_list = []
    total_params = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            sparsity_list.append(torch.tensor(m.weight.shape).sum() / m.weight.numel())
            num_params_list.append(m.weight.numel())
            total_params += m.weight.numel()

    num_params_kept = (
        torch.tensor(sparsity_list) * torch.tensor(num_params_list)
    ).sum()
    num_params_to_keep = total_params * er_sparse_init
    C = num_params_to_keep / num_params_kept
    sparsity_list = [torch.clamp(C * s, 0, 1) for s in sparsity_list]
    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            m.set_er_mask(sparsity_list[l])
            l += 1
    print(sparsity_list)


def prune_er_balanced(model: nn.Module, er_sparse_init: float) -> None:
    """ER-balanced pruning at initialization.

    Args:
        model (nn.Module): The model to prune.
        er_sparse_init (float): Initial sparsity target.
    """
    total_params = 0
    l = 0
    sparsity_list = []
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            total_params += m.weight.numel()
            l += 1
    L = l
    X = er_sparse_init * total_params / l
    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if X / m.weight.numel() < 1.0:
                sparsity_list.append(X / m.weight.numel())
            else:
                sparsity_list.append(1)
                # correction for taking care of exact sparsity
                diff = X - m.mask.numel()
                X = X + diff / (L - l)
            l += 1

    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            m.set_er_mask(sparsity_list[l])
            l += 1
    print(sparsity_list)

# Load the final mask, and an weight configuration finetune the model with a fixed sign.
@param("prune_params.target_expt")
@param("prune_params.target_dir")
def load_only_final_mask(model: nn.Module, target_expt: str, target_dir: str):
    model_path = os.path.join(target_dir, target_expt)
    mask_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    mask_dict = torch.load(mask_path)
    mask_dict = {key.replace('module.', ''): value for key, value in mask_dict.items()}

    mask_list = []
    if 'str' in target_expt:
        for k in mask_dict.keys():
            if ('weight' in k) and (('conv' in k) or ('fc' in k)):
                weight = mask_dict[k]
                k_s = k.replace("weight", "sparseThreshold")
                temp = sparseFunction(weight, mask_dict[k_s])
                mask = torch.where(temp != 0, 1, 0)
                mask_list.append(mask)
    else:
        for k in mask_dict.keys():
            if 'mask' in k:
                mask_list.append(mask_dict[k])
    
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            mask = mask_list[cnt]
            if ('str' in target_expt) and isinstance(m, Conv1dMask):
                mask = mask.squeeze(dim=-1)
            # Load the final, learnt mask
            m.mask = mask
            cnt += 1
    
    print(f'Loaded only final mask from expt {target_expt}')
    return model


@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_mw_trained_dense_and_prune_mag_mw(model: nn.Module, er_init: float, target_expt: str, target_dir: str, load_sign_at: str):
    # loads dense weights from checkpoint specified by load_sign_at
    # then prunes the magnitudes of the weights and train with mw
    model_path = os.path.join(target_dir, target_expt)
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    # load the dense weights
    if 'str' in target_expt:
        print('Does not support STR models yet')
    
    sign_dict = torch.load(sign_path)
    sign_dict = {key.replace('module.', ''): value for key, value in sign_dict.items()}
    model_state_dict = model.state_dict()
    model_state_dict.update(sign_dict)
    model.load_state_dict(model_state_dict)
    print(f'Loaded dense weights from expt {target_expt}')
    # prune the magnitudes
    print(f'Pruning magnitudes to {er_init} density')
    model = prune_mag(model, er_init)

    return model

@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_dense_and_prune_mag_mw(model: nn.Module, er_init: float, target_expt: str, target_dir: str, load_sign_at: str):
    # loads dense weights from checkpoint specified by load_sign_at
    # then prunes the magnitudes of the weights and train with mw
    model_path = os.path.join(target_dir, target_expt)
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    # load the dense weights
    if 'str' in target_expt:
        print('Does not support STR models yet')
    
    sign_dict = torch.load(sign_path)
    sign_dict = {key.replace('module.', ''): value for key, value in sign_dict.items()}
    model_state_dict = model.state_dict()
    model_state_dict.update(sign_dict)
    model.load_state_dict(model_state_dict)

    # init from weight
    print('Initializing mw from weight')
    convert_dict_to_mw(model)
    print(f'Loaded dense weights from expt {target_expt}')
    
    # prune the magnitudes
    print(f'Pruning magnitudes to {er_init} density')
    model = prune_mag(model, er_init)

    # convert to mw

    return model

@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_dense_and_prune_mag(model: nn.Module, er_init: float, target_expt: str, target_dir: str, load_sign_at: str):
    # loads dense weights from checkpoint specified by load_sign_at
    # then prunes the magnitudes of the weights
    model_path = os.path.join(target_dir, target_expt)
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    sign_dict = torch.load(sign_path)
    sign_dict = {key.replace('module.', ''): value for key, value in sign_dict.items()}
    # load the dense weights
    if 'str' in target_expt:
        print('Does not support STR models yet')
    model.load_state_dict(sign_dict)
    print(f'Loaded dense weights from expt {target_expt}')
    # prune the magnitudes
    print(f'Pruning magnitudes to {er_init} density')
    model = prune_mag(model, er_init)

    return model

### Adding mask ablations
def sparseFunction(x, s=-12800, activation=torch.relu, f=torch.sigmoid):
    s_init = s * torch.ones([1, 1]).to(x.device)
    return torch.sign(x)*activation(torch.abs(x)-f(s_init))


# Load the final mask, and an weight configuration finetune the model with a fixed sign.
@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_final_mask_and_intermediate_weight(model: nn.Module, target_expt: str, target_dir: str, load_sign_at: str):
    model_path = os.path.join(target_dir, target_expt)
    mask_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    sign_dict = torch.load(sign_path)
    sign_dict = {key.replace('module.', ''): value for key, value in sign_dict.items()}
    mask_dict = torch.load(mask_path)
    mask_dict = {key.replace('module.', ''): value for key, value in mask_dict.items()}

    mask_list = []
    if 'str' in target_expt:
        for k in mask_dict.keys():
            if ('weight' in k) and (('conv' in k) or ('fc' in k)):
                weight = mask_dict[k]
                k_s = k.replace("weight", "sparseThreshold")
                temp = sparseFunction(weight, mask_dict[k_s])
                mask = torch.where(temp != 0, 1, 0)
                mask_list.append(mask)
    else:
        for k in mask_dict.keys():
            if 'mask' in k:
                mask_list.append(mask_dict[k])
    
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            print(f'The batch norm weights are also being updated to the learnt ones')
            m.weight.data = sign_dict[n + '.weight']
            m.bias.data = sign_dict[n + '.bias']
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            name = n + '.weight'
            if 'str' in target_expt:
                weight = sign_dict[name]
                k_s = k.replace("weight", "sparseThreshold")
                weight = sparseFunction(weight, sign_dict[k_s])
            else:
                weight = sign_dict[name]
            mask = mask_list[cnt]
            if ('str' in target_expt) and isinstance(m, Conv1dMask):
                weight = weight.squeeze(dim=-1)
                mask = mask.squeeze(dim=-1)
            # initialize the weight (both mag and sign) with the intermediate checkpoint
            m.weight.data = weight
            # Load the final, learnt mask
            m.mask = mask
            cnt += 1
    
    print(f'Loaded sign {load_sign_at} and final mask from expt {target_expt}')
    return model

# Load the final mask, and an intermediate sign configuration with random magnitudes from the MW trained model.
@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_mw_final_mask_and_intermediate_sign(model: nn.Module, target_expt: str, target_dir: str, load_sign_at: str):
    model_path = os.path.join(target_dir, target_expt)
    mask_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    sign_dict = torch.load(sign_path)
    sign_dict = {key.replace('module.', ''): value for key, value in sign_dict.items()}
    mask_dict = torch.load(mask_path)
    mask_dict = {key.replace('module.', ''): value for key, value in mask_dict.items()}

    mask_list = []
    for k in mask_dict.keys():
        if 'mask' in k:
            mask_list.append(mask_dict[k])

    # Load the sign from mw and reinitialize the magnitude, along with the final mask.
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            print(f'The batch norm weights are {m.weight}, {m.bias}')
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):            
            w = sign_dict[n + '.weight']
            m_w = sign_dict[n + '.m']
            weight = m_w * w
            mask = mask_list[cnt]
            sign = weight.sign()
            m.weight.data = m.weight.data.abs() * sign
            m.mask = mask
            cnt += 1
    
    print(f'Loaded sign {load_sign_at} and final mask from expt {target_expt}')
    return model


# Load the final mask, and an intermediate sign configuration with random magnitudes to finetune the model.
@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_final_mask_and_intermediate_sign(model: nn.Module, target_expt: str, target_dir: str, load_sign_at: str):
    model_path = os.path.join(target_dir, target_expt)
    mask_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    sign_dict = torch.load(sign_path)
    sign_dict = {key.replace('module.', ''): value for key, value in sign_dict.items()}
    mask_dict = torch.load(mask_path)
    mask_dict = {key.replace('module.', ''): value for key, value in mask_dict.items()}

    mask_list = []
    if 'str' in target_expt:
        for k in mask_dict.keys():
            if ('weight' in k) and (('conv' in k) or ('fc' in k)):
                weight = mask_dict[k]
                k_s = k.replace("weight", "sparseThreshold")
                temp = sparseFunction(weight, mask_dict[k_s])
                mask = torch.where(temp != 0, 1, 0)
                mask_list.append(mask)
    else:
        for k in mask_dict.keys():
            if 'mask' in k:
                mask_list.append(mask_dict[k])

    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            print(f'The batch norm weights are {m.weight}, {m.bias}')
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            name = n + '.weight'
            if 'str' in target_expt:
                weight = sign_dict[name]
                k_s = k.replace("weight", "sparseThreshold")
                weight = sparseFunction(weight, sign_dict[k_s])
            else:
                weight = sign_dict[name]
            mask = mask_list[cnt]
            if ('str' in target_expt) and isinstance(m, Conv1dMask):
                weight = weight.squeeze(dim=-1)
                mask = mask.squeeze(dim=-1)
            sign = weight.sign()
            permuted_indices = torch.randperm(weight.numel())
            mag = weight.abs().flatten()[permuted_indices].view(weight.size())
            # m.weight.data = mag * sign
            m.weight.data = m.weight.data.abs() * sign
            m.mask = mask
            cnt += 1
    
    print(f'Loaded sign {load_sign_at} and final mask from expt {target_expt}')
    return model

# Load the final mask, and an intermediate magnitude configuration with random magnitudes to finetune the model.
@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_final_mask_and_intermediate_mag(model: nn.Module, target_expt: str, target_dir: str, load_sign_at: str):
    model_path = os.path.join(target_dir, target_expt)
    mask_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    mag_dict = torch.load(sign_path)
    mag_dict = {key.replace('module.', ''): value for key, value in mag_dict.items()}
    mask_dict = torch.load(mask_path)
    mask_dict = {key.replace('module.', ''): value for key, value in mask_dict.items()}

    mask_list = []
    if 'str' in target_expt:
        for k in mask_dict.keys():
            if ('weight' in k) and (('conv' in k) or ('fc' in k)):
                weight = mask_dict[k]
                k_s = k.replace("weight", "sparseThreshold")
                temp = sparseFunction(weight, mask_dict[k_s])
                mask = torch.where(temp != 0, 1, 0)
                mask_list.append(mask)
    else:
        for k in mask_dict.keys():
            if 'mask' in k:
                mask_list.append(mask_dict[k])

    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            print(f'The batch norm weights are {m.weight}, {m.bias}')
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            name = n + '.weight'
            if 'str' in target_expt:
                weight = mag_dict[name]
                k_s = k.replace("weight", "sparseThreshold")
                weight = sparseFunction(weight, mag_dict[k_s])
            else:
                weight = mag_dict[name]
            mask = mask_list[cnt]
            if ('str' in target_expt) and isinstance(m, Conv1dMask):
                weight = weight.squeeze(dim=-1)
                mask = mask.squeeze(dim=-1)
            mag = weight.abs()
            # Use the random init sign and the learnt amplitude from the mag_dict
            m.weight.data = m.weight.data.sign() * mag
            m.mask = mask
            cnt += 1
    
    print(f'Loaded magnitude {load_sign_at} and final mask from expt {target_expt}')
    return model

@param("prune_params.update_sign_every_level")
@param("prune_params.dst_reinit")
@param("prune_params.dst_method")
def dst_pruner(model: nn.Module, step: int, dst_reinit: str, world_size: int, T_end: int, acc_grads: dict, dst_method: str, update_sign_every_level: bool):
    def cosine_annealing():
        # Hyperparams chosen in accordance with https://arxiv.org/pdf/1911.11134
        alpha = 0.3
        return (alpha / 2) * (1 + np.cos((step * np.pi) / T_end))
    

    def linear_warmup_cosine_annealing():
        warmup_steps = int(0.1 * T_end)
        alpha = 0.3
        if step < warmup_steps:
            # Linear warmup
            return alpha * (step / warmup_steps)
        else:
            # Cosine annealing
            progress = (step - warmup_steps) / (T_end - warmup_steps)
            cosine_value = (alpha / 2) * (1 + np.cos(np.pi * progress))
            return cosine_value
    
    if step < T_end:
        if dst_method == 'rigl':
            dst_frac = cosine_annealing()
            print(f'Pruning and regrowing with RiGL: {dst_frac}')
            model = prune_grow_rigl(model, dst_frac, dst_reinit, acc_grads, world_size)
        elif dst_method == 'mest':
            dst_frac = cosine_annealing()
            print(f'Pruning and regrowing with MEST: {dst_frac}')
            model = prune_grow_mest(model, dst_frac, dst_reinit, acc_grads, world_size)
        elif dst_method == 'set':
            dst_frac = cosine_annealing()
            print(f'Pruning and regrowing with SET: {dst_frac}')
            model = prune_grow_set(model, dst_frac, dst_reinit, acc_grads, world_size)
        elif dst_method == 'sign-flip':
            dst_frac = linear_warmup_cosine_annealing()
            print(f'Flipping the sign of the mask based on the saliency: {dst_frac}')
            model = update_signs_in_random_mask(model, dst_frac, acc_grads, world_size)
        else:
            print('No DST pruner provided')
    else:
        print('Mask shuffling is complete, finetuning the final mask now.')
    return model

def prune_grow_rigl(model: nn.Module, dst_frac: float, dst_reinit: str, acc_grads: dict, world_size: int) -> None:

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            # Flatten mask and gradients for easier manipulation
            mask_flat = m.mask.flatten()
            # we choose the smallest magnitude weights and the largest magnitude gradients 
            weight_flat = m.weight.flatten().abs()
            orig_weight = m.weight.data.flatten()

            # the accumulated gradient is extracted from the hook
            gradients_flat = acc_grads[n].flatten().abs()

            # All reduce across multi GPUs
            dist.all_reduce(weight_flat, op=dist.ReduceOp.SUM)
            dist.all_reduce(gradients_flat, op=dist.ReduceOp.SUM)
            weight_flat = weight_flat / world_size
            gradients_flat = gradients_flat / world_size

            # Prune: Set a fraction of the smallest non-zero elements to zero
            nonzero_inds = mask_flat.nonzero().squeeze()
            nonzero_weights = weight_flat[nonzero_inds]
            num_prune = int(dst_frac * len(nonzero_weights))
            
            if num_prune > 0:
                smallest_nonzero_inds = nonzero_inds[nonzero_weights.argsort()[:num_prune]]
                mask_flat[smallest_nonzero_inds] = 0
                # take the complement of the pruned mask for growing
                # i.e. pruned params can be grown back in the same step
                zero_inds = (mask_flat == 0).nonzero().squeeze()
                if len(zero_inds.size()) == 0:
                    zero_inds = zero_inds.unsqueeze(0)

                # Regrow: Set a fraction of the largest gradient elements from zero to one
                zero_gradients = gradients_flat[zero_inds]
                largest_zero_gradient_inds = zero_inds[zero_gradients.argsort(descending=True)[:num_prune]]
                mask_flat[largest_zero_gradient_inds] = 1
                if dst_reinit == 'zero':
                    orig_weight[largest_zero_gradient_inds] = 0
                elif dst_reinit == 'flip':
                    saliency = orig_weight[largest_zero_gradient_inds] * gradients_flat[largest_zero_gradient_inds]
                    orig_weight[largest_zero_gradient_inds] = torch.where(saliency.sign() == 1, -1 * orig_weight[largest_zero_gradient_inds], orig_weight[largest_zero_gradient_inds])
                else:
                    pass
                
            # Reshape the mask back to its original shape
            m.mask = mask_flat.view_as(m.mask)

            # Mask the gradient
            m.weight.grad *= m.mask
            # set the reactivated parameters to zero for minimal change as recommended by authors.
            m.weight.data = orig_weight.view_as(m.weight)

    return model

def prune_grow_mest(model: nn.Module, dst_frac: float, dst_reinit: str, acc_grads: dict, world_size: int) -> None:

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            # Flatten mask and gradients for easier manipulation
            mask_flat = m.mask.flatten()
            # we choose the smallest magnitude weights and the largest magnitude gradients 
            weight_flat = m.weight.flatten().abs()
            orig_weight = m.weight.data.flatten()

            # the accumulated gradient is extracted from the hook
            gradients_flat = acc_grads[n].flatten().abs()

            # All reduce across multi GPUs
            dist.all_reduce(weight_flat, op=dist.ReduceOp.SUM)
            dist.all_reduce(gradients_flat, op=dist.ReduceOp.SUM)
            weight_flat = weight_flat / world_size
            gradients_flat = gradients_flat / world_size

            # Prune: Set a fraction of the smallest non-zero elements to zero
            nonzero_inds = mask_flat.nonzero().squeeze()
            nonzero_weights = weight_flat[nonzero_inds]
            nonzero_grads = gradients_flat[nonzero_inds]
            score = nonzero_weights + 0.001 * nonzero_grads
            num_prune = int(dst_frac * len(nonzero_weights))
            
            if num_prune > 0:
                smallest_nonzero_inds = nonzero_inds[score.argsort()[:num_prune]]
                mask_flat[smallest_nonzero_inds] = 0
                # take the complement of the pruned mask for growing
                # i.e. pruned params can be grown back in the same step
                zero_inds = (mask_flat == 0).nonzero().squeeze()
                if len(zero_inds.size()) == 0:
                    zero_inds = zero_inds.unsqueeze(0)

                # Regrow: Set a fraction of the random elements from zero to one
                zero_gradients = gradients_flat[zero_inds]
                random_grad_score = torch.rand(zero_gradients.size())
                largest_zero_gradient_inds = zero_inds[random_grad_score.argsort(descending=True)[:num_prune]]
                mask_flat[largest_zero_gradient_inds] = 1
                if dst_reinit == 'zero':
                    orig_weight[largest_zero_gradient_inds] = 0
                elif dst_reinit == 'flip':
                    saliency = orig_weight[largest_zero_gradient_inds] * gradients_flat[largest_zero_gradient_inds]
                    orig_weight[largest_zero_gradient_inds] = torch.where(saliency.sign() == 1, -1 * orig_weight[largest_zero_gradient_inds], orig_weight[largest_zero_gradient_inds])
                else:
                    pass
                
            # Reshape the mask back to its original shape
            m.mask = mask_flat.view_as(m.mask)

            # Mask the gradient
            m.weight.grad *= m.mask
            # set the reactivated parameters to zero for minimal change as recommended by authors.
            m.weight.data = orig_weight.view_as(m.weight)

    return model


def prune_grow_set(model: nn.Module, dst_frac: float, dst_reinit: str, acc_grads: dict, world_size: int) -> None:

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            # Prune the smallest magnitude weights and grow params randomly

            # Flatten mask and gradients for easier manipulation
            mask_flat = m.mask.flatten()
            # we choose the smallest magnitude weights and the largest magnitude gradients 
            weight_flat = m.weight.flatten().abs()
            orig_weight = m.weight.data.flatten()

            # the accumulated gradient is extracted from the hook
            gradients_flat = acc_grads[n].flatten().abs()

            # All reduce across multi GPUs
            dist.all_reduce(weight_flat, op=dist.ReduceOp.SUM)
            dist.all_reduce(gradients_flat, op=dist.ReduceOp.SUM)
            weight_flat = weight_flat / world_size
            gradients_flat = gradients_flat / world_size

            # Prune: Set a fraction of the smallest non-zero elements to zero
            nonzero_inds = mask_flat.nonzero().squeeze()
            nonzero_weights = weight_flat[nonzero_inds]
            num_prune = int(dst_frac * len(nonzero_weights))
            
            if num_prune > 0:
                smallest_nonzero_inds = nonzero_inds[nonzero_weights.argsort()[:num_prune]]
                mask_flat[smallest_nonzero_inds] = 0
                # take the complement of the pruned mask for growing
                # i.e. pruned params can be grown back in the same step
                zero_inds = (mask_flat == 0).nonzero().squeeze()
                if len(zero_inds.size()) == 0:
                    zero_inds = zero_inds.unsqueeze(0)

                # Regrow: Set a fraction of the random elements from zero to one
                zero_gradients = gradients_flat[zero_inds]
                random_grad_score = torch.rand(zero_gradients.size())
                largest_zero_gradient_inds = zero_inds[random_grad_score.argsort(descending=True)[:num_prune]]
                mask_flat[largest_zero_gradient_inds] = 1
                if dst_reinit == 'zero':
                    orig_weight[largest_zero_gradient_inds] = 0
                elif dst_reinit == 'flip':
                    saliency = orig_weight[largest_zero_gradient_inds] * gradients_flat[largest_zero_gradient_inds]
                    orig_weight[largest_zero_gradient_inds] = torch.where(saliency.sign() == 1, -1 * orig_weight[largest_zero_gradient_inds], orig_weight[largest_zero_gradient_inds])
                else:
                    pass
                
            # Reshape the mask back to its original shape
            m.mask = mask_flat.view_as(m.mask)

            # Mask the gradient
            m.weight.grad *= m.mask
            # set the reactivated parameters to zero for minimal change as recommended by authors.
            m.weight.data = orig_weight.view_as(m.weight)

    return model

# The idea is to flip the signs of a small fraction of params, while keeping the mask fixed (unlike RiGL)
def update_signs_in_random_mask(model: nn.Module, dst_frac: float, acc_grads: dict, world_size: int) -> None:

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            # Flatten mask and gradients for easier manipulation
            mask_flat = m.mask.flatten()
            # we choose the smallest magnitude weights and the largest magnitude gradients 
            weight_flat = m.weight.data.flatten().abs()
            orig_weight = m.weight.data.flatten()

            gradients_flat = acc_grads[n].flatten().abs()
            orig_grad = acc_grads[n].flatten()

            # All reduce across multi GPUs
            dist.all_reduce(weight_flat, op=dist.ReduceOp.SUM)
            dist.all_reduce(gradients_flat, op=dist.ReduceOp.SUM)
            weight_flat = weight_flat / world_size
            gradients_flat = gradients_flat / world_size

            # Prune: Set a fraction of the smallest non-zero elements to zero
            nonzero_inds = mask_flat.nonzero().squeeze()
            nonzero_weights = weight_flat[nonzero_inds]
            nonzero_grads = gradients_flat[nonzero_inds]
            # use the smallest mag
            saliency_score = nonzero_weights
            num_prune = int(dst_frac * len(nonzero_weights))

            if num_prune > 0:
                # calculate the smallest saliency score
                smallest_nonzero_gradient_inds = nonzero_inds[saliency_score.argsort()[:num_prune]]
                saliency = orig_weight[smallest_nonzero_gradient_inds] * orig_grad[smallest_nonzero_gradient_inds]
                # orig_weight[smallest_nonzero_gradient_inds] = torch.where(saliency.sign() == 1, -1 * orig_weight[smallest_nonzero_gradient_inds], orig_weight[smallest_nonzero_gradient_inds])
                orig_weight[smallest_nonzero_gradient_inds] = torch.where(saliency.sign() == 1, 0, orig_weight[smallest_nonzero_gradient_inds])
                
            # set the reactivated parameters to zero for minimal change as recommended by authors.
            m.weight.data = orig_weight.view_as(m.weight)

    return model

