import numpy as np
import torch
import torch.distributed as dist
import os
from dst_util import get_W
from sparse_topology_initialization import create_ws_sparse_scheduler
import math
from scipy.sparse import csr_matrix
import sys
sys.path.append("../../")
# import CH_scores
from scipy.io import loadmat, savemat
from sparse_topology_initialization import update_topology_scheduler


def chain_removal(layer1, layer2):
    layer1 = remove_unactive_links_backward(layer1, layer2)
    layer2 = remove_unactive_links_forward(layer2, layer1)

    return layer1, layer2

def qk_chain_removal(q, k):
    q = remove_unactive_links_backward(q, k.transpose(1, 0))
    k = remove_unactive_links_backward(k, q.transpose(1, 0))

    return q, k


def remove_unactive_links_backward(current_adj, after_adj):
    outdegree = torch.sum(after_adj, dim=0)
    outdegree[outdegree>0] = 1
    current_num = torch.sum(current_adj)
    current_adj = current_adj * outdegree.reshape(-1, 1)
    # print("Number of removed unactive links backwards: ", int(current_num - torch.sum(current_adj)))

    return current_adj

def remove_unactive_links_forward(current_adj, before_adj):
    indegree = torch.sum(before_adj, dim=1)
    indegree[indegree>0] = 1
    current_num = torch.sum(current_adj)
    current_adj = current_adj * indegree.reshape(1, -1)
    # print("Number of removed unactive links forwards: ", int(current_num - torch.sum(current_adj)))
    return current_adj




class IndexMaskHook:
    def __init__(self, layer, scheduler):
        self.layer = layer
        self.scheduler = scheduler
        self.dense_grad = None
        if self.scheduler.args.adaptive_shield:
            self.mean_W = torch.zeros_like(self.scheduler.W[self.layer].data)
            self.mean_W2 = torch.zeros_like(self.scheduler.W[self.layer].data)
    def __name__(self):
        return 'IndexMaskHook'

    @torch.no_grad()
    def __call__(self, grad):
        mask = self.scheduler.backward_masks[self.layer]

        if self.scheduler.args.adaptive_shield:
            if self.scheduler.step % self.scheduler.delta_T > (self.scheduler.delta_T - 10):
                # Only accumulate 10 times
            
                weight = self.scheduler.W[self.layer].data
                self.num_updates += 1
                
                self.mean_W += (weight - self.mean_W) / self.num_updates
                self.mean_W2 += (weight**2 - self.mean_W2) / self.num_updates
            else:
                self.num_updates = 0
                self.mean_W.zero_()
                self.mean_W2.zero_()
        
        # only calculate dense_grads when necessary
        if self.scheduler.check_if_backward_hook_should_accumulate_grad():
            if self.dense_grad is None:
                # initialize as all 0s so we can do a rolling average
                self.dense_grad = torch.zeros_like(grad)
            self.dense_grad += grad / self.scheduler.grad_accumulation_n
        else:
            self.dense_grad = None

        # print(f"Layer {self.layer}: mask sparsity is {torch.sum(mask).item() / self.scheduler.N[self.layer]}")
        return grad * mask
    def get_stats(self):
        std = torch.sqrt(self.mean_W2 - self.mean_W ** 2)  # Var(W) = E[W^2] - (E[W])^2
        return torch.abs(std / self.mean_W)


def _create_step_wrapper(scheduler, optimizer):
    if scheduler.args.ssam:
        _unwrapped_step = optimizer._optimizer.second_step
    else:
        _unwrapped_step = optimizer.step
    def _wrapped_step():
        if scheduler.args.ssam:
            _unwrapped_step(zero_grad=False)
        else:
            _unwrapped_step()
        scheduler.reset_momentum()
        scheduler.apply_mask_to_weights()
    optimizer.step = _wrapped_step


