## pythonic imports
import numpy as np
import os

## torch
import torch
import torch.nn as nn

## torchvision
import torchvision.models as models
from torch.amp import autocast
import torch.distributed as dist

## file-based imports
from utils.conv_type import ConvMask, Conv1dMask, LinearMask, STRConv, Conv1dMaskMW, ConvMaskMW, LinearMaskMW, replace_layers, replace_vit_layers
from utils.harness_params import get_current_params
from utils.custom_models import PreActResNet, PreActBlock
from utils.dataset import imagenet, imagenet_pytorch, CIFARLoader, CIFARLoader_subsampled, imagenet_subsampled
from utils.harness_utils import get_model_mask
# import deit model
from utils.deit import deit_tiny_patch16_224, deit_small_patch16_224 
from utils.res20 import resnet20

## fastargs
from fastargs import get_current_config
from fastargs.decorators import param
from typing import Optional, Dict, Any


get_current_params()

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
DEFAULT_CROP_RATIO = 224 / 256

def perturb_signs(model: nn.Module, perturb_ratio: float) -> nn.Module:
    """Calculate the sparsity of the model.
    Args:
        model (nn.Module): The model whose signs to perturb.
    Returns:
        model (nn.Module)
    """
    # This function randomly flips the signs of perturb_ratio fraction of weights uniformly in each layer
    for n, m in model.named_modules():
        if isinstance(m, ConvMask):
            sign = torch.where(m.mask == 1, torch.where(torch.ones_like(m.mask).bernoulli_(perturb_ratio) == 1, -1, 1), 0).to(m.weight.device)
            m.weight.data = sign * m.weight.data
    return model

def get_sparsity(model: nn.Module) -> float:
    """Calculate the sparsity of the model.

    Args:
        model (nn.Module): The model to calculate sparsity for.

    Returns:
        float: The sparsity of the model.
    """
    nz = 0
    total = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            nz += m.mask.sum()
            total += m.mask.numel()

    return nz / total


