

"""
3x3 kernel as a whole. Through all, keep all 
add log and saving
Fixe block size !!
"""

from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import math
import logging 
import json 


def add_sparse_args(parser):
    parser.add_argument('--growth', type=str, default='random', help='Growth mode. Choose from: momentum, random, and momentum_neuron.')
    parser.add_argument('--death', type=str, default='magnitude', help='Death mode / pruning mode. Choose from: magnitude, SET, threshold, CS_death.')      # 
    parser.add_argument('--redistribution', type=str, default='none', help='Redistribution mode. Choose from: momentum, magnitude, nonzeros, or none.')
    parser.add_argument('--death-rate', type=float, default=0.50, help='The pruning rate / death rate for DST.')
    parser.add_argument('--large-death-rate', type=float, default=0.80, help=' large exploration rate q.')
    parser.add_argument('--PF-rate', type=float, default=0.8, help='The pruning rate / death rate for Pruning and Finetuning.')
    parser.add_argument('--density', type=float, default=0.05, help='The density of the overall sparse network.')
    parser.add_argument('--sparse', action='store_true', help='Enable sparse mode. Default: True.')
    parser.add_argument('--fix', action='store_true', help='Fix topology during training. Default: True.')
    parser.add_argument('--sparse-init', type=str, default='ER', help='sparse initialization')
    parser.add_argument('--update-frequency', type=int, default=1000, metavar='N', help='how many iterations to train between parameter exploration')


class CosineDecay(object):
    def __init__(self, death_rate, T_max, eta_min=0.005, last_epoch=-1):
        self.sgd = optim.SGD(torch.nn.ParameterList([torch.nn.Parameter(torch.zeros(1))]), lr=death_rate)
        self.cosine_stepper = torch.optim.lr_scheduler.CosineAnnealingLR(self.sgd, T_max, eta_min, last_epoch)

    def step(self):
        self.cosine_stepper.step()

    def get_dr(self):
        return self.sgd.param_groups[0]['lr']


class LinearDecay(object):
    def __init__(self, death_rate, factor=0.99, frequency=600):
        self.factor = factor
        self.steps = 0
        self.frequency = frequency

    def step(self):
        self.steps += 1

    def get_dr(self, death_rate):
        if self.steps > 0 and self.steps % self.frequency == 0:
            return death_rate*self.factor
        else:
            return death_rate


