## pythonic imports
import numpy as np
import os

## torch
import torch
import torch.nn as nn

## torchvision
import torchvision.models as models
from torch.amp import autocast
import torch.distributed as dist

## file-based imports
from utils.conv_type import ConvMask, Conv1dMask, LinearMask, STRConv, ConvMaskMW, MWConv1d, replace_layers, replace_vit_layers
from utils.harness_params import get_current_params
from utils.custom_models import PreActResNet, PreActBlock
from utils.dataset import imagenet, imagenet_pytorch, CIFARLoader, CIFARLoader_subsampled, imagenet_subsampled

# import deit model
from utils.deit import deit_tiny_patch16_224, deit_small_patch16_224 
from utils.res20 import resnet20

## fastargs
from fastargs import get_current_config
from fastargs.decorators import param
from typing import Optional, Dict, Any

import airbench

get_current_params()

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
DEFAULT_CROP_RATIO = 224 / 256

def perturb_signs(model: nn.Module, perturb_ratio: float) -> nn.Module:
    """Calculate the sparsity of the model.
    Args:
        model (nn.Module): The model whose signs to perturb.
    Returns:
        model (nn.Module)
    """
    # This function randomly flips the signs of perturb_ratio fraction of weights uniformly in each layer
    for n, m in model.named_modules():
        if isinstance(m, ConvMask):
            sign = torch.where(m.mask == 1, torch.where(torch.ones_like(m.mask).bernoulli_(perturb_ratio) == 1, -1, 1), 0).to(m.weight.device)
            m.weight.data = sign * m.weight.data
    return model

def get_sparsity(model: nn.Module) -> float:
    """Calculate the sparsity of the model.

    Args:
        model (nn.Module): The model to calculate sparsity for.

    Returns:
        float: The sparsity of the model.
    """
    nz = 0
    total = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            nz += m.mask.sum()
            total += m.mask.numel()
    print(f'mask size:{nz}')
    return nz / total