class PruningStuff:
    """Class to handle pruning-related functionalities."""

    def __init__(self, model: Optional[nn.Module] = None) -> None:
        """Initialize the PruningStuff class.

        Args:
            model (nn.Module, optional): The model to prune. Default is None.
                                         If no model is provided, then one is automatically set based on configuration.
        """
        self.this_device = "cuda:0"
        self.config = get_current_config()

        if self.config['dataset.dataset_name'] == 'ImageNet':
            if self.config['dataset.use_ffcv']:
                self.loaders = imagenet(distributed=False, this_device='cuda:0')
            else: 
                self.loaders =  imagenet_pytorch(distributed=False, this_device='cuda:0')
            self.train_loader = self.loaders.train_loader
        elif 'CIFAR' in self.config['dataset.dataset_name']:
            #self.train_loader = airbench.CifarLoader(path='./cifar10', batch_size=512, train=True, aug={'flip' : True, 'translate' : 2}, altflip=True)
            self.loaders = CIFARLoader(distributed=False)
            self.train_loader = self.loaders.train_loader
        if model is None:
            self.model = self.acquire_model()
        else:
            self.model = model
        self.criterion = nn.CrossEntropyLoss()

    @param("model_params.model_name")
    @param("model_params.tf_pretrained")
    @param("dataset.dataset_name")
    def acquire_model(
        self, model_name: str, dataset_name: str, tf_pretrained: bool
    ) -> nn.Module:
        """Acquire the model based on the provided parameters.

        Args:
            model_name (str): Name of the model.
            dataset_name (str): Name of the dataset.

        Returns:
            nn.Module: The acquired model.
        """
        if dataset_name == 'CIFAR10':
            num_classes = 10
        elif dataset_name == 'CIFAR100':
            num_classes = 100
        elif dataset_name == 'ImageNet':
            num_classes = 1000
        
        if model_name == "preresnet":
            model = PreActResNet(block=PreActBlock, num_blocks=[2, 2, 2, 2])
        
        elif "ImageNet" in dataset_name and "deit-small" in model_name:
            model = deit_small_patch16_224(pretrained=tf_pretrained)
        elif "ImageNet" in dataset_name and "deit-tiny" in model_name:
            print('Initializing a DeiT')
            model = deit_tiny_patch16_224(pretrained=tf_pretrained)
        elif "CIFAR10" in dataset_name and "resnet20" in model_name:
            model = resnet20()
        else:
            model = getattr(models, model_name)(num_classes=num_classes)

        if ("CIFAR" in dataset_name) and ("resnet" in model_name) and ("resnet20" not in model_name):
            model.conv1 = nn.Conv2d(
                    3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
                )
            model.maxpool = nn.Identity()

        if "CIFAR" in dataset_name and "vgg11" in model_name:
            model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            model.classifier = nn.Sequential(
                nn.Linear(512 * 1 * 1, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, num_classes),
            )

        # Replacing the layers.
        if "deit" in model_name:
            model = replace_vit_layers(model=model)
        else:
            model = replace_layers(model=model)        

        model.to(self.this_device)
        return model
    

    @param("prune_params.load_level")
    def load_init_mask_from_str_master_model(self, load_level: int) -> nn.Module:
        """ Load a mask and model separately from checkpoints created in the old codebase
        Args:
            level (int): pruning level we are at
            load_expt_name (str): the expt name to load the mask from
        Returns:
            nn.Module
        """
        original_dict = torch.load(
                os.path.join(PATH, "model_imagenet-imp-rewind-warmup-every-cycle-0.001-seed-42_init.pt"))
        mask_list = torch.load(PATH".format(load_level))
        
        new_dict = {}
        for k in original_dict.keys():
            new_k = k[7:]
            if 'fc.weight' in k:
                new_dict[new_k] = original_dict[k].squeeze(dim=3)
            else:
                new_dict[new_k] = original_dict[k]

        model_dict = self.model.state_dict()

        mask_dict = {}
        cnt = 0
        for k in new_dict.keys():
            if 'weight' in k and(('conv' in k) or ('fc' in k) or ('downsample.0.' in k)):
                if 'downsample' in k:
                    pass
                elif 'fc' in k:
                    mask_dict[k] = mask_list[cnt].squeeze(dim=3)
                else:
                    mask_dict[k] = mask_list[cnt]
                cnt += 1
        print('Mask keys are: ', mask_dict.keys())
        
        model_dict.update(new_dict)
        self.model.load_state_dict(model_dict)

        
        for n, m in self.model.named_modules():
            if isinstance(m, (Conv1dMask, ConvMask, LinearMask)):
                print(n)
                print(m.mask.shape, mask_dict[n + '.weight'].shape)
                m.mask = mask_dict[n + '.weight']

        print('Loaded from imp-rewind-warmup-every-cycle-0.001-seed-42 level {}'.format(load_level))
    
    @param("prune_params.init_type")
    @param("prune_params.load_level")
    @param("prune_params.target_dir")
    @param("prune_params.load_only_warmup_sign")
    def load_init_and_mask(self, init_type: str, load_level: int, target_dir: str, load_only_warmup_sign: bool) -> nn.Module:
        """ Load a mask and initialization that can be trained with cyclic training.
        Args:
            init_type (str): Type of weights, init warmup or current
            level (int): pruning level we are at
            target_dir (str): the checkpoint to load the mask from
        Returns:
            nn.Module
        """
        if init_type == "warmup":
            original_dict = torch.load(
                os.path.join(target_dir, "checkpoints", "model_rewind.pt")
            )
        elif init_type == "init":
            original_dict = torch.load(
                os.path.join(target_dir, "checkpoints", "model_init.pt")
            )
        else:
            print("loading mask and model from the same level.")
        original_weights = dict(
            filter(lambda v: v[0].endswith((".weight", ".bias")), original_dict.items())
        )
        # loading a mask
        mask_dict = torch.load(
                os.path.join(target_dir, "checkpoints", "model_level_{}.pt".format(load_level))
            )
        mask_dict = dict(
            filter(lambda v: v[0].endswith((".mask")), mask_dict.items())
        )
        model_dict = self.model.state_dict()
        model_dict.update(original_weights)
        model_dict.update(mask_dict)
        self.model.load_state_dict(model_dict)

        if load_only_warmup_sign and init_type=='warmup':
            init_sign_list = []
            print('Loading only warmup signs and random weights with the mask')
            for n, m in self.model.named_modules():
                if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                    init_sign_list.append(m.weight.sign())
            cnt = 0
            for n, m in self.model.named_modules():
                if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                    perm = torch.randperm(m.weight.numel())
                    weight = m.weight.data.abs_().view(-1)[perm].view(m.weight.shape)
                    m.weight.data = weight.to(m.weight.device) * init_sign_list[cnt].to(m.weight.device)
                    cnt += 1
        print('Loaded init {} and mask from level {} of target {}'.format(init_type, load_level, target_dir))

    @param("prune_params.er_method")
    @param("prune_params.er_init")
    @param("prune_params.update_sign_init")
    def prune_at_initialization(
        self, er_method: str, er_init: float, update_sign_init: bool) -> nn.Module:
        """Prune the model at initialization.

        Args:
            er_method (str): Method of pruning.
            er_init (float): Initial sparsity target.

        Returns:
            nn.Module: The pruned model.
        """
        self.model = self.acquire_model()
        total = 0
        nonzero = 0
        for n, m in self.model.named_modules():
            if isinstance(m, (Conv1dMask, ConvMask, LinearMask)):
                total += m.mask.numel()
                nonzero += m.mask.sum()

            if isinstance(m, STRConv):
                mask = m.get_mask()
                total += mask.numel()
                nonzero += mask.sum()
                
        print(f'density is {(total / nonzero) * 100:3f}')
                
        print('prior to pruning at init')
        er_method_name = f"prune_{er_method}"
        pruner_method = globals().get(er_method_name)
        if er_method in {"synflow", "snip"}:
            self.model = pruner_method(self.model, self.train_loader, er_init)
        elif er_method == 'load_sign_and_mask':
            self.model = load_final_mask_and_intermediate_sign(self.model)
        elif er_method == 'load_mag_and_mask':
            self.model = load_final_mask_and_intermediate_mag(self.model)
        elif er_method == 'load_weight_and_mask':
            self.model = load_final_mask_and_intermediate_weight(self.model)
        elif er_method == 'load_mw_sign_and_mask':
            self.model = load_mw_final_mask_and_intermediate_sign(self.model)
        elif er_method == 'load_only_mask':
            self.model = load_only_final_mask(self.model)
        elif er_method == 'uniform':
            self.model = prune_uniform(self.model, er_init)

            
        elif er_method == "just dont":
            print('We dont prune at init')
            pass
        elif er_method == 'anneal_balanced':
            print('Determining the random mask for sparse training post mask decay')
            prune_er_balanced(self.model, er_init)
            self.target_mask_list = get_model_mask(self.model)
            # now make the mask all ones again
            for n, m in self.model.named_modules():
                if isinstance(m, (Conv1dMask, ConvMask, LinearMask)):
                    m.mask = torch.ones_like(m.weight)
        else:
            print('Prune method: {}'.format(er_method_name))
            pruner_method(self.model, er_init)

        if update_sign_init:
            update_func = globals().get('update_sign_from_grad')
            self.model = update_func(self.model, self.train_loader, frac=0.05)
        print('We just pruned at init, woohoo!')


    @param("prune_params.prune_method")
    @param("prune_params.update_sign_every_level")
    def level_pruner(self, prune_method: str, density: float, level: int, update_sign_every_level: bool) -> None:
        """Prune the model at a specific density level.

        Args:
            prune_method (str): Method of pruning.
            density (float): Desired density after pruning.
        """
        print("---" * 20)
        print(f"Density before pruning: {get_sparsity(self.model)}")
        print("---" * 20)

        prune_method_name = f"prune_{prune_method}"
        pruner_method = globals().get(prune_method_name)
        if prune_method in {"synflow", "snip"}:
            self.model = pruner_method(self.model, self.train_loader, density)
        elif prune_method == "just dont":
            print('We dont prune at the end of the level')
            pass
        else:
            pruner_method(self.model, density)
        
        if update_sign_every_level:
            update_func = globals().get('update_sign_from_grad')
            self.model = update_func(self.model, self.train_loader, frac=0.05)

        # put the model back on the GPU
        self.model.to(self.this_device)

        print("---" * 20)
        print(f"Density after pruning: {get_sparsity(self.model)}")
        print("---" * 20)


    def load_from_ckpt(self, path: str) -> None:
        """Load the model from a checkpoint.

        Args:
            path (str): Path to the checkpoint.
        """
        self.model.load_state_dict(torch.load(path))


def set_target_mask(model, target_mask_list):
    cnt = 0
    print('Setting the target random mask')
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            m.mask = target_mask_list[cnt].to(m.weight.device)
            cnt += 1
    return model

@param("prune_params.anneal_scale_floor")
@param("prune_params.anneal_scale_ceil")
def anneal_scale(step, total_anneal_steps, anneal_scale_floor=0, anneal_scale_ceil=1):
    if step >= total_anneal_steps:
        return 0
    return anneal_scale_floor + (anneal_scale_ceil - anneal_scale_floor) * (step / total_anneal_steps)


def masked_l2_decay(model, target_mask_list, scale):
    # decay the params that are outside the mask
    curr_loss = 0   
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, STRConv)):
            mask = target_mask_list[cnt].to(m.weight.device)
            curr_loss += ((m.weight * (mask==0)) ** 2).sum()
            cnt += 1
    return scale * curr_loss