class DSTScheduler:

    def __init__(self, model, optimizer, T_end=None, sparsity_distribution='uniform', ignore_linear_layers=True, delta=100, alpha=0.3, static_topo=False, grad_accumulation_n=1, state_dict=None, args=None):
        self.args = args
        self.dense_allocation = 1 - self.args.granet_init_sparsity if self.args.granet or self.args.gmp else 1 - self.args.sparsity
        if self.dense_allocation <= 0 or self.dense_allocation > 1:
            raise Exception('Dense allocation must be on the interval (0, 1]. Got: %f' % self.dense_allocation)
        
        
        self.global_sparsity = 1 - self.dense_allocation
        self.model = model
        self.optimizer = optimizer

        # Need to get the qk_chain_list and chain_list
        self.W, self.chain_list, self.qk_chain_list = get_W(model, args)
        # if distributed these values will be populated
        self.is_dist = dist.is_initialized()
        self.world_size = dist.get_world_size() if self.is_dist else None

        # modify optimizer.step() function to call "reset_momentum" after
        _create_step_wrapper(self, optimizer)
            
        self.N = [torch.numel(w) for w in self.W]
        if self.args.early_stop:
            self.early_stop_signal = torch.zeros(len(self.W))
        if state_dict is not None:
            self.load_state_dict(state_dict)
            self.apply_mask_to_weights()

        else:
            self.sparsity_distribution = sparsity_distribution
            self.static_topo = static_topo
            self.grad_accumulation_n = grad_accumulation_n
            self.ignore_linear_layers = ignore_linear_layers
            self.backward_masks = None

            # define sparsity allocation
            self.S = []
            for i, W in enumerate(self.W):
                if self.args.EM_S:
                    self.S.append((1 - self.dense_allocation - 0.05))
                elif self.args.granet or self.args.gmp:
                    self.S.append(1 - self.dense_allocation)
                else:
                    self.S.append(1 - self.dense_allocation)
            if args.init_mode == "swi" or args.init_mode == "kaiming":
                # reset the parameters with swi
                self.reset_parameters()
            if self.args.history_weights:
                self.history_masks = [self.W[i].detach().clone().cpu() for i in range(len(self.W))]
            # randomly sparsify model according to S
            self.random_sparsify()
            if self.args.new_history_weights:
                self.history_masks = [self.W[i].detach().clone().cpu() for i in range(len(self.W))]
            # scheduler keeps a log of how many times it's called. this is how it does its scheduling
            self.step = 0
            self.dst_steps = 0

            # define the actual schedule
            self.delta_T = delta
            self.alpha = alpha
            self.T_end = T_end

        # also, register backward hook so sparse elements cannot be recovered during normal training
        self.backward_hook_objects = []
        for i, w in enumerate(self.W):
            # if sparsity is 0%, skip
            if self.S[i] <= 0:
                self.backward_hook_objects.append(None)
                continue
            if getattr(w, '_has_rigl_backward_hook', False):
                print(i, w.shape)
                # print()
                raise Exception('This model already has been registered to a DSTScheduler.')
        
            self.backward_hook_objects.append(IndexMaskHook(i, self))
            w.register_hook(self.backward_hook_objects[-1])
            setattr(w, '_has_rigl_backward_hook', True)

        if self.args.save_new:
            self.count_steps_mask = [torch.zeros_like(self.W[i]).to(self.W[i].device) for i in range(len(self.W))]
            self.saving_links_mask = [torch.zeros_like(self.W[i]).to(self.W[i].device) for i in range(len(self.W))]
            self.not_saving_links_mask = [torch.zeros_like(self.W[i]).to(self.W[i].device) for i in range(len(self.W))]

                    
        if self.args.adaptive_shield:
            self.saving_links_mask = [torch.zeros_like(self.W[i]).bool().to(self.W[i].device) for i in range(len(self.W))]
            self.not_saving_links_mask = [torch.zeros_like(self.W[i]).bool().to(self.W[i].device) for i in range(len(self.W))]

        self.pruning_T_end = self.args.pruning_T_end
        self.final_iter = int(self.pruning_T_end / self.delta_T)
        self.ini_iter = int(self.args.granet_init_step/ self.delta_T)
        print(f"Initial_iter: {self.ini_iter}, Final_iter: {self.final_iter}")
        self.total_prune_iter = self.final_iter - self.ini_iter

        assert self.grad_accumulation_n > 0
        assert self.sparsity_distribution in ('uniform', 'non-uniform')




    def state_dict(self):
        obj = {
            'dense_allocation': self.dense_allocation,
            'S': self.S,
            'N': self.N,
            'hyperparams': {
                'delta_T': self.delta_T,
                'alpha': self.alpha,
                'T_end': self.T_end,
                'ignore_linear_layers': self.ignore_linear_layers,
                'static_topo': self.static_topo,
                'sparsity_distribution': self.sparsity_distribution,
                'grad_accumulation_n': self.grad_accumulation_n,
            },
            'step': self.step,
            'dst_steps': self.dst_steps,
            'backward_masks': self.backward_masks,
            # '_linear_layers_mask': self._linear_layers_mask,
        }

        return obj

    def load_state_dict(self, state_dict):
        for k, v in state_dict.items():
            if type(v) == dict:
                self.load_state_dict(v)
            setattr(self, k, v)


    @torch.no_grad()
    def random_sparsify(self):
        self.backward_masks = []
        self.record_mask = []
        for l, w in enumerate(self.W):
            # if sparsity is 0%, skip
            if self.S[l] <= 0:
                self.backward_masks.append(None)
                continue
            
            if self.args.WS:
                # broadcast tensort must be contiguous
                mask = create_ws_sparse_scheduler(self.S[l], w, self.args).contiguous()

            else:
                n = self.N[l]
                s = int(self.S[l] * n)
                perm = torch.randperm(n)
                perm = perm[:s]
                flat_mask = torch.ones(n, device=w.device)
                flat_mask[perm] = 0
                mask = torch.reshape(flat_mask, w.shape)

            mask = mask.bool()
            if self.is_dist:
                dist.broadcast(mask, 0)

            
            w *= mask
            self.backward_masks.append(mask)
            if self.args.itop:
                self.record_mask.append(mask)


    def __str__(self):
        s = 'DSTScheduler(\n'
        s += 'layers=%i,\n' % len(self.N)

        # calculate the number of non-zero elements out of the total number of elements
        N_str = '['
        S_str = '['
        sparsity_percentages = []
        total_params = 0
        total_nonzero = 0

        for N, S, mask, W in zip(self.N, self.S, self.backward_masks, self.W):
            
            actual_S = torch.sum(W[mask == 0] == 0).item()
            # print(torch.sum(mask).item(), actual_S)
            N_str += ('%i/%i, ' % (N-actual_S, N))
            sp_p = float(N-actual_S) / float(N) * 100
            S_str += '%.2f%%, ' % sp_p
            sparsity_percentages.append(sp_p)
            total_params += N
            total_nonzero += N-actual_S
        
        N_str = N_str[:-2] + ']'
        S_str = S_str[:-2] + ']'
        
        s += 'nonzero_params=' + N_str + ',\n'
        s += 'nonzero_percentages=' + S_str + ',\n'
        s += 'total_nonzero_params=' + ('%i/%i (%.2f%%)' % (total_nonzero, total_params, float(total_nonzero)/float(total_params)*100)) + ',\n'
        # s += 'total_CONV_nonzero_params=' + ('%i/%i (%.2f%%)' % (total_conv_nonzero, total_conv_params, float(total_conv_nonzero)/float(total_conv_params)*100)) + ',\n'
        s += 'step=' + str(self.step) + ',\n'
        s += 'num_dst_steps=' + str(self.dst_steps) + ',\n'
        s += 'ignoring_linear_layers=' + str(self.ignore_linear_layers) + ',\n'
        s += 'sparsity_distribution=' + str(self.sparsity_distribution) + ',\n'
        if self.args.WS:
            s += 'WS=True, WS_beta=' + str(self.args.ws_beta) + ',\n'
        self.global_sparsity = 1 - total_nonzero/total_params
        if self.args.gmp:
            s += f'GMP=True, init_sparsity={self.args.granet_init_sparsity}\n'
            s += f'pruning_scheduler={self.args.pruning_scheduler}, pruning_method={self.args.pruning_method}\n'
        elif self.args.regrow_method == "fc":
            s += 'FC=True,\n'
        else:
            if self.args.granet:
                s += f'granet=True, init_sparsity={self.args.granet_init_sparsity}\n'
                s += f'pruning_scheduler={self.args.pruning_scheduler}, pruning_method={self.args.pruning_method}\n'

            if self.args.EM_S:
                s += 'EM_S=True,\n'
            elif self.args.adaptive_zeta:
                s += 'Adaptive_zeta=True,\n'
            s += 'regrow_method=' + self.args.regrow_method + ',\n'
            s += 'remove_method=' + self.args.remove_method + ',\n'
            if self.args.adaptive_shield:
                s+= 'adaptive_shield=True,\n'
                s+= f'shield_threshold={self.args.shield_threshold},\n'
            elif self.args.save_new:
                s+= f'save_new=True, saving_steps: {self.args.saving_steps}\n'
        
        if self.args.history_weights:
            s += 'history_weights=True,\n'
            
        s += 'target_sparsity=' + str(self.args.sparsity) + ',\n'
        

        return s + ')'


    @torch.no_grad()
    def reset_momentum(self):
        for w, mask, s in zip(self.W, self.backward_masks, self.S):
            # if sparsity is 0%, skip
            if s <= 0:
                continue
            
            # optimizer is ADAM
            param_state = self.optimizer.state[w]
            if self.args.ssam:
                base_param_state = self.optimizer.base_optimizer.state[w]
            
            # exit()
            # exit()
            optimizer_state_list = ["exp_avg", "exp_avg_sq", "prev_grad", "prev_u", "e_w"]
            for optimizer_state in optimizer_state_list:
                if optimizer_state in param_state:
                    # mask the momentum matrix
                    buf = param_state[optimizer_state]
                    buf *= mask
                if self.args.ssam:
                    if optimizer_state in base_param_state:
                        buf = base_param_state[optimizer_state]
                        buf *= mask
        # print(param_state.keys())
        # print(base_param_state.keys())
    @torch.no_grad()
    def apply_mask_to_weights(self):
        for w, mask, s in zip(self.W, self.backward_masks, self.S):
            # if sparsity is 0%, skip
            if s <= 0:
                continue
                
            w *= mask
    def apply_mask_to_history_weights(self):
        for l, (w, mask, s) in enumerate(zip(self.W, self.backward_masks, self.S)):
            if s <= 0:
                continue
            
            w.data = mask * self.history_masks[l].to(w.device)


    @torch.no_grad()
    def apply_mask_to_gradients(self):
        for w, mask, s in zip(self.W, self.backward_masks, self.S):
            # if sparsity is 0%, skip
            if s <= 0:
                continue
            w.grad *= mask

    
    def check_if_backward_hook_should_accumulate_grad(self):
        """
        Used by the backward hooks. Basically just checks how far away the next rigl step is, 
        if it's within `self.grad_accumulation_n` steps, return True.
        """

        if self.step >= self.T_end:
            return False

        steps_til_next_rigl_step = self.delta_T - (self.step % self.delta_T)
        return steps_til_next_rigl_step <= self.grad_accumulation_n


    def cosine_annealing(self):
        return self.alpha / 2 * (1 + np.cos((self.step * np.pi) / self.T_end))


    def __call__(self):
        self.step += 1
        if self.static_topo:
            return True
        if self.args.early_stop:
            if torch.sum(self.early_stop_signal) == len(self.W):
                # print("All layer early stopped!")
                return True
        if (self.step % self.delta_T) == 0 and self.step <= self.T_end: # check schedule
            if self.args.save_new:
                if torch.max(self.count_steps_mask[0]) < self.args.saving_steps:
                    current_step = self.step // self.delta_T
                    for l, m in enumerate(self.count_steps_mask):
                        
                        self.saving_links_mask[l] = (self.count_steps_mask[l] > 0) & (self.count_steps_mask[l] < current_step)
                        self.not_saving_links_mask[l] = self.count_steps_mask[l] >= current_step
                        self.count_steps_mask[l][self.backward_masks[l]!=0] += 1
                else:
                    for l, m in enumerate(self.count_steps_mask):
                        self.saving_links_mask[l] = (self.count_steps_mask[l] > 0) & (self.count_steps_mask[l] < self.args.saving_steps)
                        self.not_saving_links_mask[l] = self.count_steps_mask[l] >= self.args.saving_steps
                        if l == 0:
                            print("Number of shielded links in current layer: ", torch.sum((self.saving_links_mask[l])))
            elif self.args.adaptive_shield:
                for l, m in enumerate(self.saving_links_mask):
                    layer_score = self.backward_hook_objects[l].get_stats()
                    self.saving_links_mask[l] = (layer_score * self.saving_links_mask[l].bool()) > self.args.shield_threshold
                    self.not_saving_links_mask[l] = (layer_score * self.saving_links_mask[l].bool()) <= self.args.shield_threshold
            self._dst_step()
            self.dst_steps += 1
            print(self)
            return False
        if self.step % self.delta_T == 0:
            print(self)
        return True
    
    def uniform_pruning(self):
        curr_prune_iter = int((self.step - self.ini_iter) / self.delta_T)
        
        if self.args.pruning_scheduler == "linear":
            curr_prune_rate = (self.args.sparsity - self.args.granet_init_sparsity)* curr_prune_iter / self.total_prune_iter + self.args.granet_init_sparsity
        elif self.args.pruning_scheduler == "granet":
            prune_decay = (1 - ((curr_prune_iter - self.ini_iter) / self.total_prune_iter)) ** 3
            curr_prune_rate = self.args.granet_init_sparsity + (self.args.sparsity - self.args.granet_init_sparsity) * (1 - prune_decay)
            
        elif self.args.pruning_scheduler == "s_shape":
            mid_prune_step = self.total_prune_iter / 2
            # S-shape pruning curve
            
            k = 6/mid_prune_step
            
            prune_rate_step_0 = 1/(1 + np.exp(-k * (- mid_prune_step)))
            prune_rate_step_final = 1/(1 + np.exp(-k * (self.total_prune_iter - mid_prune_step)))
            scale_factor = 1 / (prune_rate_step_final - prune_rate_step_0)
            
            

            
            curr_prune_rate = (self.args.sparsity - self.args.granet_init_sparsity) * ((1 / (1 + np.exp(-k * (curr_prune_iter - mid_prune_step)))- 0.5) * scale_factor + 0.5)  + self.args.granet_init_sparsity
        
        else:
            raise NotImplementedError
        
        
        print('******************************************************')
        print(f'Pruning Progress is {curr_prune_iter - self.ini_iter} / {self.total_prune_iter}')
        print('******************************************************')
        
        for l, w in enumerate(self.W):
            if self.S[l] <= 0:
                continue
            if self.args.history_weights and not self.args.new_history_weights:
                weight = self.history_masks[l].clone().to(w.device)
                if self.args.save_new or self.args.adaptive_shield:

                    # exit()
                    weight = weight * (self.saving_links_mask[l] ==0)
            else:
                if self.is_dist:
                    dist.all_reduce(w)
                    w /= self.world_size
                if self.args.save_new or self.args.adaptive_shield:
                    weight = w.clone() * (self.saving_links_mask[l] ==0)
                else:
                    weight = w.clone()
            
            if self.args.pruning_method == "weight_magnitude":
                weight_abs = torch.abs(weight)
                
            elif self.args.pruning_method == "ri":
                eps = 0.00001
                weight_abs = torch.abs(weight)/torch.sum(torch.abs(weight)+ eps, dim=0) + torch.abs(weight)/torch.sum(torch.abs(weight)+ eps, dim=1).reshape(-1, 1)
                
            elif self.args.pruning_method == "MEST":
                score_grow = torch.abs(self.backward_hook_objects[l].dense_grad)
                if self.is_dist:
                    dist.all_reduce(score_grow)
                    score_grow /= self.world_size
                weight_abs = torch.abs(weight) + self.args.factor * torch.abs(score_grow)
            
            
            weight_abs_flatten = torch.flatten(weight_abs)
            if self.args.save_new or self.args.adaptive_shield:
                num_params_to_keep = int(len(weight_abs_flatten) * (1 - curr_prune_rate) - torch.sum(self.saving_links_mask[l]).item())
                # print("Number of saving links in current layer: ", torch.sum(saving_links).item())
                # print("Number of links to keep in current layer: ", num_params_to_keep)
                pass
            else:
                num_params_to_keep = int(len(weight_abs_flatten) * (1 - curr_prune_rate))
            # if num_params_to_keep >= len
            threshold, _ = torch.topk(weight_abs_flatten, num_params_to_keep, sorted=True)
            acceptable_score = threshold[-1]
            if self.args.save_new or self.args.adaptive_shield:
                # print("Number of links in current layer: ", torch.sum(self.backward_masks[l]).item())
                self.backward_masks[l] = ((weight_abs > acceptable_score).bool() | self.saving_links_mask[l] ).to(w.device)
                if l == 0:
                    print("Number of shielded links in current layer: ", torch.sum((self.saving_links_mask[l])))
                
                    print("Number of links in current layer: ", torch.sum(self.backward_masks[l]).item())
                    # print("Number of active links in current layer: ", torch.sum(self.count_steps_mask[l] != 0).item())
            else:

                if "soft" in self.args.pruning_method:
                    T = self.args.start_T + self.step * (self.args.end_T - self.args.start_T) / self.total_prune_iter
                    weight_abs = weight_abs ** T
                    mask = torch.zeros_like(weight_abs_flatten).to(w.device)
                    flat_matrix = weight_abs.flatten()
                    probabilities = flat_matrix / flat_matrix.sum()
                    sampled_flat_indices = torch.multinomial(probabilities, max(1, num_params_to_keep), replacement=False)
                    mask[sampled_flat_indices] = 1
                    self.backward_masks[l] = mask.view(weight_abs.shape).bool().to(w.device)
                else:
                    self.backward_masks[l] = (weight_abs > acceptable_score).bool().to(w.device)

            self.S[l] = 1 - torch.sum(self.backward_masks[l]).item() / self.N[l]
            if self.is_dist:
                dist.broadcast(self.backward_masks[l], 0)
            
    
    def non_uniform_pruning(self):
        curr_prune_iter = int((self.step - self.ini_iter) / self.delta_T)
        
        if self.args.pruning_scheduler == "linear":
            curr_prune_rate = (self.args.sparsity - self.args.granet_init_sparsity) * curr_prune_iter / self.total_prune_iter + self.args.granet_init_sparsity
        elif self.args.pruning_scheduler == "granet":
            prune_decay = (1 - ((curr_prune_iter - self.ini_iter) / self.total_prune_iter)) ** 3
            curr_prune_rate = self.args.granet_init_sparsity + (self.args.sparsity - self.args.granet_init_sparsity) * (1 - prune_decay)
            
        elif self.args.pruning_scheduler == "s_shape":
            mid_prune_step = self.total_prune_iter / 2
            # S-shape pruning curve
            k = 6/mid_prune_step
            
            prune_rate_step_0 = 1/(1 + np.exp(-k * (- mid_prune_step)))
            prune_rate_step_final = 1/(1 + np.exp(-k * (self.total_prune_iter - mid_prune_step)))
            scale_factor = 1 / (prune_rate_step_final - prune_rate_step_0)
            
            

            
            curr_prune_rate = (self.args.sparsity - self.args.granet_init_sparsity) * ((1 / (1 + np.exp(-k * (curr_prune_iter - mid_prune_step)))- 0.5) * scale_factor + 0.5)  + self.args.granet_init_sparsity
        
        else:
            raise NotImplementedError
        
        print('******************************************************')
        print(f'Pruning Progress is {curr_prune_iter - self.ini_iter} / {self.total_prune_iter}')
        print('******************************************************')

        weight_abs = []
        for l, w in enumerate(self.W):
            # print(f"Layer {l}: type of self.backward_masks[l] is {self.backward_masks[l].dtype}")
            if self.S[l] <= 0:
                continue
            if self.args.history_weights:
                weight = self.history_masks[l].clone().to(w.device)
            else:
                if self.is_dist:
                    dist.all_reduce(w)
                    w /= self.world_size
                weight = w.clone()
            
            if self.args.pruning_method == "weight_magnitude":
                weight_abs.append(torch.abs(weight))
                
            elif self.args.pruning_method == "ri":
                eps = 0.00001
                weight_abs.append(torch.abs(weight)/torch.sum(torch.abs(weight)+ eps, dim=0) + torch.abs(weight)/torch.sum(torch.abs(weight)+ eps, dim=1).reshape(-1, 1))
                
            elif self.args.pruning_method == "MEST":
                score_grow = torch.abs(self.backward_hook_objects[l].dense_grad)
                if self.is_dist:
                    dist.all_reduce(score_grow)
                    score_grow /= self.world_size
                weight_abs.append(torch.abs(weight) + self.args.factor * torch.abs(score_grow))
            else:
                raise NotImplementedError

        # Gather all scores in a single vector and normalise
        all_scores = torch.cat([torch.flatten(x) for x in weight_abs])
        num_params_to_keep = int(len(all_scores) * (1 - curr_prune_rate))

        threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
        acceptable_score = threshold[-1]
        
        total_size = 0
        sparse_size = 0
        for l, w in enumerate(self.W):
            if self.S[l] <= 0:
                continue
            self.backward_masks[l] = (weight_abs[l] > acceptable_score).bool().to(w.device) # must be > to prevent acceptable_score is zero, leading to dense tensors
            if self.is_dist:
                dist.broadcast(self.backward_masks[l], 0)
        
            total_size += self.N[l]
            sparse_size += torch.sum(self.backward_masks[l]).item()
            self.S[l] = 1 - torch.sum(self.backward_masks[l]).item() / self.N[l]
            # print(f"Layer {l}: mask sparsity is {torch.sum(self.backward_masks[l]).item() / self.N[l]}")
            
        
        print('Total Model parameters:', total_size)
        print('density after pruning: {0}'.format(
            sparse_size / total_size))

    @torch.no_grad()
    def _dst_step(self):
        if self.args.EM_S and self.args.adaptive_zeta:
            print("EM_S and adaptive_zeta cannot be used together!")
            raise NotImplementedError
        
        if self.args.history_weights:
            for l, w in enumerate(self.W):
                if self.is_dist:
                    dist.all_reduce(w)
                    w /= self.world_size
                self.history_masks[l][self.backward_masks[l] == 1] = w.detach().clone()[self.backward_masks[l] == 1].cpu()

        
        if self.step <= self.pruning_T_end and (self.args.granet or self.args.gmp):
            if self.sparsity_distribution == "non-uniform":
                # need_to_update_sparsity
                self.non_uniform_pruning()
            elif self.sparsity_distribution == "uniform":
                self.uniform_pruning()
            else:
                raise NotImplementedError
            
            self.reset_momentum()
            print(f"W dtype before multiply: {self.W[0].dtype}")
            # exit()
            if self.args.history_weights:
                self.apply_mask_to_history_weights()
            else:
                self.apply_mask_to_weights()
            print(f"W dtype after multiply: {self.W[0].dtype}")
            self.apply_mask_to_gradients()
        
        
        if self.args.gmp:
            # Gradual Magnitude Pruning
            if self.args.chain_removal:
                self.chain_removal()
            
                self.reset_momentum()
                if self.args.history_weights:
                    self.apply_mask_to_history_weights()
                else:
                    self.apply_mask_to_weights()
                self.apply_mask_to_gradients()
            return
        
        self.link_removal()
        
        if self.args.chain_removal:
            self.chain_removal()
        
        self.link_regrowth()
        
        self.reset_momentum()
        if self.args.history_weights:
            self.apply_mask_to_history_weights()
        else:
            self.apply_mask_to_weights()
        self.apply_mask_to_gradients()
        
        torch.cuda.empty_cache()
    
    def link_removal(self):

        for l, w in enumerate(self.W):
            if self.S[l] <= 0:
                continue
            
            if self.args.EM_S:
                drop_fraction = (1-self.S[l]-self.dense_allocation)/(1-self.S[l])
                
            elif self.args.adaptive_zeta:
                drop_fraction = self.cosine_annealing()
            else:
                drop_fraction = self.alpha
            
            current_mask = self.backward_masks[l]
            n_total = self.N[l]
            n_ones = torch.sum(current_mask).item()
            n_prune = int(n_ones * drop_fraction)
            
            # print(n_keep)
            score_drop = torch.abs(w)
            if self.is_dist:
                dist.all_reduce(score_drop)
                score_drop /= self.world_size
            
            if self.args.save_new or self.args.adaptive_shield:
                score_drop = score_drop * (self.saving_links_mask[l] == 0)
                n_keep = int(n_ones - n_prune - torch.sum(self.saving_links_mask[l]).item())
                # n_keep = int(torch.sum(self.not_saving_links_mask[l])).item() - n_prune)
                if l == 0:
                    print("Number of links in current layer: ", torch.sum(current_mask).item())
                    print("Number of n_keep: ", n_keep)
                    print("Number of active links that exceed the saving steps: ", int(torch.sum(self.not_saving_links_mask[l])))
                    print("Number of saving links in current layer: ", int(torch.sum(self.saving_links_mask[l])))
            else:
                n_keep = int(n_ones - n_prune)
            
            if l == 0:
                print(f"number of pruning in layer {l}: {n_prune}")
                print(f"number of keeps in layer {l}: {n_keep}")
            if self.args.remove_method == "weight_magnitude":
                    
                _, sorted_indices = torch.topk(score_drop.view(-1), k=n_total)
                new_values = torch.where(
                            torch.arange(n_total, device=w.device) < n_keep,
                            torch.ones_like(sorted_indices),
                            torch.zeros_like(sorted_indices))
                mask = new_values.scatter(0, sorted_indices, new_values)
                
            elif self.args.remove_method == "weight_magnitude_soft":
                T = 1 + self.step * (2 / self.T_end)
                mask = torch.zeros_like(score_drop.view(-1)).to(w.device)
                flat_matrix = (score_drop.flatten())** T
                probabilities = flat_matrix / flat_matrix.sum()
                sampled_flat_indices = torch.multinomial(probabilities, max(1, n_keep), replacement=False)
                mask[sampled_flat_indices] = 1
                
            elif self.args.remove_method == "ri":
                
                eplison = 0.00001
                score_drop = torch.abs(score_drop)/torch.sum(torch.abs(score_drop)+ eplison, dim=0) + torch.abs(score_drop)/torch.sum(torch.abs(score_drop)+ eplison, dim=1).reshape(-1, 1)
                _, sorted_indices = torch.topk(score_drop.view(-1), k=n_total)
                new_values = torch.where(
                            torch.arange(n_total, device=w.device) < n_keep,
                            torch.ones_like(sorted_indices),
                            torch.zeros_like(sorted_indices))
                mask = new_values.scatter(0, sorted_indices, new_values)
            
            elif self.args.remove_method == "ri_soft":
                eplison = 0.00001

                score_drop = torch.abs(score_drop)/torch.sum(torch.abs(score_drop)+ eplison, dim=0) + torch.abs(score_drop)/torch.sum(torch.abs(score_drop)+ eplison, dim=1).reshape(-1, 1)
                T = 1 + self.step * (2 / self.T_end)
                mask = torch.zeros_like(score_drop.view(-1)).to(w.device)
                flat_matrix = (score_drop.flatten())** T
                probabilities = flat_matrix / flat_matrix.sum()
                sampled_flat_indices = torch.multinomial(probabilities, max(1, n_keep), replacement=False)
                mask[sampled_flat_indices] = 1
            
            elif self.args.remove_method == "MEST":
                score_grow = torch.abs(self.backward_hook_objects[l].dense_grad)
                if self.is_dist:
                    dist.all_reduce(score_grow)
                    score_grow /= self.world_size
                score_drop = score_drop + self.args.factor * torch.abs(score_grow * current_mask)
                _, sorted_indices = torch.topk(score_drop.view(-1), k=n_total)
                new_values = torch.where(
                            torch.arange(n_total, device=w.device) < n_keep,
                            torch.ones_like(sorted_indices),
                            torch.zeros_like(sorted_indices))
                mask = new_values.scatter(0, sorted_indices, new_values)
            
            else:
                raise NotImplementedError
            
            if self.args.save_new or self.args.adaptive_shield:
                self.backward_masks[l] = (torch.reshape(mask, current_mask.shape).bool() | self.saving_links_mask[l]).to(w.device)
                if l== 0:
                    # print("Number of saving links: ", int(torch.sum(self.saving_links_mask[l]).item()))
                    print("Number of links in layer 0", torch.sum(self.backward_masks[l]).item())
            else:
                self.backward_masks[l] = torch.reshape(mask, current_mask.shape).bool().to(w.device)
            
    def chain_removal(self):
        for chain in self.qk_chain_list:
            self.backward_masks[chain[0]], self.backward_masks[chain[1]] = qk_chain_removal(self.backward_masks[chain[0]], self.backward_masks[chain[1]])

        for chain in self.chain_list:
            self.backward_masks[chain[0]], self.backward_masks[chain[1]] = chain_removal(self.backward_masks[chain[0]], self.backward_masks[chain[1]])
            
            
    
    def link_regrowth(self):
        for l, w in enumerate(self.W):
            if self.S[l] <= 0:
                continue
            if self.args.EM_S:
                if self.step <= self.T_end * 0.6:
                    self.S[l] = 1-self.dense_allocation-0.05
                    n_prune = int(0.05 * self.N[l])
                elif self.step < (self.T_end - self.delta_T):
                    self.S[l] = 1-self.dense_allocation-0.025
                    n_prune = int(0.025 * self.N[l])
                else:
                    self.S[l] = 1-self.dense_allocation
                    n_prune = 0
                    print("Final sparsity: {}".format(torch.sum(self.backward_masks[l]).item()/self.N[l]))
                    continue
            else:
                n_prune = int((1-self.S[l]) * self.N[l]) - torch.sum(self.backward_masks[l]).item()

            if l == 0:
                print(f"number of regrowth in layer {l}: {n_prune}")
            if n_prune <= 0:
                continue
                
            current_mask = self.backward_masks[l].data.clone().float()
            if "ch" in self.args.regrow_method.lower():
                CH_method = self.args.regrow_method.split("_")[0]
                
                if "L3n" in self.args.regrow_method:
                    
                    DTPATHS1 = current_mask
                    
                    TDPATHS1 = DTPATHS1.transpose(1, 0)
                    DDPATHS2 = torch.matmul(DTPATHS1, TDPATHS1)
                    TTPATHS2 = torch.matmul(TDPATHS1, DTPATHS1)

                    BDDPATHS2 = DDPATHS2 != 0
                    BTTPATHS2 = TTPATHS2 != 0

                    elcl_DT = (torch.sum(DTPATHS1, dim=1) - DDPATHS2) * BDDPATHS2
                    elcl_TD = (torch.sum(TDPATHS1, dim=1) - TTPATHS2) * BTTPATHS2

                    elcl_DT[elcl_DT == 0] = 1
                    elcl_TD[elcl_TD == 0] = 1

                    elcl_DT -= 1
                    elcl_TD -= 1
                    if CH_method == "CH2":
                        elcl_DT = 1 / (elcl_DT + 1) * (DDPATHS2 + BDDPATHS2)
                        elcl_TD = 1 / (elcl_TD + 1) * (TTPATHS2 + BTTPATHS2)
                    elif CH_method == "CH3":
                        elcl_DT = 1 / (elcl_DT + 1) * BDDPATHS2
                        elcl_TD = 1 / (elcl_TD + 1) * BTTPATHS2
                    elif CH_method == "CH3.1":
                        elcl_DT = 1 / ((elcl_DT + 1) ** (1 + (elcl_DT/ (1+elcl_DT)))) * (DDPATHS2 + BDDPATHS2)
                        elcl_TD = 1 / ((elcl_TD + 1) ** (1 + (elcl_TD/ (1+elcl_TD)))) * (TTPATHS2 + BTTPATHS2)
                    

                    elcl_DT = torch.matmul(elcl_DT, DTPATHS1)
                    elcl_TD = torch.matmul(elcl_TD, TDPATHS1)

                    scores = elcl_DT + elcl_TD.T
                    scores = scores * (current_mask == 0)
                    thre = torch.sort(scores.ravel())[0][-n_prune]
                    if thre == 0:
                        print("Regrowing threshold is 0!!!")
                        scores = (scores + 0.00001)*(current_mask==0)
                
                elif "L3p" in self.args.regrow_method:
                    # CH3_L3 path-based regrowth
                    xb = np.array(current_mask.cpu())
                    x = transform_bi_to_mo(xb)
                    
                    A = csr_matrix(x)
                    ir = A.indices
                    jc = A.indptr
                    if CH_method == "CH2":
                        scores_cell = torch.tensor(np.array(CH_scores.CH_scores_new_v2(ir, jc, x.shape[0], [3], 1, 3, [2], 1))).to(w.device)
                    elif CH_method == "CH3":
                        scores_cell = torch.tensor(np.array(CH_scores.CH_scores_new_v2(ir, jc, x.shape[0], [3], 1, 3, [3], 1))).to(w.device)
                    elif CH_method == "CH3.1":
                        scores_cell = torch.tensor(np.array(CH_scores.CH_scores_new_v2(ir, jc, x.shape[0], [3], 1, 3, [5], 1))).to(w.device)
                    else:
                        raise NotImplementedError
                    scores = torch.reshape(scores_cell, x.shape)
                    scores = scores[:xb.shape[0], xb.shape[0]:]
                    
                    scores = scores * (current_mask == 0)

                    thre = torch.sort(scores.ravel())[0][-n_prune]
                    if thre == 0:
                        print("Regrowing threshold is 0!!!")
                        print(f"# of scores: {torch.sum(scores > 0)}")
                        scores = (scores + 0.00001)*(current_mask==0)
                        
            elif self.args.regrow_method == "random":
                # random regrowth
                scores = torch.rand(w.shape).to(w.device) * (current_mask == 0)
                # flatten grow scores
                thre = torch.sort(scores.ravel())[0][-n_prune]

            elif self.args.regrow_method == "gradient":
                scores = torch.abs(self.backward_hook_objects[l].dense_grad) * (current_mask == 0)
                # flatten grow scores
                thre = torch.sort(scores.ravel())[0][-n_prune]
                
            else:
                raise NotImplementedError
            
                
        
            if "soft" in self.args.regrow_method:
                mask = torch.zeros_like(scores.view(-1)).to(w.device)
                flat_matrix = scores.flatten()
                probabilities = flat_matrix / flat_matrix.sum()
                # print(probabilities.shape)
                sampled_flat_indices = torch.multinomial(probabilities, max(1, n_prune), replacement=False)
                mask[sampled_flat_indices] = 1
            else:
                mask = torch.zeros_like(scores).to(w.device)
                mask[scores >= thre] = 1

            new_link_mask = torch.reshape(mask, current_mask.shape)
            # print(f"type of new_link_mask is {new_link_mask.dtype}")
            # print(f"type of self.backward_masks[l] is {self.backward_masks[l].dtype}")
            self.backward_masks[l] = self.backward_masks[l] | (new_link_mask.bool())
            if self.is_dist:
                dist.broadcast(self.backward_masks[l], 0)
            if self.args.save_new:
                self.count_steps_mask[l] += 1
                self.count_steps_mask[l] *= self.backward_masks[l]
                if l == 0:
                    print(f"Number of new links in layer {l} after evolution: {torch.sum((self.count_steps_mask[l] < (self.args.saving_steps)) & (self.count_steps_mask[l] > 0)).item()}")
                # exit()
                
            elif self.args.adaptive_shield:
                self.saving_links_mask[l] = (self.saving_links_mask[l] | (new_link_mask.bool()))
                # print(f"Number of shielded links in layer {l} after evolution: {torch.sum(self.adaptive_shield_mask[l]).item()}")

            if self.args.itop:
                self.record_mask[l] = ((self.record_mask[l] == 1) | (self.backward_masks[l]))
                print("ITOP rate is : ", (torch.sum(self.record_mask[l]) / self.N[l]).item())
    
    @torch.no_grad()
    def reset_parameters(self):
        for l, w in enumerate(self.W):
            if self.args.init_mode == "swi":
                stdv = math.sqrt(2. / (((1-self.S[l]) * self.N[l]) / w.shape[1]))
            elif self.args.init_mode == "kaiming":
                stdv = math.sqrt(2 / w.shape[1])
            else:
                raise NotImplementedError
            w.data = (torch.randn(w.shape[0], w.shape[1]) * stdv).to(w.device)

def transform_bi_to_mo(xb):
    # create monopartite adjacency matrix
    x = np.zeros((xb.shape[0] + xb.shape[1], xb.shape[0] + xb.shape[1]))

    # Assign xb to the top-right block of matrix x
    x[:xb.shape[0], xb.shape[0]:] = xb

    # Assign the transpose of xb to the bottom-left block of matrix x
    x[xb.shape[0]:, :xb.shape[0]] = xb.T
    return x