class PruningStuff:
    """Class to handle pruning-related functionalities."""

    def __init__(self, model: Optional[nn.Module] = None) -> None:
        """Initialize the PruningStuff class.

        Args:
            model (nn.Module, optional): The model to prune. Default is None.
                                         If no model is provided, then one is automatically set based on configuration.
        """
        self.this_device = "cuda:0"
        self.config = get_current_config()

        if self.config['dataset.dataset_name'] == 'ImageNet':
            if self.config['dataset.use_ffcv']:
                self.loaders = imagenet(distributed=False, this_device='cuda:0')
            else: 
                self.loaders =  imagenet_pytorch(distributed=False, this_device='cuda:0')
            self.train_loader = self.loaders.train_loader
        elif 'CIFAR' in self.config['dataset.dataset_name']:
            #self.train_loader = airbench.CifarLoader(path='./cifar10', batch_size=512, train=True, aug={'flip' : True, 'translate' : 2}, altflip=True)
            self.loaders = CIFARLoader(distributed=False)
            self.train_loader = self.loaders.train_loader
        if model is None:
            self.model = self.acquire_model()
        else:
            self.model = model
        self.criterion = nn.CrossEntropyLoss()

    @param("model_params.model_name")
    @param("dataset.dataset_name")
    def acquire_model(
        self, model_name: str, dataset_name: str
    ) -> nn.Module:
        """Acquire the model based on the provided parameters.

        Args:
            model_name (str): Name of the model.
            dataset_name (str): Name of the dataset.

        Returns:
            nn.Module: The acquired model.
        """
        if dataset_name == 'CIFAR10':
            num_classes = 10
        elif dataset_name == 'CIFAR100':
            num_classes = 100
        elif dataset_name == 'ImageNet':
            num_classes = 1000
        
        if model_name == "preresnet":
            model = PreActResNet(block=PreActBlock, num_blocks=[2, 2, 2, 2])
        
        elif "ImageNet" in dataset_name and "deit-small" in model_name:
            model = deit_small_patch16_224(pretrained=False)
        elif "ImageNet" in dataset_name and "deit-tiny" in model_name:
            print('Initializing a DeiT')
            model = deit_tiny_patch16_224(pretrained=False)
        elif "CIFAR10" in dataset_name and "resnet20" in model_name:
            model = resnet20()
        else:
            model = getattr(models, model_name)(num_classes=num_classes)

        if ("CIFAR" in dataset_name) and ("resnet" in model_name) and ("resnet20" not in model_name):
            model.conv1 = nn.Conv2d(
                    3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
                )
            model.maxpool = nn.Identity()

        if "CIFAR" in dataset_name and "vgg11" in model_name:
            model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            model.classifier = nn.Sequential(
                nn.Linear(512 * 1 * 1, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, num_classes),
            )

        # Replacing the layers.
        if "deit" in model_name:
            model = replace_vit_layers(model=model)
        else:
            model = replace_layers(model=model)        

        model.to(self.this_device)
        return model
    

    @param("prune_params.load_level")
    def load_init_mask_from_str_master_model(self, load_level: int) -> nn.Module:
        """ Load a mask and model separately from checkpoints created in the old codebase
        Args:
            level (int): pruning level we are at
            load_expt_name (str): the expt name to load the mask from
        Returns:
            nn.Module
        """
        original_dict = torch.load(
                os.path.join("-", ""))
        mask_list = torch.load("".format(load_level))
        
        new_dict = {}
        for k in original_dict.keys():
            new_k = k[7:]
            if 'fc.weight' in k:
                new_dict[new_k] = original_dict[k].squeeze(dim=3)
            else:
                new_dict[new_k] = original_dict[k]

        model_dict = self.model.state_dict()

        mask_dict = {}
        cnt = 0
        for k in new_dict.keys():
            if 'weight' in k and(('conv' in k) or ('fc' in k) or ('downsample.0.' in k)):
                if 'downsample' in k:
                    pass
                elif 'fc' in k:
                    mask_dict[k] = mask_list[cnt].squeeze(dim=3)
                else:
                    mask_dict[k] = mask_list[cnt]
                cnt += 1
        print('Mask keys are: ', mask_dict.keys())
        
        model_dict.update(new_dict)
        self.model.load_state_dict(model_dict)

        
        for n, m in self.model.named_modules():
            if isinstance(m, (Conv1dMask, ConvMask, LinearMask, ConvMaskMW)):
                print(n)
                print(m.mask.shape, mask_dict[n + '.weight'].shape)
                m.mask = mask_dict[n + '.weight']

        print('Loaded from imp-rewind-warmup-every-cycle-0.001-seed-42 level {}'.format(load_level))
    
    @param("prune_params.init_type")
    @param("prune_params.load_level")
    @param("prune_params.target_dir")
    @param("prune_params.load_only_warmup_sign")
    def load_init_and_mask(self, init_type: str, load_level: int, target_dir: str, load_only_warmup_sign: bool) -> nn.Module:
        """ Load a mask and initialization that can be trained with cyclic training.
        Args:
            init_type (str): Type of weights, init warmup or current
            level (int): pruning level we are at
            target_dir (str): the checkpoint to load the mask from
        Returns:
            nn.Module
        """
        if init_type == "warmup":
            original_dict = torch.load(
                os.path.join(target_dir, "checkpoints", "model_rewind.pt")
            )
        elif init_type == "init":
            original_dict = torch.load(
                os.path.join(target_dir, "checkpoints", "model_init.pt")
            )
        else:
            print("loading mask and model from the same level.")
        original_weights = dict(
            filter(lambda v: v[0].endswith((".weight", ".bias")), original_dict.items())
        )
        # loading a mask
        mask_dict = torch.load(
                os.path.join(target_dir, "checkpoints", "model_level_{}.pt".format(load_level))
            )
        mask_dict = dict(
            filter(lambda v: v[0].endswith((".mask")), mask_dict.items())
        )
        model_dict = self.model.state_dict()
        model_dict.update(original_weights)
        model_dict.update(mask_dict)
        self.model.load_state_dict(model_dict)

        if load_only_warmup_sign and init_type=='warmup':
            init_sign_list = []
            print('Loading only warmup signs and random weights with the mask')
            for n, m in self.model.named_modules():
                if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
                    init_sign_list.append(m.weight.sign())
            cnt = 0
            for n, m in self.model.named_modules():
                if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
                    perm = torch.randperm(m.weight.numel())
                    weight = m.weight.data.abs_().view(-1)[perm].view(m.weight.shape)
                    m.weight.data = weight.to(m.weight.device) * init_sign_list[cnt].to(m.weight.device)
                    cnt += 1
        print('Loaded init {} and mask from level {} of target {}'.format(init_type, load_level, target_dir))

    @param("prune_params.er_method")
    @param("prune_params.er_init")
    @param("prune_params.update_sign_init")
    def prune_at_initialization(
        self, er_method: str, er_init: float, update_sign_init: bool) -> nn.Module:
        """Prune the model at initialization.

        Args:
            er_method (str): Method of pruning.
            er_init (float): Initial sparsity target.

        Returns:
            nn.Module: The pruned model.
        """
        self.model = self.acquire_model()
        total = 0
        nonzero = 0
        for n, m in self.model.named_modules():
            if isinstance(m, (Conv1dMask, ConvMask, LinearMask, ConvMaskMW)):
                total += m.mask.numel()
                nonzero += m.mask.sum()

            if isinstance(m, STRConv):
                mask = m.get_mask()
                total += mask.numel()
                nonzero += mask.sum()
        print(f'{total}')
        print(f'density is {(total / nonzero) * 100:3f}')
                
        print('prior to pruning at init')
        er_method_name = f"prune_{er_method}"
        pruner_method = globals().get(er_method_name)
        if er_method in {"synflow", "snip"}:
            self.model = pruner_method(self.model, self.train_loader, er_init)
        elif er_method == "just dont":
            print('We dont prune at init')
            pass
        else:
            print('Prune method: {}'.format(er_method_name))
            pruner_method(self.model, er_init)

        if update_sign_init:
            update_func = globals().get('update_sign_from_grad')
            self.model = update_func(self.model, self.train_loader, frac=0.05)
        print('We just pruned at init, woohoo!')


    @param("prune_params.prune_method")
    @param("prune_params.update_sign_every_level")
    def level_pruner(self, prune_method: str, density: float, level: int, update_sign_every_level: bool) -> None:
        """Prune the model at a specific density level.

        Args:
            prune_method (str): Method of pruning.
            density (float): Desired density after pruning.
        """
        print("---" * 20)
        print(f"Density before pruning: {get_sparsity(self.model)}")
        print("---" * 20)

        prune_method_name = f"prune_{prune_method}"
        pruner_method = globals().get(prune_method_name)
        if prune_method in {"synflow", "snip"}:
            self.model = pruner_method(self.model, self.train_loader, density)
        elif prune_method == "just dont":
            print('We dont prune at the end of the level')
            pass
        else:
            pruner_method(self.model, density)
        
        if update_sign_every_level:
            update_func = globals().get('update_sign_from_grad')
            self.model = update_func(self.model, self.train_loader, frac=0.05)

        # put the model back on the GPU
        self.model.to(self.this_device)

        print("---" * 20)
        print(f"Density after pruning: {get_sparsity(self.model)}")
        print("---" * 20)


    def load_from_ckpt(self, path: str) -> None:
        """Load the model from a checkpoint.

        Args:
            path (str): Path to the checkpoint.
        """
        self.model.load_state_dict(torch.load(path))