def make_dense(model: nn.Module) -> nn.Module:
    """
    Make the mask to all ones i.e. dense network.
    """
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            m.mask = torch.ones_like(m.mask)
    
    print('The mask is now dense.')
    return model
    
def prune_mag(model: nn.Module, density: float, world_size=1, distributed=False) -> nn.Module:
    """Magnitude-based pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    score_list = {}


    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (Conv1dMaskMW, LinearMaskMW, ConvMaskMW)):
                score_list[n] = (m.mask.to(m.weight.device) * m.weight * m.m).detach().abs_()
            else:
                score_list[n] = (m.mask.to(m.weight.device) * m.weight).detach().abs_()

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])

    # takes scores across all GPUs into account while pruning.
    if distributed:
        dist.all_reduce(global_scores, op=dist.ReduceOp.SUM)
        global_scores = global_scores / world_size

    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0

        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

                # if distributed:
                #     # add a sync of params with all_reduce
                #     dist.all_reduce(m.weight.data, op=dist.ReduceOp.SUM)  # Sum parameters across all GPUs
                #     dist.all_reduce(m.mask, op=dist.ReduceOp.SUM)  # Sum parameters across all GPUs
                #     m.weight.data /= dist.get_world_size()
                #     m.mask /= dist.get_world_size()
                #     assert torch.all((m.mask == 0) | (m.mask == 1)).item()

    print(
        "Overall model density after magnitude pruning at current iteration = ",
        (total_num / total_den).item(),
    )

    
    return model


def prune_random_erk(model: nn.Module, density: float) -> nn.Module:
    """Random ERK-based pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    sparsity_list = []
    num_params_list = []
    total_params = 0
    score_list = {}

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score_list[n] = (
                (
                    m.mask.to(m.weight.device)
                    * torch.randn_like(m.weight).to(m.weight.device)
                )
                .detach()
                .abs_()
            )
            sparsity_list.append(torch.tensor(m.weight.shape).sum() / m.weight.numel())
            num_params_list.append(m.weight.numel())
            total_params += m.weight.numel()

    num_params_kept = (
        torch.tensor(sparsity_list) * torch.tensor(num_params_list)
    ).sum()
    num_params_to_keep = total_params * density
    C = num_params_to_keep / num_params_kept
    print("Factor: ", C)
    sparsity_list = [torch.clamp(C * s, 0, 1) for s in sparsity_list]

    total_num = 0
    total_den = 0
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            global_scores = torch.flatten(score_list[n])
            k = int((1 - sparsity_list[cnt]) * global_scores.numel())
            if k == 0:
                threshold = 0
            else:
                threshold, _ = torch.kthvalue(global_scores, k)
            print("Layer", n, " params ", k, global_scores.numel())

            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            m.mask = torch.where(score <= threshold, zero, one)
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()
            cnt += 1

    print(
        "Overall model density after random global (ERK) pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def prune_uniform(model: nn.Module, density: float) -> nn.Module:
    """Random uniform pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    score_list = {}
    total_num = 0
    total_den = 0
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score_list[n] = m.mask * torch.randn_like(m.weight)
            global_scores = torch.flatten(score_list[n])
            k = int((1 - density) * global_scores.numel())
            if k == 0:
                threshold = 0
            else:
                threshold, _ = torch.kthvalue(global_scores, k)
            print("Layer", n, " params ", k, global_scores.numel())

            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            m.mask = torch.where(score <= threshold, zero, one)
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()
            cnt += 1

    print(
        "Overall model density after random global (ERK) pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def update_sign_from_grad(model: nn.Module, trainloader: Any, frac: float) -> nn.Module:
    # This function is designed to choose a fraction of the smallest parameters and update their signs based on their gradient value
    num_steps = 10
    criterion = nn.CrossEntropyLoss()
    model.zero_grad()
    print('Updating the Sign of the mask')
    for i, (images, target) in enumerate(trainloader):
        images = images.to(torch.device("cuda"))
        target = target.to(torch.device("cuda")).long()
        with autocast(dtype=torch.bfloat16, device_type='cuda'):
            output = model(images)
            loss = criterion(output, target)
            loss.backward()
            # accumulate the gradient over multiple steps
        if i == num_steps:
            break
    
    # based on the value of the magnitude, flip the smallest fraction of weights inside the mask
    # since these values are close to zero and are more likely to flip
    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score = (m.weight).detach().abs_()
            # elements outside the mask should have a high score so they are ingnored
            score = torch.where(m.mask.to(m.weight.device) == 0, 10.0, score)
            score_list[n] = score

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int(frac * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)
    if not k < 1:
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                flip = torch.ones_like(m.weight.data)
                # if the weight and the gradient have same sign: change the weight sign,
                # if the weight and the gradient have opposite signs: keep the weight sign
                flip = torch.where(score_list[n] <= k, torch.where(m.weight.data.sign() * m.weight.grad.sign() == 1, -1, 1), 1)
                m.weight.data = flip * m.weight.data
    return model
    

def prune_snip(model: nn.Module, trainloader: Any, density: float) -> nn.Module:
    """SNIP method for pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        trainloader (Any): The training data loader.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    criterion = nn.CrossEntropyLoss()
    
    model.zero_grad()
    for i, (images, target) in enumerate(trainloader):
        images = images.to(torch.device("cuda"))
        target = target.to(torch.device("cuda")).long()
        with autocast(dtype=torch.bfloat16, device_type='cuda'):
            output = model(images)
            criterion(output, target).backward()
        if i == 2:
            break
        
    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
                score = ((m.weight.grad * m.weight + m.m * m.m.grad) * m.mask.to(m.weight.device)).detach().abs_()
                # normalizing the score by its sum as in the paper
                score_list[n] = score
                
            else:    
                score = (m.weight.grad * m.weight * m.mask.to(m.weight.device)).detach().abs_()
                score_list[n] = score

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

    print(
        "Overall model density after snip pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def prune_synflow(model: nn.Module, trainloader: Any, density: float) -> nn.Module:
    """SynFlow method pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        trainloader (Any): The training data loader.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """

    @torch.no_grad()
    def linearize(model: nn.Module) -> Dict[str, torch.Tensor]:
        """Linearize the model by taking the absolute value of its parameters.

        Args:
            model (nn.Module): The model to linearize.

        Returns:
            Dict[str, torch.Tensor]: Dictionary of parameter signs.
        """
        signs = {}
        for name, param in model.state_dict().items():
            signs[name] = torch.sign(param)
            param.abs_()
        return signs

    @torch.no_grad()
    def nonlinearize(model: nn.Module, signs: Dict[str, torch.Tensor]) -> None:
        """Restore the signs of the model parameters.

        Args:
            model (nn.Module): The model to restore.
            signs (Dict[str, torch.Tensor]): Dictionary of parameter signs.
        """
        for n, param in model.state_dict().items():
            param.mul_(signs[n])

    signs = linearize(model)

    for i, (images, target) in enumerate(trainloader):
        images = images.to(torch.device("cuda"))
        target = target.to(torch.device("cuda")).long()
        input_dim = list(images[0, :].shape)
        input = torch.ones([1] + input_dim).to("cuda")
        with autocast(dtype=torch.bfloat16, device_type='cuda'):
            output = model(input)
            torch.sum(output).backward()
        break

    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
                score_list[n] = (
                    (m.mask.to(m.weight.device) * (m.weight.grad * m.weight + m.m * m.m.grad)).detach().abs_()
                )
            else:
                score_list[n] = (
                    (m.mask.to(m.weight.device) * m.weight.grad * m.weight).detach().abs_()
                )

    model.zero_grad()
    nonlinearize(model, signs)

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.0]).to(m.weight.device)
                one = torch.tensor([1.0]).to(m.weight.device)
                m.mask = torch.where(score <= threshold, zero, one)
                total_num += (m.mask == 1).sum()
                total_den += m.mask.numel()

    print(
        "Overall model density after synflow pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def prune_random_balanced(model: nn.Module, density: float) -> nn.Module:
    """Random balanced pruning of the model.

    Args:
        model (nn.Module): The model to prune.
        density (float): Desired density after pruning.

    Returns:
        nn.Module: The pruned model.
    """
    total_params = 0
    l = 0
    sparsity_list = []
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            total_params += m.weight.numel()
            l += 1
    L = l
    X = density * total_params / l
    score_list = {}
    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            score_list[n] = (
                (
                    m.mask.to(m.weight.device)
                    * torch.randn_like(m.weight).to(m.weight.device)
                )
                .detach()
                .abs_()
            )

            if X / m.weight.numel() < 1.0:
                sparsity_list.append(X / m.weight.numel())
            else:
                sparsity_list.append(1)
                # correction for taking care of exact sparsity
                diff = X - m.mask.numel()
                X = X + diff / (L - l)
            l += 1

    total_num = 0
    total_den = 0
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            global_scores = torch.flatten(score_list[n])
            k = int((1 - sparsity_list[cnt]) * global_scores.numel())
            if k == 0:
                threshold = 0
            else:
                threshold, _ = torch.kthvalue(global_scores, k)
            print("Layer", n, " params ", k, global_scores.numel())

            score = score_list[n].to(m.weight.device)
            zero = torch.tensor([0.0]).to(m.weight.device)
            one = torch.tensor([1.0]).to(m.weight.device)
            m.mask = torch.where(score <= threshold, zero, one)
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()
            cnt += 1

    print(
        "Overall model density after random global (balanced) pruning at current iteration = ",
        total_num / total_den,
    )
    return model


def prune_er_erk(model: nn.Module, er_sparse_init: float) -> None:
    """ERK-based pruning at initialization.

    Args:
        model (nn.Module): The model to prune.
        er_sparse_init (float): Initial sparsity target.
    """
    sparsity_list = []
    num_params_list = []
    total_params = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            sparsity_list.append(torch.tensor(m.weight.shape).sum() / m.weight.numel())
            num_params_list.append(m.weight.numel())
            total_params += m.weight.numel()

    num_params_kept = (
        torch.tensor(sparsity_list) * torch.tensor(num_params_list)
    ).sum()
    num_params_to_keep = total_params * er_sparse_init
    C = num_params_to_keep / num_params_kept
    sparsity_list = [torch.clamp(C * s, 0, 1) for s in sparsity_list]
    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            m.set_er_mask(sparsity_list[l])
            l += 1
    print(sparsity_list)


def prune_er_balanced(model: nn.Module, er_sparse_init: float) -> None:
    """ER-balanced pruning at initialization.

    Args:
        model (nn.Module): The model to prune.
        er_sparse_init (float): Initial sparsity target.
    """
    total_params = 0
    l = 0
    sparsity_list = []
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            total_params += m.weight.numel()
            l += 1
    L = l
    X = er_sparse_init * total_params / l
    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if X / m.weight.numel() < 1.0:
                sparsity_list.append(X / m.weight.numel())
            else:
                sparsity_list.append(1)
                # correction for taking care of exact sparsity
                diff = X - m.mask.numel()
                X = X + diff / (L - l)
            l += 1

    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            m.set_er_mask(sparsity_list[l])
            l += 1
    print(sparsity_list)

# Load the final mask, and an weight configuration finetune the model with a fixed sign.
@param("prune_params.target_expt")
@param("prune_params.target_dir")
def load_only_final_mask(model: nn.Module, target_expt: str, target_dir: str):
    model_path = os.path.join(target_dir, target_expt)
    mask_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    mask_dict = torch.load(mask_path)
    mask_dict = {key.replace('module.', ''): value for key, value in mask_dict.items()}

    mask_list = []
    if 'str' in target_expt:
        for k in mask_dict.keys():
            if ('weight' in k) and (('conv' in k) or ('fc' in k)):
                weight = mask_dict[k]
                k_s = k.replace("weight", "sparseThreshold")
                temp = sparseFunction(weight, mask_dict[k_s])
                mask = torch.where(temp != 0, 1, 0)
                mask_list.append(mask)
    else:
        for k in mask_dict.keys():
            if 'mask' in k:
                mask_list.append(mask_dict[k])
    
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            mask = mask_list[cnt]
            if ('str' in target_expt) and isinstance(m, Conv1dMask):
                mask = mask.squeeze(dim=-1)
            # Load the final, learnt mask
            m.mask = mask
            cnt += 1
    
    print(f'Loaded only final mask from expt {target_expt}')
    return model


### Adding mask ablations
def sparseFunction(x, s=-12800, activation=torch.relu, f=torch.sigmoid):
    s_init = s * torch.ones([1, 1]).to(x.device)
    return torch.sign(x)*activation(torch.abs(x)-f(s_init))


# Load the final mask, and an weight configuration finetune the model with a fixed sign.
@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_final_mask_and_intermediate_weight(model: nn.Module, target_expt: str, target_dir: str, load_sign_at: str):
    model_path = os.path.join(target_dir, target_expt)
    mask_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    sign_dict = torch.load(sign_path)
    sign_dict = {key.replace('module.', ''): value for key, value in sign_dict.items()}
    mask_dict = torch.load(mask_path)
    mask_dict = {key.replace('module.', ''): value for key, value in mask_dict.items()}

    mask_list = []
    if 'str' in target_expt:
        for k in mask_dict.keys():
            if ('weight' in k) and (('conv' in k) or ('fc' in k)):
                weight = mask_dict[k]
                k_s = k.replace("weight", "sparseThreshold")
                temp = sparseFunction(weight, mask_dict[k_s])
                mask = torch.where(temp != 0, 1, 0)
                mask_list.append(mask)
    else:
        for k in mask_dict.keys():
            if 'mask' in k:
                mask_list.append(mask_dict[k])
    
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            print(f'The batch norm weights are also being updated to the learnt ones')
            m.weight.data = sign_dict[n + '.weight']
            m.bias.data = sign_dict[n + '.bias']
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            name = n + '.weight'
            if 'str' in target_expt:
                weight = sign_dict[name]
                k_s = k.replace("weight", "sparseThreshold")
                weight = sparseFunction(weight, sign_dict[k_s])
            else:
                weight = sign_dict[name]
            mask = mask_list[cnt]
            if ('str' in target_expt) and isinstance(m, Conv1dMask):
                weight = weight.squeeze(dim=-1)
                mask = mask.squeeze(dim=-1)
            # initialize the weight (both mag and sign) with the intermediate checkpoint
            m.weight.data = weight
            # Load the final, learnt mask
            m.mask = mask
            cnt += 1
    
    print(f'Loaded sign {load_sign_at} and final mask from expt {target_expt}')
    return model

# Load the final mask, and an intermediate sign configuration with random magnitudes from the MW trained model.
@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_mw_final_mask_and_intermediate_sign(model: nn.Module, target_expt: str, target_dir: str, load_sign_at: str):
    model_path = os.path.join(target_dir, target_expt)
    mask_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    sign_dict = torch.load(sign_path)
    sign_dict = {key.replace('module.', ''): value for key, value in sign_dict.items()}
    mask_dict = torch.load(mask_path)
    mask_dict = {key.replace('module.', ''): value for key, value in mask_dict.items()}

    mask_list = []
    for k in mask_dict.keys():
        if 'mask' in k:
            mask_list.append(mask_dict[k])

    # Load the sign from mw and reinitialize the magnitude, along with the final mask.
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            print(f'The batch norm weights are {m.weight}, {m.bias}')
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):            
            w = sign_dict[n + '.weight']
            m_w = sign_dict[n + '.m']
            weight = m_w * w
            mask = mask_list[cnt]
            sign = weight.sign()
            m.weight.data = m.weight.data.abs() * sign
            m.mask = mask
            cnt += 1
    
    print(f'Loaded sign {load_sign_at} and final mask from expt {target_expt}')
    return model