class Masking(object):
    def __init__(self, optimizer, death_rate=0.3, growth_death_ratio=1.0, death_rate_decay=None, death_mode='magnitude', growth_mode='gradient', redistribution_mode='none', args=None, train_loader=None, device=None):
        growth_modes = ['random', 'momentum', 'momentum_neuron', 'gradient']
        if growth_mode not in growth_modes:
            print('Growth mode: {0} not supported!'.format(growth_mode))
            print('Supported modes are:', str(growth_modes))

        self.args = args
        self.loader = train_loader
        self.device = torch.device("cuda")
        self.growth_mode = growth_mode
        self.death_mode = death_mode
        self.growth_death_ratio = growth_death_ratio
        self.redistribution_mode = redistribution_mode
        self.death_rate_decay = death_rate_decay
        self.PF_rate = args.PF_rate

        self.death_funcs = {}
        self.death_funcs['magnitude'] = self.magnitude_death
        self.death_funcs['SET'] = self.magnitude_and_negativity_death
        self.death_funcs['threshold'] = self.threshold_death

        self.growth_funcs = {}
        self.growth_funcs['random'] = self.random_growth
        self.growth_funcs['momentum'] = self.momentum_growth
        self.growth_funcs['momentum_neuron'] = self.momentum_neuron_growth

        self.masks = {}
        self.nonzero_masks = {}
        self.new_masks = {}
        self.pre_tensor = {}
        self.pruning_rate = {}
        self.modules = []
        self.names = []
        self.optimizer = optimizer

        self.adjusted_growth = 0
        self.adjustments = []
        self.baseline_nonzero = None
        self.name2baseline_nonzero = {}

        # stats
        self.name2variance = {}
        self.name2zeros = {}
        self.name2nonzeros = {}
        self.total_variance = 0
        self.total_removed = 0
        self.total_zero = 0
        self.total_nonzero = 0
        self.death_rate = death_rate
        self.name2death_rate = {}
        self.steps = 0


        # if fix, then we do not explore the sparse connectivity
        if self.args.fix: self.prune_every_k_steps = None
        else: self.prune_every_k_steps = self.args.update_frequency

        # >>>>>>>>>>>>>>>>>>>>>>>> For Group >>>>>>>>>>>>>>>>>>>>>>>>
        self.group_idxs = {}

    def step(self):
        self.steps += 1
        self.optimizer.step()       # backward

        self.apply_mask()           # make weights zero with mask

        # >>>>>>>>>>>>>>>>>>>>>>>> get death rate (pruning rate) >>>>>>>>>>>>>>>>>>>>>>>>
        self.death_rate_decay.step()
        if self.args.decay_schedule == 'cosine':
            self.death_rate = self.death_rate_decay.get_dr()
        elif self.args.decay_schedule == 'constant':
            self.death_rate = self.args.death_rate

        # >>>>>>>>>>>>>>>>>>>>>>>> Prune and Grow >>>>>>>>>>>>>>>>>>>>>>>>
        if self.prune_every_k_steps is not None:
            if self.steps % self.prune_every_k_steps == 0:      # 1000 step
                self.truncate_weights(self.args.death_rate)
                self.print_nonzero_counts()

    def init(self, mode='ERK', density=0.05, erk_power_scale=1.0):
        self.density = density      # 1 - sparsity

        if mode == 'pruning':
            print('initialize by pruning')
            weight_abs = []
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    weight_abs.append(torch.abs(weight))

            # Gather all scores in a single vector and normalise
            all_scores = torch.cat([torch.flatten(x) for x in weight_abs])
            num_params_to_keep = int(len(all_scores) * (1 - self.PF_rate))

            threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
            acceptable_score = threshold[-1]

            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float()

        elif mode == 'resume':
            # Initializes the mask according to the weights
            # which are currently zero-valued. This is required
            # if you want to resume a sparse model but did not
            # save the mask.
            print('initialize by resume')
            self.baseline_nonzero = 0
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    print((weight != 0.0).sum().item()/weight.numel())
                    self.masks[name][:] = (weight != 0.0).float().data.cuda()
                    self.baseline_nonzero += weight.numel()*density
            self.apply_mask()

        elif mode == 'uniform':
            index = 0
            for module in self.modules:
                for name, weight in module.named_parameters():
                    name_cur = name + '_' + str(index)
                    index += 1
                    if name_cur not in self.masks: continue
                    self.masks[name_cur][:] = (torch.rand(weight.shape) < density).float().data.cuda() #lsw
                    # self.masks[name_cur][:] = (torch.rand(weight.shape) < density).float().data #lsw
            self.apply_mask()

        elif mode == 'ERK':     # default
            print('\ninitialize by fixed_ERK =========>>>>>>>>>')
            total_params = 0                            # total number of params
            for name, weight in self.masks.items():
                total_params += weight.numel()          
            
            is_epsilon_valid = False
            # # The following loop will terminate worst case when all masks are in the
            # custom_sparsity_map. This should probably never happen though, since once
            # we have a single variable or more with the same constant, we have a valid
            # epsilon. Note that for each iteration we add at least one variable to the
            # custom_sparsity_map and therefore this while loop should terminate.
            dense_layers = set()
            while not is_epsilon_valid:
                # We will start with all layers and try to find right epsilon. However if
                # any probablity exceeds 1, we will make that layer dense and repeat the
                # process (finding epsilon) with the non-dense layers.
                # We want the total number of connections to be the same. Let say we have
                # for layers with N_1, ..., N_4 parameters each. Let say after some
                # iterations probability of some dense layers (3, 4) exceeded 1 and
                # therefore we added them to the dense_layers set. Those layers will not
                # scale with erdos_renyi, however we need to count them so that target
                # paratemeter count is achieved. See below.
                # eps * (p_1 * N_1 + p_2 * N_2) + (N_3 + N_4) =
                #    (1 - default_sparsity) * (N_1 + N_2 + N_3 + N_4)
                # eps * (p_1 * N_1 + p_2 * N_2) =
                #    (1 - default_sparsity) * (N_1 + N_2) - default_sparsity * (N_3 + N_4)
                # eps = rhs / (\sum_i p_i * N_i) = rhs / divisor.

                divisor = 0
                rhs = 0
                raw_probabilities = {}
                for name, mask in self.masks.items():
                    # name, block1.layer.0.bn1.weight, mask, all 0 at the initialization 
                    n_param = np.prod(mask.shape)
                    n_zeros = n_param * (1 - density)
                    n_ones = n_param * density

                    if name in dense_layers:
                        # See `- default_sparsity * (N_3 + N_4)` part of the equation above.
                        rhs -= n_zeros

                    else:
                        # Corresponds to `(1 - default_sparsity) * (N_1 + N_2)` part of the
                        # equation above.
                        rhs += n_ones
                        # Erdos-Renyi probability: epsilon * (n_in + n_out / n_in * n_out).
                        raw_probabilities[name] = (
                                                          np.sum(mask.shape) / np.prod(mask.shape)
                                                  ) ** erk_power_scale
                        # Note that raw_probabilities[mask] * n_param gives the individual
                        # elements of the divisor.
                        divisor += raw_probabilities[name] * n_param
                # By multipliying individual probabilites with epsilon, we should get the
                # number of parameters per layer correctly.
                epsilon = rhs / divisor
                # If epsilon * raw_probabilities[mask.name] > 1. We set the sparsities of that
                # mask to 0., so they become part of dense_layers sets.
                max_prob = np.max(list(raw_probabilities.values()))
                max_prob_one = max_prob * epsilon
                if max_prob_one > 1:
                    is_epsilon_valid = False
                    for mask_name, mask_raw_prob in raw_probabilities.items():
                        if mask_raw_prob == max_prob:
                            print(f"Sparsity of var:{mask_name} had to be set to 0.")
                            dense_layers.add(mask_name)
                else:
                    is_epsilon_valid = True

            self.density_dict = {}
            total_nonzero = 0.0
            # With the valid epsilon, we can set sparsities of the remaning layers.
            for name, mask in self.masks.items():
                n_param = np.prod(mask.shape)
                if name in dense_layers:
                    self.density_dict[name] = 1.0
                else:
                    probability_one = epsilon * raw_probabilities[name]
                    self.density_dict[name] = probability_one
                print(
                    f"  **** Layer: {name}, shape: {mask.shape}, density: {self.density_dict[name]}"
                )
                # >>>>>>>>>>>>>>>>>>>>>> random mask with given probability >>>>>>>>>>>>>>>>>>>>>>
                self.masks[name][:] = (torch.rand(mask.shape) < self.density_dict[name]).float().data.cuda()

                total_nonzero += self.density_dict[name] * mask.numel()
            print(f"Overall sparsity {total_nonzero / total_params}")

        elif mode == 'ER':
            print('initialize by SET')
            # initialization used in sparse evolutionary training
            total_params = 0
            index = 0
            for module in self.modules:
                for name, weight in module.named_parameters():
                    name_cur = name + '_' + str(index)
                    index += 1
                    if name_cur not in self.masks: continue
                    total_params += weight.numel()

            target_params = total_params *density
            tolerance = 5
            current_params = 0
            new_nonzeros = 0
            epsilon = 10.0
            growth_factor = 0.5
            # searching for the right epsilon for a specific sparsity level
            while not ((current_params+tolerance > target_params) and (current_params-tolerance < target_params)):
                new_nonzeros = 0.0
                index = 0
                for name, weight in module.named_parameters():
                    name_cur = name + '_' + str(index)
                    index += 1
                    if name_cur not in self.masks: continue
                    # original SET formulation for fully connected weights: num_weights = epsilon * (noRows + noCols)
                    # we adapt the same formula for convolutional weights
                    growth =  epsilon*sum(weight.shape)
                    new_nonzeros += growth
                current_params = new_nonzeros
                if current_params > target_params:
                    epsilon *= 1.0 - growth_factor
                else:
                    epsilon *= 1.0 + growth_factor
                growth_factor *= 0.95

            index = 0
            for name, weight in module.named_parameters():
                name_cur = name + '_' + str(index)
                index += 1
                if name_cur not in self.masks: continue
                growth =  epsilon*sum(weight.shape)
                prob = growth/np.prod(weight.shape)
                self.masks[name_cur][:] = (torch.rand(weight.shape) < prob).float().data.cuda() #lsw
                # self.masks[name_cur][:] = (torch.rand(weight.shape) < prob).float().data


        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        # TODO: group mask here
        # print('++++++++++++++++ Group Mask ++++++++++++++++')
        # for module in self.modules:
        #     for name, weight in module.named_parameters():
        #         if name not in self.masks: continue
        #         if self.density_dict[name] == 1.0: continue
        #         mask = self.masks[name]    # shape: (ch_out, ch_in, k, k). Value: 1 means keep, 0 means out     80% "0"

        #         new_mask, dense_block = group_mask(mask)
        #         self.masks[name] = new_mask.float()
        #         print('name: ', name, 'dense: ', dense_block)

        self.group_mask() # make it as note if just 

        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.apply_mask()

        total_size = 0
        for name, weight in self.masks.items():
            total_size  += weight.numel()
        print('Total Model parameters:', total_size)

        sparse_size = 0
        for name, weight in self.masks.items():
            sparse_size += (weight != 0).sum().int().item()

        print('Total parameters under sparsity level of {0}: {1}'.format(density, sparse_size / total_size))


    '''
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Module >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    '''

    def add_module(self, module, density, sparse_init='ER'):
        # density. 0.2 = 80% sparsity
        self.sparse_init = sparse_init      # ERK
        self.modules.append(module)
        print('Add Module >>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
        for name, tensor in module.named_parameters():
            if len(tensor.shape) != 4:
                print('Skip {}. Size: '.format(name)), tensor.shape
                continue 

            print('module name: ', name, tensor.shape)
            self.names.append(name)
            cout, cin, _, _ = tensor.shape
            # self.masks[name] = torch.zeros_like(tensor, dtype=torch.float32, requires_grad=False).cuda()        # mask 
            self.masks[name] = torch.zeros((cout, cin), requires_grad=False).cuda()

        print('Removing biases...')
        self.remove_weight_partial_name('bias')
        print('Removing 2D batch norms...')
        self.remove_type(nn.BatchNorm2d)
        print('Removing 1D batch norms...')
        self.remove_type(nn.BatchNorm1d)

        # >>>>>>>>>>>>>>>>>>>>>>>> initialization >>>>>>>>>>>>>>>>>>>>>>>>
        self.init(mode=sparse_init, density=density)


    def remove_weight(self, name):
        if name in self.masks:
            print('Removing {0} of size {1} = {2} parameters.'.format(name, self.masks[name].shape,
                                                                      self.masks[name].numel()))
            self.masks.pop(name)
        elif name + '.weight' in self.masks:
            print('Removing {0} of size {1} = {2} parameters.'.format(name, self.masks[name + '.weight'].shape,
                                                                      self.masks[name + '.weight'].numel()))
            self.masks.pop(name + '.weight')
        else:
            print('ERROR', name)

    def remove_weight_partial_name(self, partial_name):
        removed = set()
        for name in list(self.masks.keys()):
            if partial_name in name:

                print('Removing {0} of size {1} with {2} parameters...'.format(name, self.masks[name].shape,
                                                                                   np.prod(self.masks[name].shape)))
                removed.add(name)
                self.masks.pop(name)

        print('Removed {0} layers.'.format(len(removed)))

        i = 0
        while i < len(self.names):
            name = self.names[i]
            if name in removed:
                self.names.pop(i)
            else:
                i += 1

    def remove_type(self, nn_type):
        for module in self.modules:
            for name, module in module.named_modules():
                if isinstance(module, nn_type):
                    self.remove_weight(name)

    def apply_mask(self):
        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name in self.masks:
                    tensor.data = tensor.data* (self.masks[name].unsqueeze(-1).unsqueeze(-1) )   # zero out weights
                    # if 'momentum_buffer' in self.optimizer.state[tensor]:
                    #     self.optimizer.state[tensor]['momentum_buffer'] = self.optimizer.state[tensor]['momentum_buffer']*self.masks[name]

    def group_mask(self):
        logging.info('\n+++++++++++++++++++++++++++++++ Group Mask +++++++++++++++++++++++++++++++')
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                if self.density_dict[name] == 1.0: continue
                # !!! 
                # if 'block1.' in name or 'block2.' in name: continue

                mask = self.masks[name]    # shape: (ch_out, ch_in, k, k). Value: 1 means keep, 0 means out     80% "0"

                logging.info('******** name: {},  shape: {} ********'.format(name, str(mask.shape)))
                new_mask, dense_block = group_mask(mask)
                logging.info('Origin density: {}, new density: {}'.format( mask.sum()/ mask.numel(),  new_mask.sum() / mask.numel()   ) )
                self.masks[name] = new_mask.float()         # replace

                # save 
                self.group_idxs[name] = {}
                self.group_idxs[name]["size"] = list(weight.shape)
                self.group_idxs[name]["density"] = new_mask.sum().item() / mask.numel()
                self.group_idxs[name]["block"] = dense_block         # [[rows], [cols]]  save results

                # *********************** Pring Info ***********************
                msg = "-- # of block: {} \n".format(len(dense_block))
                for row, col in dense_block:
                    msg += '--- Block Size: {} x {}: {}, {}\n'.format(len(row), len(col), row, col)
                msg += '\n'
                logging.info(msg)

        save_dir =  'results/DST_' + str(self.args.model) + '_' + str(self.args.data) + '_density' + str(self.args.density) + '_seed' + str(self.args.seed) + '/mask.json'
        json.dump(self.group_idxs, open(save_dir, "w"), indent=4)



    def truncate_weights(self, pruning_rate):
        print('*** dynamic sparse training. Prune & Grow ***')
        # Pruning step >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                mask = self.masks[name]    # shape: (ch_out, ch_in, k, k). Value: 1 means keep, 0 means out     80% "0"

                # death
                if self.death_mode == 'magnitude':      # default
                    new_mask = self.magnitude_death(mask, weight, name, pruning_rate)       
                    # shape: (ch_out, ch_in, K, K)
                    # (1 - death_rate) * density of "1"
                elif self.death_mode == 'SET':
                    new_mask = self.magnitude_and_negativity_death(mask, weight, name)
                elif self.death_mode == 'Taylor_FO':
                    new_mask = self.taylor_FO(mask, weight, name)
                elif self.death_mode == 'threshold':
                    new_mask = self.threshold_death(mask, weight, name)

                self.pruning_rate[name] = int(self.masks[name].sum().item() - new_mask.sum().item())        
                # density - (1 - death_rate) * density = death_rate * density of "1"
                self.masks[name][:] = new_mask

        self.apply_mask()

        # Connect growth >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                new_mask = self.masks[name].data.byte()

                # growth
                if self.growth_mode == 'random':
                    new_mask = self.random_growth(name, new_mask, self.pruning_rate[name], weight)

                elif self.growth_mode == 'momentum':
                    new_mask = self.momentum_growth(name, new_mask, self.pruning_rate[name], weight)

                elif self.growth_mode == 'gradient':        # default
                    new_mask = self.gradient_growth(name, new_mask, self.pruning_rate[name], weight)

                elif self.growth_mode == 'momentum_neuron':
                    new_mask = self.momentum_neuron_growth(name, new_mask,  self.pruning_rate[name], weight)
                # exchanging masks
                self.masks.pop(name)
                self.masks[name] = new_mask.float()

        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        # TODO: group mask here
        
        # for module in self.modules:
        #     for name, weight in module.named_parameters():
        #         if name not in self.masks: continue
        #         if self.density_dict[name] == 1.0: continue
        #         mask = self.masks[name]    # shape: (ch_out, ch_in, k, k). Value: 1 means keep, 0 means out     80% "0"
                
        #         new_mask, dense_block = group_mask(mask)
        #         self.masks[name] = new_mask.float()
        #         print('name: ', name, 'dense: ', dense_block)
        self.group_mask()

        # apply new mask 
        self.apply_mask()


        # Print statistics >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        total_size = 0
        for name, weight in self.masks.items():
            total_size += weight.numel()
        print('Total Model parameters after dst:', total_size)      # always the same

        sparse_size = 0
        for name, weight in self.masks.items():
            sparse_size += (weight != 0).sum().int().item()

        print('Total parameters under sparsity level of {0}: {1} after dst'.format(self.args.density, sparse_size / total_size))

    def pruning(self):
        print('pruning...')
        print('death rate:', self.args.density)
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                num_remove = math.ceil((1-self.args.density) * weight.numel())
                x, idx = torch.sort(torch.abs(weight.data.view(-1)))
                self.masks[name].data.view(-1)[idx[:num_remove]] = 0.0
        self.apply_mask()
        total_size = 0
        for name, weight in self.masks.items():
            total_size  += weight.numel()
        print('Total Model parameters:', total_size)

        sparse_size = 0
        for name, weight in self.masks.items():
            sparse_size += (weight != 0).sum().int().item()

        print('Total parameters under sparsity level of {0}: {1}'.format(self.args.density, sparse_size / total_size))



    '''
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> DEATH >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    '''

    def threshold_death(self, mask, weight, name):
        return (torch.abs(weight.data) > self.threshold)

    def taylor_FO(self, mask, weight, name):

        num_remove = math.ceil(self.name2death_rate[name] * self.name2nonzeros[name])
        num_zeros = self.name2zeros[name]
        k = math.ceil(num_zeros + num_remove)

        x, idx = torch.sort((weight.data * weight.grad).pow(2).flatten())
        mask.data.view(-1)[idx[:k]] = 0.0

        return mask

    def magnitude_death(self, mask, weight, name, pruning_rate):
        # ************************* default *************************
        # prune based on the magnitude of weights
        # mask: (ch_out, ch_in, K, K).          1 means keep, 0 means out
        # weights: (ch_out, ch_in, K, K). already include some 0 weights
        num_zeros = (mask == 0).sum().item()                                # 1 - density
        num_remove = math.ceil(pruning_rate * (mask.sum().item()))          # death_rate * density 
        if num_remove == 0.0: return weight.data != 0.0

        # sort remaining weights by magnitude 
        x, idx = torch.sort(torch.abs(weight.data.view(-1)))        # small to large

        k = math.ceil(num_zeros + num_remove)           # (1 - density)  + death_rate * density "0" 
        threshold = x[k - 1].item()

        # revise here for kenel-wise
        # return (torch.abs(weight.data) > threshold)
        return (torch.abs(weight.data.sum(-1).sum(-1)) > threshold)


    def global_magnitude_death(self):
        death_rate = 0.0
        for name in self.name2death_rate:
            if name in self.masks:
                death_rate = self.name2death_rate[name]
        tokill = math.ceil(death_rate*self.baseline_nonzero)
        total_removed = 0
        prev_removed = 0
        while total_removed < tokill*(1.0-self.tolerance) or (total_removed > tokill*(1.0+self.tolerance)):
            total_removed = 0
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    remain = (torch.abs(weight.data) > self.threshold).sum().item()
                    total_removed += self.name2nonzeros[name] - remain

            if prev_removed == total_removed: break
            prev_removed = total_removed
            if total_removed > tokill*(1.0+self.tolerance):
                self.threshold *= 1.0-self.increment
                self.increment *= 0.99
            elif total_removed < tokill*(1.0-self.tolerance):
                self.threshold *= 1.0+self.increment
                self.increment *= 0.99

        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                self.masks[name][:] = torch.abs(weight.data) > self.threshold

        return int(total_removed)


    def global_momentum_growth(self, total_regrowth):
        togrow = total_regrowth
        total_grown = 0
        last_grown = 0
        while total_grown < togrow*(1.0-self.tolerance) or (total_grown > togrow*(1.0+self.tolerance)):
            total_grown = 0
            total_possible = 0
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue

                    new_mask = self.masks[name]
                    grad = self.get_momentum_for_weight(weight)
                    grad = grad*(new_mask==0).float()
                    possible = (grad !=0.0).sum().item()
                    total_possible += possible
                    grown = (torch.abs(grad.data) > self.growth_threshold).sum().item()
                    total_grown += grown
            print(total_grown, self.growth_threshold, togrow, self.growth_increment, total_possible)
            if total_grown == last_grown: break
            last_grown = total_grown


            if total_grown > togrow*(1.0+self.tolerance):
                self.growth_threshold *= 1.02
                #self.growth_increment *= 0.95
            elif total_grown < togrow*(1.0-self.tolerance):
                self.growth_threshold *= 0.98
                #self.growth_increment *= 0.95

        total_new_nonzeros = 0
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue

                new_mask = self.masks[name]
                grad = self.get_momentum_for_weight(weight)
                grad = grad*(new_mask==0).float()
                self.masks[name][:] = (new_mask.byte() | (torch.abs(grad.data) > self.growth_threshold)).float()
                total_new_nonzeros += new_mask.sum().item()
        return total_new_nonzeros


    def magnitude_and_negativity_death(self, mask, weight, name):
        num_remove = math.ceil(self.name2death_rate[name]*self.name2nonzeros[name])
        num_zeros = self.name2zeros[name]

        # find magnitude threshold
        # remove all weights which absolute value is smaller than threshold
        x, idx = torch.sort(weight[weight > 0.0].data.view(-1))
        k = math.ceil(num_remove/2.0)
        if k >= x.shape[0]:
            k = x.shape[0]

        threshold_magnitude = x[k-1].item()

        # find negativity threshold
        # remove all weights which are smaller than threshold
        x, idx = torch.sort(weight[weight < 0.0].view(-1))
        k = math.ceil(num_remove/2.0)
        if k >= x.shape[0]:
            k = x.shape[0]
        threshold_negativity = x[k-1].item()


        pos_mask = (weight.data > threshold_magnitude) & (weight.data > 0.0)
        neg_mask = (weight.data < threshold_negativity) & (weight.data < 0.0)


        new_mask = pos_mask | neg_mask
        return new_mask


    '''
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Growth >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    '''


    def random_growth(self, name, new_mask, total_regrowth, weight):
        n = (new_mask==0).sum().item()
        if n == 0: return new_mask
        expeced_growth_probability = (total_regrowth/n)
        new_weights = torch.rand(new_mask.shape).cuda() < expeced_growth_probability #lsw
        # new_weights = torch.rand(new_mask.shape) < expeced_growth_probability
        return new_mask.byte() | new_weights

    def momentum_growth(self, name, new_mask, total_regrowth, weight):
        grad = self.get_momentum_for_weight(weight)
        grad = grad*(new_mask==0).float()
        y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
        new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0

        return new_mask

    def gradient_growth(self, name, new_mask, total_regrowth, weight):
        # new_mask, shape: (ch_out, ch_in, k, k)        (1 - death_rate) * density of "1"
        # total_regrowth: death_rate * density 
        # weight: (ch_out, ch_in, K, K)
        # ************************* default function *************************
        if self.density_dict[name]==1.0:
            new_mask = torch.ones_like(new_mask, dtype=torch.float32, requires_grad=False).cuda()
        else:
            grad = self.get_gradient_for_weights(weight)        # shape: (ch_out, ch_in, K, K)
            # grad = grad*(new_mask==0).float()
            grad = grad*(new_mask.unsqueeze(-1).unsqueeze(-1)  == 0).float()

            # y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
            y, idx = torch.sort(torch.abs(grad).sum(-1).sum(-1).flatten(), descending=True)

            new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0

        return new_mask

    def momentum_neuron_growth(self, name, new_mask, total_regrowth, weight):
        grad = self.get_momentum_for_weight(weight)

        M = torch.abs(grad)
        if len(M.shape) == 2: sum_dim = [1]
        elif len(M.shape) == 4: sum_dim = [1, 2, 3]

        v = M.mean(sum_dim).data
        v /= v.sum()

        slots_per_neuron = (new_mask==0).sum(sum_dim)

        M = M*(new_mask==0).float()
        for i, fraction  in enumerate(v):
            neuron_regrowth = math.floor(fraction.item()*total_regrowth)
            available = slots_per_neuron[i].item()

            y, idx = torch.sort(M[i].flatten())
            if neuron_regrowth > available:
                neuron_regrowth = available
            threshold = y[-(neuron_regrowth)].item()
            if threshold == 0.0: continue
            if neuron_regrowth < 10: continue
            new_mask[i] = new_mask[i] | (M[i] > threshold)

        return new_mask

    '''
                UTILITY
    '''
    def get_momentum_for_weight(self, weight):
        if 'exp_avg' in self.optimizer.state[weight]:
            adam_m1 = self.optimizer.state[weight]['exp_avg']
            adam_m2 = self.optimizer.state[weight]['exp_avg_sq']
            grad = adam_m1/(torch.sqrt(adam_m2) + 1e-08)
        elif 'momentum_buffer' in self.optimizer.state[weight]:
            grad = self.optimizer.state[weight]['momentum_buffer']
        return grad

    def get_gradient_for_weights(self, weight):
        grad = weight.grad.clone()
        return grad

    def print_nonzero_counts(self):
        # >>>>>>>>>>>>>>>>>>>>>>>>>>> Print each module >>>>>>>>>>>>>>>>>>>>>>>>>>> 
        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name not in self.masks: continue
                mask = self.masks[name]
                num_nonzeros = (mask != 0).sum().item()
                if name in self.name2variance:
                    val = '{0}: {1}->{2}, density: {3:.3f}, proportion: {4:.4f}'.format(name, self.name2nonzeros[name], num_nonzeros, num_nonzeros/float(mask.numel()), self.name2variance[name])
                    print(val)
                else:
                    print('module: ', name, '#non_zeros: ', num_nonzeros, '#total: ', np.prod(mask.shape))

        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name not in self.masks: continue
                print('Death rate: {0}\n'.format(self.death_rate))
                break

    def reset_momentum(self):
        """
        Taken from: https://github.com/AlliedToasters/synapses/blob/master/synapses/SET_layer.py
        Resets buffers from memory according to passed indices.
        When connections are reset, parameters should be treated
        as freshly initialized.
        """
        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name not in self.masks: continue
                mask = self.masks[name]
                weights = list(self.optimizer.state[tensor])
                for w in weights:
                    if w == 'momentum_buffer':
                        # momentum
                        self.optimizer.state[tensor][w][mask==0] = torch.mean(self.optimizer.state[tensor][w][mask.byte()])
                        # self.optimizer.state[tensor][w][mask==0] = 0
                    elif w == 'square_avg' or \
                        w == 'exp_avg' or \
                        w == 'exp_avg_sq' or \
                        w == 'exp_inf':
                        # Adam
                        self.optimizer.state[tensor][w][mask==0] = torch.mean(self.optimizer.state[tensor][w][mask.byte()])

    def fired_masks_update(self):
        ntotal_fired_weights = 0.0
        ntotal_weights = 0.0
        layer_fired_weights = {}
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                self.fired_masks[name] = self.masks[name].data.byte() & self.fired_masks[name].data.byte()
                ntotal_fired_weights += float(self.fired_masks[name].sum().item())
                ntotal_weights += float(self.fired_masks[name].numel())
                layer_fired_weights[name] = float(self.fired_masks[name].sum().item())/float(self.fired_masks[name].numel())
                print('Layerwise percentage of the fired weights of', name, 'is:', layer_fired_weights[name])
        total_fired_weights = ntotal_fired_weights/ntotal_weights
        print('The percentage of the total fired weights is:', total_fired_weights)
        return layer_fired_weights, total_fired_weights


# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
import copy

def matrix2graph(matrix):
    # matrix: (c_out, c_in * K * K)
    num_row, num_col = matrix.shape
    hyperedge_indices = [0]
    hyperedges = []
    counter = 0
    for i in range(num_col):
        for j in range(num_row):
            if matrix[j, i].item() == 1:
                hyperedges.append(j)
                counter += 1
        hyperedge_indices.append(counter)
    return hyperedge_indices, hyperedges



def hyperGraphPartition(weights, t1=2):
    # weights: (cout, cin * K * K)
    import kahypar as kahypar
    
    # transfer to hyper graph format    
    num_nodes, num_edges = weights.shape
    hyperedge_indices, hyperedges = matrix2graph(weights)  # hyperedge_indices, length = num_edge + 1
    edge_weights = [1] * num_edges
    node_weights = [1] * num_nodes
    
    # build Hyper Graph
    hypergraph = kahypar.Hypergraph(num_nodes, num_edges, hyperedge_indices, hyperedges, t1, edge_weights, node_weights)
    
    # Partition
    context = kahypar.Context()
    context.loadINIconfiguration("kahypar/config/km1_kKaHyPar_sea20.ini")
    context.setK(t1)
    context.setEpsilon(0.03)
    kahypar.partition(hypergraph, context)
    

    # output 
    groups = {}
    for node_id in hypergraph.nodes():
        group_id = hypergraph.blockID(node_id)
        if group_id not in groups:
            groups[group_id] = [node_id]
        else:
            groups[group_id].append(node_id)
    return groups


def group_mask(input_mask, B1=8, B2=8, t1=8, t2=4):
    num_row, num_col = input_mask.shape

    # ******************* fix size for all blocks *******************
    adjust_ratio = 1.5
    sparse_ratio = input_mask.sum() / input_mask.numel()
    keep_col_num = math.ceil(sparse_ratio * adjust_ratio * num_col)
    # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

    
    output_dense_block = []
    dense_mask = copy.deepcopy(input_mask)
    new_mask = copy.deepcopy(input_mask)
    # each time, obtain one dense block, until cannot find new dense block
    while True:
        find_dense = False
        # 1) Divide the rows into t1 groups
        groups = hyperGraphPartition(dense_mask, t1=t1)

        for g in groups: # 2) for each group gi
            sec_rows_id = groups[g]    # a list
            sec_rows = dense_mask[sec_rows_id]  # (B1, num_col) 
            if len(sec_rows) >= B1: # if gi has no less than B1 rows
                
                # sort all cols based on the number of "1"
                all_cols_count = [sec_rows[:, i].sum()  for i in range(num_col)]    # len: num_col
                all_cols_id_sort = sorted(list(range(num_col)), key=lambda x:all_cols_count[x], reverse=True) # 

                # 3) select cols with at least t2 non-zeros
                dense_cols_id = [i  for i in range(num_col) if all_cols_count[i] >= t2]

                if len(dense_cols_id) >= B2:      # B2 is bottomline 
                    # if at least B2 columns are selected, and 
                    sec_cols_id = all_cols_id_sort[:keep_col_num]

                    # 4) output selected cols in gi as a dense block 
                    output_dense_block.append([sec_rows_id, sec_cols_id])   # save dense block ID

                    # 5) remove non-zeros in the dense block from S
                    for r in sec_rows_id:
                        for c in dense_cols_id: # !!! 
                            dense_mask[r, c] = 0
                    find_dense = True
                    # break
        
        if not find_dense: 
            break
    
    # ************************** Find nothing **************************
    if len(output_dense_block) == 0:
        print('No dense block found!!!')
        return input_mask, output_dense_block

    num_sparse = dense_mask.sum().item()      # remaining dot

    new_mask[dense_mask == 1] = 0           # # remove sparse mask dot
    
    # ************************** fill in mask **************************
    num_fill = 0
    fflag = True
    for (rows, cols) in output_dense_block:
        for r in rows:
            for c in cols:
                if new_mask[r, c] == 0:
                    if num_fill < num_sparse:
                        new_mask[r, c] = 1
                        num_fill += 1
                    else:
                        fflag = False
                        break
            if not fflag:
                break
        if not fflag:
            break
    logging.info('#sparse: {},  #fill: {} '.format(num_sparse, num_fill))

    return new_mask, output_dense_block