def make_dense(model: nn.Module) -> nn.Module:
    """
    Make the mask to all ones i.e. dense network.
    """
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            m.mask = torch.ones_like(m.mask)
    
    print('The mask is now dense.')
    return model
    
def prune_mag(model: nn.Module, density: float, world_size=1, distributed=False) -> nn.Module:
    """Magnitude-based pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    score_list = {}


    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW, MWConv1d)):
            score_list[n] = (m.mask.to(m.weight.device) * m.weight).detach().abs_()

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])

    # takes scores across all GPUs into account while pruning.
    if distributed:
        dist.all_reduce(global_scores, op=dist.ReduceOp.SUM)
        global_scores = global_scores / world_size

    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0

        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW, MWConv1d)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()
    print(
        "Overall model density after magnitude pruning at current iteration = ",
        (total_num / total_den).item(),
    )
    return model


def prune_random_erk(model: nn.Module, density: float) -> nn.Module:
    """Random ERK-based pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    sparsity_list = []
    num_params_list = []
    total_params = 0
    score_list = {}

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            score_list[n] = (
                (
                    m.mask.to(m.weight.device)
                    * torch.randn_like(m.weight).to(m.weight.device)
                )
                .detach()
                .abs_()
            )
            sparsity_list.append(torch.tensor(m.weight.shape).sum() / m.weight.numel())
            num_params_list.append(m.weight.numel())
            total_params += m.weight.numel()

    num_params_kept = (
        torch.tensor(sparsity_list) * torch.tensor(num_params_list)
    ).sum()
    num_params_to_keep = total_params * density
    C = num_params_to_keep / num_params_kept
    print("Factor: ", C)
    sparsity_list = [torch.clamp(C * s, 0, 1) for s in sparsity_list]

    total_num = 0
    total_den = 0
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            global_scores = torch.flatten(score_list[n])
            k = int((1 - sparsity_list[cnt]) * global_scores.numel())
            if k == 0:
                threshold = 0
            else:
                threshold, _ = torch.kthvalue(global_scores, k)
            print("Layer", n, " params ", k, global_scores.numel())

            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            m.mask = torch.where(score <= threshold, zero, one)
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()
            cnt += 1

    print(
        "Overall model density after random global (ERK) pruning at current iteration = ",
        total_num / total_den,
    )
    return model