# Load the final mask, and an intermediate sign configuration with random magnitudes to finetune the model.
@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_final_mask_and_intermediate_sign(model: nn.Module, target_expt: str, target_dir: str, load_sign_at: str):
    model_path = os.path.join(target_dir, target_expt)
    mask_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    sign_dict = torch.load(sign_path)
    sign_dict = {key.replace('module.', ''): value for key, value in sign_dict.items()}
    mask_dict = torch.load(mask_path)
    mask_dict = {key.replace('module.', ''): value for key, value in mask_dict.items()}

    mask_list = []
    if 'str' in target_expt:
        for k in mask_dict.keys():
            if ('weight' in k) and (('conv' in k) or ('fc' in k)):
                weight = mask_dict[k]
                k_s = k.replace("weight", "sparseThreshold")
                temp = sparseFunction(weight, mask_dict[k_s])
                mask = torch.where(temp != 0, 1, 0)
                mask_list.append(mask)
    else:
        for k in mask_dict.keys():
            if 'mask' in k:
                mask_list.append(mask_dict[k])

    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            print(f'The batch norm weights are {m.weight}, {m.bias}')
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            name = n + '.weight'
            if 'str' in target_expt:
                weight = sign_dict[name]
                k_s = k.replace("weight", "sparseThreshold")
                weight = sparseFunction(weight, sign_dict[k_s])
            else:
                weight = sign_dict[name]
            mask = mask_list[cnt]
            if ('str' in target_expt) and isinstance(m, Conv1dMask):
                weight = weight.squeeze(dim=-1)
                mask = mask.squeeze(dim=-1)
            sign = weight.sign()
            permuted_indices = torch.randperm(weight.numel())
            mag = weight.abs().flatten()[permuted_indices].view(weight.size())
            # m.weight.data = mag * sign
            m.weight.data = m.weight.data.abs() * sign
            m.mask = mask
            cnt += 1
    
    print(f'Loaded sign {load_sign_at} and final mask from expt {target_expt}')
    return model

# Load the final mask, and an intermediate magnitude configuration with random magnitudes to finetune the model.
@param("prune_params.target_expt")
@param("prune_params.target_dir")
@param("prune_params.load_sign_at")
def load_final_mask_and_intermediate_mag(model: nn.Module, target_expt: str, target_dir: str, load_sign_at: str):
    model_path = os.path.join(target_dir, target_expt)
    mask_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    if load_sign_at == 'final':
        sign_path = os.path.join(model_path, 'checkpoints/model_level_0.pt')
    elif load_sign_at == 'warmup':
        sign_path = os.path.join(model_path, 'checkpoints/model_0_9.pt')
    else:
        assert (int(load_sign_at) > 9) and (int(load_sign_at) <  99)
        sign_path = os.path.join(model_path, 'checkpoints/model_0_{}.pt'.format(load_sign_at))

    mag_dict = torch.load(sign_path)
    mag_dict = {key.replace('module.', ''): value for key, value in mag_dict.items()}
    mask_dict = torch.load(mask_path)
    mask_dict = {key.replace('module.', ''): value for key, value in mask_dict.items()}

    mask_list = []
    if 'str' in target_expt:
        for k in mask_dict.keys():
            if ('weight' in k) and (('conv' in k) or ('fc' in k)):
                weight = mask_dict[k]
                k_s = k.replace("weight", "sparseThreshold")
                temp = sparseFunction(weight, mask_dict[k_s])
                mask = torch.where(temp != 0, 1, 0)
                mask_list.append(mask)
    else:
        for k in mask_dict.keys():
            if 'mask' in k:
                mask_list.append(mask_dict[k])

    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            print(f'The batch norm weights are {m.weight}, {m.bias}')
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            name = n + '.weight'
            if 'str' in target_expt:
                weight = mag_dict[name]
                k_s = k.replace("weight", "sparseThreshold")
                weight = sparseFunction(weight, mag_dict[k_s])
            else:
                weight = mag_dict[name]
            mask = mask_list[cnt]
            if ('str' in target_expt) and isinstance(m, Conv1dMask):
                weight = weight.squeeze(dim=-1)
                mask = mask.squeeze(dim=-1)
            mag = weight.abs()
            # Use the random init sign and the learnt amplitude from the mag_dict
            m.weight.data = m.weight.data.sign() * mag
            m.mask = mask
            cnt += 1
    
    print(f'Loaded magnitude {load_sign_at} and final mask from expt {target_expt}')
    return model