def update_sign_from_grad(model: nn.Module, trainloader: Any, frac: float) -> nn.Module:
    # This function is designed to choose a fraction of the smallest parameters and update their signs based on their gradient value
    num_steps = 10
    criterion = nn.CrossEntropyLoss()
    model.zero_grad()
    print('Updating the Sign of the mask')
    for i, (images, target) in enumerate(trainloader):
        images = images.to(torch.device("cuda"))
        target = target.to(torch.device("cuda")).long()
        with autocast(dtype=torch.bfloat16, device_type='cuda'):
            output = model(images)
            loss = criterion(output, target)
            loss.backward()
            # accumulate the gradient over multiple steps
        if i == num_steps:
            break
    
    # based on the value of the magnitude, flip the smallest fraction of weights inside the mask
    # since these values are close to zero and are more likely to flip
    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            score = (m.weight).detach().abs_()
            # elements outside the mask should have a high score so they are ingnored
            score = torch.where(m.mask.to(m.weight.device) == 0, 10.0, score)
            score_list[n] = score

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int(frac * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)
    if not k < 1:
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
                flip = torch.ones_like(m.weight.data)
                # if the weight and the gradient have same sign: change the weight sign,
                # if the weight and the gradient have opposite signs: keep the weight sign
                flip = torch.where(score_list[n] <= k, torch.where(m.weight.data.sign() * m.weight.grad.sign() == 1, -1, 1), 1)
                m.weight.data = flip * m.weight.data
    return model
    