@param("prune_params.update_sign_every_level")
@param("prune_params.dst_reinit")
@param("prune_params.dst_method")
def dst_pruner(model: nn.Module, step: int, dst_reinit: str, world_size: int, T_end: int, acc_grads: dict, dst_method: str, update_sign_every_level: bool):
    def cosine_annealing():
        # Hyperparams chosen in accordance with https://arxiv.org/pdf/1911.11134
        alpha = 0.3
        return (alpha / 2) * (1 + np.cos((step * np.pi) / T_end))
    

    def linear_warmup_cosine_annealing():
        warmup_steps = int(0.1 * T_end)
        alpha = 0.3
        if step < warmup_steps:
            # Linear warmup
            return alpha * (step / warmup_steps)
        else:
            # Cosine annealing
            progress = (step - warmup_steps) / (T_end - warmup_steps)
            cosine_value = (alpha / 2) * (1 + np.cos(np.pi * progress))
            return cosine_value
    
    if step < T_end:
        if dst_method == 'rigl':
            dst_frac = cosine_annealing()
            print(f'Pruning and regrowing with RiGL: {dst_frac}')
            model = prune_grow_rigl(model, dst_frac, dst_reinit, acc_grads, world_size)
        elif dst_method == 'sign-flip':
            dst_frac = linear_warmup_cosine_annealing()
            print(f'Flipping the sign of the mask based on the saliency: {dst_frac}')
            model = update_signs_in_random_mask(model, dst_frac, acc_grads, world_size)
        else:
            print('No DST pruner provided')
    else:
        print('Mask shuffling is complete, finetuning the final mask now.')
    return model

def prune_grow_rigl(model: nn.Module, dst_frac: float, dst_reinit: str, acc_grads: dict, world_size: int) -> None:

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            # Flatten mask and gradients for easier manipulation
            mask_flat = m.mask.flatten()
            # we choose the smallest magnitude weights and the largest magnitude gradients 
            weight_flat = m.weight.flatten().abs()
            orig_weight = m.weight.data.flatten()

            # the accumulated gradient is extracted from the hook
            gradients_flat = acc_grads[n].flatten().abs()

            # All reduce across multi GPUs
            dist.all_reduce(weight_flat, op=dist.ReduceOp.SUM)
            dist.all_reduce(gradients_flat, op=dist.ReduceOp.SUM)
            weight_flat = weight_flat / world_size
            gradients_flat = gradients_flat / world_size

            # Prune: Set a fraction of the smallest non-zero elements to zero
            nonzero_inds = mask_flat.nonzero().squeeze()
            nonzero_weights = weight_flat[nonzero_inds]
            num_prune = int(dst_frac * len(nonzero_weights))
            
            if num_prune > 0:
                smallest_nonzero_inds = nonzero_inds[nonzero_weights.argsort()[:num_prune]]
                mask_flat[smallest_nonzero_inds] = 0
                # take the complement of the pruned mask for growing
                # i.e. pruned params can be grown back in the same step
                zero_inds = (mask_flat == 0).nonzero().squeeze()
                if len(zero_inds.size()) == 0:
                    zero_inds = zero_inds.unsqueeze(0)

                # Regrow: Set a fraction of the largest gradient elements from zero to one
                zero_gradients = gradients_flat[zero_inds]
                largest_zero_gradient_inds = zero_inds[zero_gradients.argsort(descending=True)[:num_prune]]
                mask_flat[largest_zero_gradient_inds] = 1
                if dst_reinit == 'zero':
                    orig_weight[largest_zero_gradient_inds] = 0
                elif dst_reinit == 'flip':
                    saliency = orig_weight[largest_zero_gradient_inds] * gradients_flat[largest_zero_gradient_inds]
                    orig_weight[largest_zero_gradient_inds] = torch.where(saliency.sign() == 1, -1 * orig_weight[largest_zero_gradient_inds], orig_weight[largest_zero_gradient_inds])
                else:
                    pass
                
            # Reshape the mask back to its original shape
            m.mask = mask_flat.view_as(m.mask)

            # Mask the gradient
            m.weight.grad *= m.mask
            # set the reactivated parameters to zero for minimal change as recommended by authors.
            m.weight.data = orig_weight.view_as(m.weight)

    return model