def prune_snip(model: nn.Module, trainloader: Any, density: float) -> nn.Module:
    """SNIP method for pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        trainloader (Any): The training data loader.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    criterion = nn.CrossEntropyLoss()
    for i, (images, target) in enumerate(trainloader):
        images = images.to(torch.device("cuda"))
        target = target.to(torch.device("cuda")).long()
        with autocast(dtype=torch.bfloat16, device_type='cuda'):
            model.zero_grad()
            output = model(images)
            criterion(output, target).backward()
        break

    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score_list[n] = (
                (m.weight.grad * m.weight * m.mask.to(m.weight.device)).detach().abs_()
            )
    for n, m in model.named_modules():
        if isinstance(m, (ConvMaskMW)):
            score_list[n] = (
                ((m.w * m.w.grad + m.m * m.m.grad ) * m.mask.to(m.weight.device)).detach().abs_()
            )

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

    print(
        "Overall model density after snip pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def prune_synflow(model: nn.Module, trainloader: Any, density: float) -> nn.Module:
    """SynFlow method pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        trainloader (Any): The training data loader.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """

    @torch.no_grad()
    def linearize(model: nn.Module) -> Dict[str, torch.Tensor]:
        """Linearize the model by taking the absolute value of its parameters.

        Args:
            model (nn.Module): The model to linearize.

        Returns:
            Dict[str, torch.Tensor]: Dictionary of parameter signs.
        """
        signs = {}
        for name, param in model.state_dict().items():
            signs[name] = torch.sign(param)
            param.abs_()
        return signs

    @torch.no_grad()
    def nonlinearize(model: nn.Module, signs: Dict[str, torch.Tensor]) -> None:
        """Restore the signs of the model parameters.

        Args:
            model (nn.Module): The model to restore.
            signs (Dict[str, torch.Tensor]): Dictionary of parameter signs.
        """
        for n, param in model.state_dict().items():
            param.mul_(signs[n])

    signs = linearize(model)

    for i, (images, target) in enumerate(trainloader):
        images = images.to(torch.device("cuda"))
        target = target.to(torch.device("cuda")).long()
        input_dim = list(images[0, :].shape)
        input = torch.ones([1] + input_dim).to("cuda")
        with autocast(dtype=torch.bfloat16, device_type='cuda'):
            output = model(input)
            torch.sum(output).backward()
        break

    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score_list[n] = (
                (m.mask.to(m.weight.device) * m.weight.grad * m.weight).detach().abs_()
            )
        if isinstance(m, (ConvMaskMW)):
            score_list[n] = (
                    (m.mask.to(m.weight.device) * ( m.w * m.w.grad + m.m * m.m.grad)).detach().abs()
            )

    model.zero_grad()
    nonlinearize(model, signs)

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

    print(
        "Overall model density after synflow pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def prune_random_balanced(model: nn.Module, density: float) -> nn.Module:
    """Random balanced pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    total_params = 0
    l = 0
    sparsity_list = []
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            total_params += m.weight.numel()
            l += 1
    L = l
    X = density * total_params / l
    score_list = {}
    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            score_list[n] = (
                (
                    m.mask.to(m.weight.device)
                    * torch.randn_like(m.weight).to(m.weight.device)
                )
                .detach()
                .abs_()
            )

            if X / m.weight.numel() < 1.0:
                sparsity_list.append(X / m.weight.numel())
            else:
                sparsity_list.append(1)
                # correction for taking care of exact sparsity
                diff = X - m.mask.numel()
                X = X + diff / (L - l)
            l += 1

    total_num = 0
    total_den = 0
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            global_scores = torch.flatten(score_list[n])
            k = int((1 - sparsity_list[cnt]) * global_scores.numel())
            if k == 0:
                threshold = 0
            else:
                threshold, _ = torch.kthvalue(global_scores, k)
            print("Layer", n, " params ", k, global_scores.numel())

            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            m.mask = torch.where(score <= threshold, zero, one)
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()
            cnt += 1

    print(
        "Overall model density after random global (balanced) pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def prune_er_erk(model: nn.Module, er_sparse_init: float) -> None:
    """ERK-based pruning at initialization.

    Args:
        model (nn.Module): The model to prune.
        er_sparse_init (float): Initial sparsity target.
    """
    sparsity_list = []
    num_params_list = []
    total_params = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            sparsity_list.append(torch.tensor(m.weight.shape).sum() / m.weight.numel())
            num_params_list.append(m.weight.numel())
            total_params += m.weight.numel()

    num_params_kept = (
        torch.tensor(sparsity_list) * torch.tensor(num_params_list)
    ).sum()
    num_params_to_keep = total_params * er_sparse_init
    C = num_params_to_keep / num_params_kept
    sparsity_list = [torch.clamp(C * s, 0, 1) for s in sparsity_list]
    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            m.set_er_mask(sparsity_list[l])
            l += 1
    print(sparsity_list)


def prune_er_balanced(model: nn.Module, er_sparse_init: float) -> None:
    """ER-balanced pruning at initialization.

    Args:
        model (nn.Module): The model to prune.
        er_sparse_init (float): Initial sparsity target.
    """
    total_params = 0
    l = 0
    sparsity_list = []
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            total_params += m.weight.numel()
            l += 1
    L = l
    X = er_sparse_init * total_params / l
    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            if X / m.weight.numel() < 1.0:
                sparsity_list.append(X / m.weight.numel())
            else:
                sparsity_list.append(1)
                # correction for taking care of exact sparsity
                diff = X - m.mask.numel()
                X = X + diff / (L - l)
            l += 1

    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            m.set_er_mask(sparsity_list[l])
            l += 1
    print(sparsity_list)


@param("prune_params.update_sign_every_level")
@param("experiment_params.epochs_per_level")
def dst_pruner(model: nn.Module, step: int, steps_per_epoch: int, epochs_per_level: int, world_size: int, update_sign_every_level: bool):
    total_steps = steps_per_epoch * epochs_per_level
    T_end = int(0.75 * total_steps)
    def cosine_annealing():
        # Hyperparams chosen in accordance with https://arxiv.org/pdf/1911.11134
        alpha = 0.3
        return (alpha / 2) * (1 + np.cos((step * np.pi) / T_end))
    # right now only support RiGL but other DST methods can be included here
    dst_frac = cosine_annealing()
    if step < T_end:
        print(f'Pruning and regrowing with RiGL: {dst_frac}')
        model = prune_grow_rigl(model, dst_frac, world_size)
    else:
        print('Mask shuffling is complete, finetuning the final mask now.')
    return model

def prune_grow_rigl(model: nn.Module, dst_frac: float, world_size: int) -> None:

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, ConvMaskMW)):
            # Flatten mask and gradients for easier manipulation
            mask_flat = m.mask.flatten()
            # we choose the smallest magnitude weights and the largest magnitude gradients 
            weight_flat = m.weight.flatten().abs()
            orig_weight = m.weight.data.flatten()
            gradients_flat = m.weight.grad.flatten().abs()

            # All reduce across multi GPUs
            dist.all_reduce(weight_flat, op=dist.ReduceOp.SUM)
            dist.all_reduce(gradients_flat, op=dist.ReduceOp.SUM)
            weight_flat = weight_flat / world_size
            gradients_flat = gradients_flat / world_size

            # Prune: Set a fraction of the smallest non-zero elements to zero
            nonzero_inds = mask_flat.nonzero().squeeze()
            nonzero_weights = weight_flat[nonzero_inds]
            num_prune = int(dst_frac * len(nonzero_weights))

            # Regrow: Set a fraction of the largest gradient elements from zero to one
            zero_inds = (mask_flat == 0).nonzero().squeeze()
            if len(zero_inds) == 0:
                # skip the layer if the mask is dense
                pass
            zero_gradients = gradients_flat[zero_inds]
            
            if num_prune > 0:
                smallest_nonzero_inds = nonzero_inds[nonzero_weights.argsort()[:num_prune]]
                mask_flat[smallest_nonzero_inds] = 0
                largest_zero_gradient_inds = zero_inds[zero_gradients.argsort(descending=True)[:num_prune]]
                mask_flat[largest_zero_gradient_inds] = 1
                orig_weight[largest_zero_gradient_inds]=0

            # Reshape the mask back to its original shape
            m.mask = mask_flat.view_as(m.mask)

            # set the reactivated parameters to zero for minimal change as recommended by authors.
            m.weight.data = orig_weight.view_as(m.weight)

    return model