# The idea is to flip the signs of a small fraction of params, while keeping the mask fixed (unlike RiGL)
def update_signs_in_random_mask(model: nn.Module, dst_frac: float, acc_grads: dict, world_size: int) -> None:

    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            # Flatten mask and gradients for easier manipulation
            mask_flat = m.mask.flatten()
            # we choose the smallest magnitude weights and the largest magnitude gradients 
            weight_flat = m.weight.data.flatten().abs()
            orig_weight = m.weight.data.flatten()

            gradients_flat = acc_grads[n].flatten().abs()
            orig_grad = acc_grads[n].flatten()

            # All reduce across multi GPUs
            dist.all_reduce(weight_flat, op=dist.ReduceOp.SUM)
            dist.all_reduce(gradients_flat, op=dist.ReduceOp.SUM)
            weight_flat = weight_flat / world_size
            gradients_flat = gradients_flat / world_size

            # Prune: Set a fraction of the smallest non-zero elements to zero
            nonzero_inds = mask_flat.nonzero().squeeze()
            nonzero_weights = weight_flat[nonzero_inds]
            nonzero_grads = gradients_flat[nonzero_inds]
            # use the smallest mag
            saliency_score = nonzero_weights
            num_prune = int(dst_frac * len(nonzero_weights))

            if num_prune > 0:
                # calculate the smallest saliency score
                smallest_nonzero_gradient_inds = nonzero_inds[saliency_score.argsort()[:num_prune]]
                saliency = orig_weight[smallest_nonzero_gradient_inds] * orig_grad[smallest_nonzero_gradient_inds]
                # orig_weight[smallest_nonzero_gradient_inds] = torch.where(saliency.sign() == 1, -1 * orig_weight[smallest_nonzero_gradient_inds], orig_weight[smallest_nonzero_gradient_inds])
                orig_weight[smallest_nonzero_gradient_inds] = torch.where(saliency.sign() == 1, 0, orig_weight[smallest_nonzero_gradient_inds])
                
            # set the reactivated parameters to zero for minimal change as recommended by authors.
            m.weight.data = orig_weight.view_as(m.weight)

    return model

