"""
Grouping on Imagenet
"""


from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import numpy as np
import math
from snip import SNIP, prefetched_loader, GraSP
from torch.autograd import Variable
from funcs import redistribution_funcs, growth_funcs, prune_funcs
import logging 
import random
import json
import copy



def print_and_log(msg):
    global logger
    print(msg)
    logging.info(msg)




def add_sparse_args(parser):
    parser.add_argument('--growth', type=str, default='gradient', help='Growth mode. Choose from: momentum, random, and momentum_neuron.')
    parser.add_argument('--prune', type=str, default='magnitude', help='Prune mode / pruning mode. Choose from: magnitude, SET.')
    parser.add_argument('--redistribution', type=str, default='none', help='Redistribution mode. Choose from: momentum, magnitude, nonzeros, or none.')
    parser.add_argument('--prune-rate', type=float, default=0.50, help='The pruning rate / prune rate.')
    parser.add_argument('--density', type=float, default=0.20, help='The density of the overall sparse network.')
    parser.add_argument('--dense', action='store_true', help='Enable dense mode. Default: False.')
    parser.add_argument('--verbose', action='store_true', help='Prints verbose status of pruning/growth algorithms.')
    parser.add_argument('--update_frequency', type=int, default=4000, metavar='N', help='how many iterations to train between mask update')
    parser.add_argument('--fix', action='store_true', help='Fix topology during training. Default: True.')
    parser.add_argument('--sparse_init', type=str, default='ER', help='sparse initialization')
    parser.add_argument('--multiplier', type=int, default=1, metavar='N', help='extend training time by multiplier times')
    parser.add_argument('--fc_density', type=float, default=1, help='The pruning rate / death rate.')
    #------------------
    #parameters of reinitialization
    # ------------------
    parser.add_argument(
        "--scale-fan", action="store_true", default=False, help="scale fan")
    parser.add_argument(
        "--init", default="kaiming_normal", help="Weight initialization modifications"
    )
class CosineDecay(object):
    """Decays a pruning rate according to a cosine schedule

    This class is just a wrapper around PyTorch's CosineAnnealingLR.
    """
    def __init__(self, prune_rate, T_max, eta_min=0.005, last_epoch=-1):
        self.sgd = optim.SGD(torch.nn.ParameterList([torch.nn.Parameter(torch.zeros(1))]), lr=prune_rate)
        self.cosine_stepper = torch.optim.lr_scheduler.CosineAnnealingLR(self.sgd, T_max, eta_min, last_epoch)

    def step(self):
        self.cosine_stepper.step()

    def get_dr(self, prune_rate):
        return self.sgd.param_groups[0]['lr']

class LinearDecay(object):
    """Anneals the pruning rate linearly with each step."""
    def __init__(self, prune_rate, T_max):
        self.steps = 0
        self.decrement = prune_rate/float(T_max)
        self.current_prune_rate = prune_rate

    def step(self):
        self.steps += 1
        self.current_prune_rate -= self.decrement

    def get_dr(self, prune_rate):
        return self.current_prune_rate


class Masking(object):
    """Wraps PyTorch model parameters with a sparse mask.

    Creates a mask for each parameter tensor contained in the model. When
    `apply_mask()` is called, it applies the sparsity pattern to the parameters.

    Basic usage:
        optimizer = torchoptim.SGD(model.parameters(),lr=args.lr)
        decay = CosineDecay(args.prune_rate, len(train_loader)*(args.epochs))
        mask = Masking(optimizer, prune_rate_decay=decay)
        model = MyModel()
        mask.add_module(model)

    Removing layers: Layers can be removed individually, by type, or by partial
    match of their name.
      - `mask.remove_weight(name)` requires an exact name of
    a parameter.
      - `mask.remove_weight_partial_name(partial_name=name)` removes all
        parameters that contain the partial name. For example 'conv' would remove all
        layers with 'conv' in their name.
      - `mask.remove_type(type)` removes all layers of a certain type. For example,
        mask.remove_type(torch.nn.BatchNorm2d) removes all 2D batch norm layers.
    """
    def __init__(self, optimizer, prune_rate_decay, prune_rate=0.5, prune_mode='magnitude', growth_mode='momentum', redistribution_mode='momentum', verbose=False, fp16=False, args=False, train_loader=False):
        growth_modes = ['random', 'momentum', 'momentum_neuron', 'gradient']
        if growth_mode not in growth_modes:
            print('Growth mode: {0} not supported!'.format(growth_mode))
            print('Supported modes are:', str(growth_modes))
        self.train_loader = train_loader
        self.args = args
        self.growth_mode = growth_mode
        self.prune_mode = prune_mode
        self.redistribution_mode = redistribution_mode
        self.prune_rate_decay = prune_rate_decay
        self.verbose = verbose
        self.device = torch.device("cuda")

        self.growth_func = growth_mode
        self.prune_func = prune_mode
        self.redistribution_func = redistribution_mode

        self.global_growth = False
        self.global_prune = False
        self.fc_density = args.fc_density
        self.masks = {}
        self.modules = []
        self.names = []
        self.optimizer = optimizer

        self.adjusted_growth = 0
        self.adjustments = []
        self.baseline_nonzero = None
        self.name2baseline_nonzero = {}

        # stats
        self.name2variance = {}
        self.name2zeros = {}
        self.name2nonzeros = {}
        self.name2removed = {}

        self.total_variance = 0
        self.total_removed = 0
        self.total_zero = 0
        self.total_nonzero = 0
        self.prune_rate = prune_rate
        self.name2prune_rate = {}
        self.steps = 0
        self.start_name = None

        # global growth/prune state
        self.prune_threshold = 0.001
        self.growth_threshold = 0.001
        self.growth_increment = 0.2
        self.increment = 0.2
        self.tolerance = 0.02
        self.prune_every_k_steps = None     # Fix Sparse Trainign
        self.half = fp16
        self.name_to_32bit = {}

        # >>>>>>>>>>>>>>>>>>>>>>>> For Group >>>>>>>>>>>>>>>>>>>>>>>>
        self.group_idxs = {}
        self.last_weights = {}


    def init_optimizer(self):
        if 'fp32_from_fp16' in self.optimizer.state_dict():
            for (name, tensor), tensor2 in zip(self.modules[0].named_parameters(), self.optimizer.state_dict()['fp32_from_fp16'][0]):
                self.name_to_32bit[name] = tensor2
            self.half = True

    def init(self, mode='ERK', density=0.05, erk_power_scale=1.0):
        self.density = density
        self.init_growth_prune_and_redist()
        self.init_optimizer()
        if mode == 'uniform':
            print('initialized with uniform')
            # initializes each layer with a constant percentage of dense weights
            # each layer will have weight.numel()*density weights.
            # weight.numel()*density == weight.numel()*(1.0-sparsity)
            self.baseline_nonzero = 0
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    self.masks[name][:] = (torch.rand(weight.shape) < density).float().data.cuda()
                    self.baseline_nonzero += weight.numel()*density
            self.apply_mask()

        elif mode == 'GraSP':
            print('initialize by GraSP')
            layer_wise_sparsities = GraSP(self.module, self.density, self.train_loader, self.device)
            # re-sample mask positions
            for sparsity_, name in zip(layer_wise_sparsities, self.masks):
                self.masks[name][:] = (torch.rand(self.masks[name].shape) < (1-sparsity_)).float().data.cuda()

        elif mode == 'snip':
            print('initialize by snip')
            self.baseline_nonzero = 0
            snip_masks = SNIP(self.module, self.density, self.train_loader, self.device, self.masks, self.args)
            for snip_mask, name in zip(snip_masks, self.masks):
                assert (snip_mask.shape == self.masks[name].shape)
                self.masks[name][:] = snip_mask
                self.baseline_nonzero += (self.masks[name]!=0).sum().item()

        elif mode == 'resume':
            print('initialized with resume')
            # Initializes the mask according to the weights
            # which are currently zero-valued. This is required
            # if you want to resume a sparse model but did not
            # save the mask.
            self.baseline_nonzero = 0
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    print((weight != 0.0).sum().item())
                    if name in self.name_to_32bit:
                        print('W2')
                    self.masks[name][:] = (weight != 0.0).float().data.cuda()
                    self.baseline_nonzero += weight.numel()*density
            self.apply_mask()

        elif mode == 'ERK_plus':
            print('initialize by ERK_plus')
            total_params = 0
            self.baseline_nonzero = 0
            for name, weight in self.masks.items():
                total_params += weight.numel()
                self.baseline_nonzero += weight.numel() * density

            # remove FC layer
            for name in self.masks.copy():
                if 'fc.weight' in name:
                    total_params = total_params - self.masks[name].numel()
                    density = (self.baseline_nonzero - self.masks[name].numel() * self.fc_density) / total_params
                    self.masks.pop(name)

            is_epsilon_valid = False
            dense_layers = set()
            while not is_epsilon_valid:

                divisor = 0
                rhs = 0
                raw_probabilities = {}
                for name, mask in self.masks.items():
                    n_param = np.prod(mask.shape)
                    n_zeros = n_param * (1 - density)
                    n_ones = n_param * density

                    if name in dense_layers:
                        # See `- default_sparsity * (N_3 + N_4)` part of the equation above.
                        rhs -= n_zeros

                    else:
                        # Corresponds to `(1 - default_sparsity) * (N_1 + N_2)` part of the
                        # equation above.
                        rhs += n_ones
                        # Erdos-Renyi probability: epsilon * (n_in + n_out / n_in * n_out).
                        if len(mask.shape) !=2 :
                            raw_probabilities[name] = (
                                                          np.sum(mask.shape) / np.prod(mask.shape)
                                                  ) ** erk_power_scale
                        else:
                            raw_probabilities[name] = (
                                                              np.sum(mask.shape) / np.prod(mask.shape)
                                                      ) ** erk_power_scale
                        # Note that raw_probabilities[mask] * n_param gives the individual
                        # elements of the divisor.
                        divisor += raw_probabilities[name] * n_param
                # By multipliying individual probabilites with epsilon, we should get the
                # number of parameters per layer correctly.
                epsilon = rhs / divisor
                # If epsilon * raw_probabilities[mask.name] > 1. We set the sparsities of that
                # mask to 0., so they become part of dense_layers sets.
                max_prob = np.max(list(raw_probabilities.values()))
                max_prob_one = max_prob * epsilon
                if max_prob_one > 1:
                    is_epsilon_valid = False
                    for mask_name, mask_raw_prob in raw_probabilities.items():
                        if mask_raw_prob == max_prob:
                            print(f"Sparsity of var:{mask_name} had to be set to 0.")
                            dense_layers.add(mask_name)
                else:
                    is_epsilon_valid = True

            density_dict = {}
            total_nonzero = 0.0
            # With the valid epsilon, we can set sparsities of the remaning layers.
            for name, mask in self.masks.items():
                n_param = np.prod(mask.shape)
                if name in dense_layers:
                    density_dict[name] = 1.0
                else:
                    probability_one = epsilon * raw_probabilities[name]
                    density_dict[name] = probability_one
                print(
                    f"layer: {name}, shape: {mask.shape}, density: {density_dict[name]}"
                )
                self.masks[name][:] = (torch.rand(mask.shape) < density_dict[name]).float().data.cuda()

                total_nonzero += density_dict[name] * mask.numel()
            
                    # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


            for name, weight in self.module.named_parameters():
                if 'fc.weight' in name:
                    # fc density = 1
                    self.masks[name] = (torch.rand(weight.shape) < self.fc_density).float().data.cuda()
                    total_nonzero += self.fc_density * weight.numel()
                    total_params += weight.numel()
                    print(
                        f"layer: {name}, shape: {self.masks[name].shape}, density: {self.fc_density}"
                    )
                    density_dict[name] = self.fc_density
            
            self.density_dict = density_dict
            print(f"Overall sparsity {total_nonzero / total_params}")

            # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
            # TODO: group mask here
            self.group_mask()
            # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

            self.apply_mask()

        elif mode == 'ERK':
            print('initialize by fixed_ERK')
            total_params = 0
            self.baseline_nonzero = 0
            for name, weight in self.masks.items():
                total_params += weight.numel()
                self.baseline_nonzero += weight.numel() * density
            is_epsilon_valid = False

            dense_layers = set()
            while not is_epsilon_valid:

                divisor = 0
                rhs = 0
                raw_probabilities = {}
                for name, mask in self.masks.items():
                    n_param = np.prod(mask.shape)
                    n_zeros = n_param * (1 - density)
                    n_ones = n_param * density

                    if name in dense_layers:
                        # See `- default_sparsity * (N_3 + N_4)` part of the equation above.
                        rhs -= n_zeros

                    else:
                        rhs += n_ones  ###### very important, don't forget this
                        raw_probabilities[name] = (
                                                          np.sum(mask.shape) / np.prod(mask.shape)
                                                  ) ** erk_power_scale
                        # Note that raw_probabilities[mask] * n_param gives the individual
                        # elements of the divisor.
                        divisor += raw_probabilities[name] * n_param
                # By multipliying individual probabilites with epsilon, we should get the
                # number of parameters per layer correctly.
                epsilon = rhs / divisor
                # If epsilon * raw_probabilities[mask.name] > 1. We set the sparsities of that
                # mask to 0., so they become part of dense_layers sets.
                max_prob = np.max(list(raw_probabilities.values()))
                max_prob_one = max_prob * epsilon
                if max_prob_one > 1:
                    is_epsilon_valid = False
                    for mask_name, mask_raw_prob in raw_probabilities.items():
                        if mask_raw_prob == max_prob:
                            print(f"Sparsity of var:{mask_name} had to be set to 0.")
                            dense_layers.add(mask_name)
                else:
                    is_epsilon_valid = True

            density_dict = {}
            total_nonzero = 0.0
            # With the valid epsilon, we can set sparsities of the remaning layers.
            for name, mask in self.masks.items():
                n_param = np.prod(mask.shape)
                if name in dense_layers:
                    density_dict[name] = 1.0
                else:
                    probability_one = epsilon * raw_probabilities[name]
                    density_dict[name] = probability_one
                print(
                    f"layer: {name}, shape: {mask.shape}, density: {density_dict[name]}"
                )
                self.masks[name][:] = (torch.rand(mask.shape) < density_dict[name]).float().data.cuda()

                total_nonzero += density_dict[name] * mask.numel()
            print(f"Overall sparsity {total_nonzero / total_params}")


            # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
            # TODO: group mask here
            self.group_mask()
            # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

            self.apply_mask()


        self.fired_masks = copy.deepcopy(self.masks)
        self.print_nonzero_counts()

        total_size = 0
        for name, weight in self.masks.items():
            total_size += weight.numel()
        print('Total Model parameters:', total_size)

        sparse_size = 0
        for name, weight in self.masks.items():
            sparse_size += (weight != 0).sum().int().item()

        print('Total parameters under sparsity level of {0}: {1}'.format(density, sparse_size / total_size))

    def init_growth_prune_and_redist(self):
        if isinstance(self.growth_func, str) and self.growth_func in growth_funcs:
            if 'global' in self.growth_func: self.global_growth = True
            self.growth_func = growth_funcs[self.growth_func]
        elif isinstance(self.growth_func, str):
            print('='*50, 'ERROR', '='*50)
            print('Growth mode function not known: {0}.'.format(self.growth_func))
            print('Use either a custom growth function or one of the pre-defined functions:')
            for key in growth_funcs:
                print('\t{0}'.format(key))
            print('='*50, 'ERROR', '='*50)
            raise Exception('Unknown growth mode.')

        if isinstance(self.prune_func, str) and self.prune_func in prune_funcs:
            if 'global' in self.prune_func: self.global_prune = True
            self.prune_func = prune_funcs[self.prune_func]
        elif isinstance(self.prune_func, str):
            print('='*50, 'ERROR', '='*50)
            print('Prune mode function not known: {0}.'.format(self.prune_func))
            print('Use either a custom prune function or one of the pre-defined functions:')
            for key in prune_funcs:
                print('\t{0}'.format(key))
            print('='*50, 'ERROR', '='*50)
            raise Exception('Unknown prune mode.')

        if isinstance(self.redistribution_func, str) and self.redistribution_func in redistribution_funcs:
            self.redistribution_func = redistribution_funcs[self.redistribution_func]
        elif isinstance(self.redistribution_func, str):
            print('='*50, 'ERROR', '='*50)
            print('Redistribution mode function not known: {0}.'.format(self.redistribution_func))
            print('Use either a custom redistribution function or one of the pre-defined functions:')
            for key in redistribution_funcs:
                print('\t{0}'.format(key))
            print('='*50, 'ERROR', '='*50)
            raise Exception('Unknown redistribution mode.')

    def at_end_of_epoch(self):
        self.truncate_weights()
        _, _ = self.fired_masks_update()
        # self.print_nonzero_counts()

    def step(self):
        self.optimizer.step()
        self.apply_mask()
        self.prune_rate_decay.step()
        self.prune_rate = self.prune_rate_decay.get_dr(self.prune_rate)

        self.steps += 1

        if self.prune_every_k_steps is not None:
            if self.steps % self.prune_every_k_steps == 0:
                self.truncate_weights()
                if self.verbose:
                    self.print_nonzero_counts()


    def add_module(self, module, density, sparse_init='ER'):
        self.module = module
        self.modules.append(module)
        for name, tensor in module.named_parameters():
            self.names.append(name)
            self.masks[name] = torch.zeros_like(tensor, dtype=torch.float32, requires_grad=False).cuda()
        print('Removing biases...')
        self.remove_weight_partial_name('bias')
        print('Removing 2D batch norms...')
        self.remove_type(nn.BatchNorm2d, verbose=self.verbose)
        print('Removing 1D batch norms...')
        self.remove_type(nn.BatchNorm1d, verbose=self.verbose)

        # >>>>>>>>>>>>>>>>>>>>>>>> initialization >>>>>>>>>>>>>>>>>>>>>>>>
        self.init(mode=sparse_init, density=density)

    def is_at_start_of_pruning(self, name):
        if self.start_name is None: self.start_name = name
        if name == self.start_name: return True
        else: return False

    def remove_weight(self, name):
        if name in self.masks:
            print('Removing {0} of size {1} = {2} parameters.'.format(name, self.masks[name].shape, self.masks[name].numel()))
            self.masks.pop(name)
        elif name+'.weight' in self.masks:
            print('Removing {0} of size {1} = {2} parameters.'.format(name, self.masks[name+'.weight'].shape, self.masks[name+'.weight'].numel()))
            self.masks.pop(name+'.weight')
        else:
            print('ERROR',name)

    def remove_weight_partial_name(self, partial_name, verbose=False):
        removed = set()
        for name in list(self.masks.keys()):
            if partial_name in name:
                if self.verbose:
                    print('Removing {0} of size {1} with {2} parameters...'.format(name, self.masks[name].shape, np.prod(self.masks[name].shape)))
                removed.add(name)
                self.masks.pop(name)

        print('Removed {0} layers.'.format(len(removed)))

        i = 0
        while i < len(self.names):
            name = self.names[i]
            if name in removed: self.names.pop(i)
            else: i += 1


    def remove_type(self, nn_type, verbose=False):
        for module in self.modules:
            for name, module in module.named_modules():
                if isinstance(module, nn_type):
                    self.remove_weight(name)
                    #self.remove_weight_partial_name(name, verbose=self.verbose)


    '''
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Grouping >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    '''

    def group_mask(self):
        logging.info('\n+++++++++++++++++++++++++++++++ Group Mask +++++++++++++++++++++++++++++++')
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                # if self.density_dict[name] == 1.0: continue
                dense_ratio = self.density_dict[name]
                mask = self.masks[name]    # shape: (ch_out, ch_in, K, K). Value: 1 means keep, 0 means out     80% "0"
                # logging.info('******** name: {},  shape: {} ********, Dense: {}'.format(name, str(mask.shape), dense_ratio))
                print_and_log('******** name: {},  shape: {} ********, Dense: {}'.format(name, str(mask.shape), dense_ratio))

                new_mask, dense_block = self.hr_pruning(mask, define_density=dense_ratio)       # (ch_out, ch_in, K, K)
                new_mask = new_mask.to(mask.device)
                
                # logging.info('Origin density: {}, new density: {}'.format( mask.sum()/ mask.numel(),  new_mask.sum() / mask.numel()   ) )
                print_and_log('Origin density: {}, new density: {}'.format( mask.sum()/ mask.numel(),  new_mask.sum() / mask.numel()   ))
                self.masks[name] = new_mask.float()         # replace

                # save
                self.group_idxs[name] = {}
                self.group_idxs[name]["filter_size"] = list(weight.shape)
                self.group_idxs[name]["density"] = new_mask.sum().item() / mask.numel()
                self.group_idxs[name]["block"] = dense_block         # [[rows], [cols]]  save results

                # *********************** Pring Info ***********************
                msg = "-- # of block: {} \n".format(len(dense_block))
                for row, col in dense_block:
                    msg += '--- Block Size: {} x {}: {}, {}\n'.format(len(row), len(col), row, col)
                msg += '\n'
                # logging.info(msg)
                print_and_log(msg)
        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> output dir >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        save_dir = self.args.save + '/mask.json'
        json.dump(self.group_idxs, open(save_dir, "w"), indent=4)



    def apply_mask(self):
        self.synchronism_masks()

        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name in self.masks:
                    if not self.half:
                        tensor.data = tensor.data*self.masks[name]
                    else:
                        tensor.data = tensor.data*self.masks[name].half()
                        if name in self.name_to_32bit:
                            tensor2 = self.name_to_32bit[name]
                            tensor2.data = tensor2.data*self.masks[name]

    def adjust_prune_rate(self):
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                if name not in self.name2prune_rate: self.name2prune_rate[name] = self.prune_rate

                self.name2prune_rate[name] = self.prune_rate

                sparsity = self.name2zeros[name]/float(self.masks[name].numel())
                if sparsity < 0.2:
                    # determine if matrix is relativly dense but still growing
                    expected_variance = 1.0/len(list(self.name2variance.keys()))
                    actual_variance = self.name2variance[name]
                    expected_vs_actual = expected_variance/actual_variance
                    if expected_vs_actual < 1.0:
                        # growing
                        self.name2prune_rate[name] = min(sparsity, self.name2prune_rate[name])

    def truncate_weights(self):
        self.gather_statistics()
        self.adjust_prune_rate()

        total_nonzero_new = 0
        if self.global_prune:
            self.total_removed = self.prune_func(self)
        else:
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    mask = self.masks[name]

                    # prune
                    new_mask = self.prune_func(self, mask, weight, name)
                    removed = self.name2nonzeros[name] - new_mask.sum().item()
                    self.total_removed += removed
                    self.name2removed[name] = removed
                    self.masks[name][:] = new_mask

        name2regrowth = self.calc_growth_redistribution()
        if self.global_growth:
            total_nonzero_new = self.growth_func(self, self.total_removed + self.adjusted_growth)
        else:
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    new_mask = self.masks[name].data.byte()

                    # growth
                    new_mask = self.growth_func(self, name, new_mask, math.floor(self.name2removed[name]), weight)

                    new_nonzero = new_mask.sum().item()

                    # exchanging masks
                    self.masks.pop(name)
                    self.masks[name] = new_mask.float()
                    total_nonzero_new += new_nonzero

        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        # TODO: group mask here
        self.group_mask()
        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

        self.apply_mask()

        # Some growth techniques and redistribution are probablistic and we might not grow enough weights or too much weights
        # Here we run an exponential smoothing over (prune-growth) residuals to adjust future growth
        self.adjustments.append(self.baseline_nonzero - total_nonzero_new)
        self.adjusted_growth = 0.25*self.adjusted_growth + (0.75*(self.baseline_nonzero - total_nonzero_new)) + np.mean(self.adjustments)
        if self.total_nonzero > 0 and self.verbose:
            print('Nonzero before/after: {0}/{1}. Growth adjustment: {2:.2f}.'.format(
                  self.total_nonzero, total_nonzero_new, self.adjusted_growth))

    def gather_statistics(self):
        self.name2nonzeros = {}
        self.name2zeros = {}
        self.name2variance = {}
        self.name2removed = {}

        self.total_variance = 0.0
        self.total_removed = 0
        self.total_nonzero = 0
        self.total_zero = 0.0
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                mask = self.masks[name]

                # redistribution
                self.name2variance[name] = self.redistribution_func(self, name, weight, mask)

                if not np.isnan(self.name2variance[name]):
                    self.total_variance += self.name2variance[name]
                self.name2nonzeros[name] = mask.sum().item()
                self.name2zeros[name] = mask.numel() - self.name2nonzeros[name]

                sparsity = self.name2zeros[name]/float(self.masks[name].numel())
                self.total_nonzero += self.name2nonzeros[name]
                self.total_zero += self.name2zeros[name]

        for name in self.name2variance:
            if self.total_variance != 0.0:
                self.name2variance[name] /= self.total_variance
            else:
                print('Total variance was zero!')
                print(self.growth_func)
                print(self.prune_func)
                print(self.redistribution_func)
                print(self.name2variance)

    def calc_growth_redistribution(self):
        num_overgrowth = 0
        total_overgrowth = 0
        residual = 0

        residual = 9999
        mean_residual = 0
        name2regrowth = {}
        i = 0
        expected_var = 1.0/len(self.name2variance)
        while residual > 0 and i < 1000:
            residual = 0
            for name in self.name2variance:
                prune_rate = self.name2prune_rate[name]
                num_remove = math.ceil(prune_rate*self.name2nonzeros[name])
                num_nonzero = self.name2nonzeros[name]
                num_zero = self.name2zeros[name]
                max_regrowth = num_zero + num_remove

                if name in name2regrowth:
                    regrowth = name2regrowth[name]
                else:
                    regrowth = math.ceil(self.name2variance[name]*(self.total_removed+self.adjusted_growth))
                regrowth += mean_residual

                if regrowth > 0.99*max_regrowth:
                    name2regrowth[name] = 0.99*max_regrowth
                    residual += regrowth - name2regrowth[name]
                else:
                    name2regrowth[name] = regrowth
            if len(name2regrowth) == 0: mean_residual = 0
            else:
                mean_residual = residual / len(name2regrowth)
            i += 1

        if i == 1000:
            print('Error resolving the residual! Layers are too full! Residual left over: {0}'.format(residual))

        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                if self.prune_mode == 'global_magnitude':
                    expected_removed = self.baseline_nonzero*self.name2prune_rate[name]
                    if expected_removed == 0.0:
                        name2regrowth[name] = 0.0
                    else:
                        expected_vs_actual = self.total_removed/expected_removed
                        name2regrowth[name] = math.floor(expected_vs_actual*name2regrowth[name])

        return name2regrowth


    '''
                UTILITY
    '''
    def get_momentum_for_weight(self, weight):
        if 'exp_avg' in self.optimizer.state[weight]:
            adam_m1 = self.optimizer.state[weight]['exp_avg']
            adam_m2 = self.optimizer.state[weight]['exp_avg_sq']
            grad = adam_m1/(torch.sqrt(adam_m2) + 1e-08)
        elif 'momentum_buffer' in self.optimizer.state[weight]:
            grad = self.optimizer.state[weight]['momentum_buffer']

        return grad

    def get_gradient_for_weights(self, weight):
        grad = weight.grad.clone()
        return grad

    def print_nonzero_counts(self):
        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name not in self.masks: continue
                mask = self.masks[name]
                num_nonzeros = (mask != 0).sum().item()
                if name in self.name2variance:
                    val = '{0}: {1}->{2}, density: {3:.3f}, proportion: {4:.4f}'.format(name, self.name2nonzeros[name], num_nonzeros, num_nonzeros/float(mask.numel()), self.name2variance[name])
                    print(val)
                else:
                    print(name, num_nonzeros)

        print('Prune rate: {0}\n'.format(self.prune_rate))

    def fired_masks_update(self):
        ntotal_fired_weights = 0.0
        ntotal_weights = 0.0
        layer_fired_weights = {}
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                self.fired_masks[name] = self.masks[name].data.byte() | self.fired_masks[name].data.byte()
                ntotal_fired_weights += float(self.fired_masks[name].sum().item())
                ntotal_weights += float(self.fired_masks[name].numel())
                layer_fired_weights[name] = float(self.fired_masks[name].sum().item())/float(self.fired_masks[name].numel())
                print('Layerwise percentage of the fired weights of', name, 'is:', layer_fired_weights[name])
        total_fired_weights = ntotal_fired_weights/ntotal_weights
        print('The percentage of the total fired weights is:', total_fired_weights)
        return layer_fired_weights, total_fired_weights


    def synchronism_masks(self):

        for name in self.masks.keys():
            torch.distributed.broadcast(self.masks[name], src=0, async_op=False)



    '''
    >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Grouping >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    '''
    def hr_pruning(self, unstructured_mask, define_density, num_group=8, B1=8, given_sparse=0.444444):
        # input_mask: (C_out, C_in, K, K)       unstructured sparsity
        # inner_sparse: can choose from 1/9, 2/9, 3/9, ... , 9/9
        # !!!!!! density < inner_sparse
    
        # >>>>>>>>>>>>>>>>>>>>>>>>>> HyperParameter calculation  >>>>>>>>>>>>>>>>>>>>>>>>>>
        density_ratio = define_density
            

        # >>>>>>>>>>>>>>>>>>>>>>>>>>  Low sparse mask >>>>>>>>>>>>>>>>>>>>>>>>>> 
        if define_density >= 0.3:    # !!! Cannot accelerate low sparse mask with regrouping. 0.3 is empirically found 
            # logging.info('Density is higher than 0.3, cannot be accelerated on hardware')
            print_and_log('Density is higher than 0.3, cannot be accelerated on hardware')
            return unstructured_mask, []

        
        c_out, c_in, k1, k2 = unstructured_mask.shape

        # >>>>>>>>>>>>>>>>>>>>>>>>>> Find the Min sparse >>>>>>>>>>>>>>>>>>>>>>>>>>
        if k1 == 1  or k2 == 1:         # avoid layer with 1x1 conv
            inner_sparse = 1.0
        else:
            for s in range(1, k1 * k2):
                if s > density_ratio * k1 * k2:
                    inner_sparse = s / (k1 * k2)
                    break
            # If fix kernel-sparse
            # inner_sparse = given_sparse

        # ********************* STEP1: obtain kernel-level mask ******************************************************
        num_keep_kernel = math.ceil((1 / inner_sparse) * density_ratio * c_out * c_in)   # 1/0.5  20% 
        kernel_mask = self.get_kernel_mask(unstructured_mask, num_keep_kernel)           # (C_out, C_in, K, K) -> (C_out, C_in)
        
        # ********************* STEP2: kernel-level mask Grouping ****************************************************
        num_row, num_col = kernel_mask.shape        # num_col = C_in

        # ******************* fix size for all blocks *******************
        adjust_ratio = 1.0          # >= 1.0  this can keep more kernels in one dense block, but can introduce more calculation
        keep_col_num = math.ceil( (1 / inner_sparse) * density_ratio * adjust_ratio * num_col)

        # logging.info("Shape:{}x{}x{}x{}, Given Ratio:{}, #Col:{}".format(c_out, c_in, k1, k2, density_ratio, keep_col_num))
        print_and_log("Shape:{}x{}x{}x{}, Given Ratio:{}, #Col:{}".format(c_out, c_in, k1, k2, density_ratio, keep_col_num))

        new_kernel_mask = copy.deepcopy(kernel_mask)                        # (C_out, C_in)
        kernel_dense_group = []
        # 1) Divide the rows into t1 groups      (Only group Once)
        # groups = hyperGraphPartition(kernel_mask, t1=num_group)
        groups = random_group(kernel_mask, t1=num_group)        # !!! Fix Mask Only !!!

        for g in groups: # 2) for each group gi
            sec_rows_id = groups[g]                 # a list
            sec_rows = kernel_mask[sec_rows_id]     # (B1, num_col) 

            if len(sec_rows) >= B1: # if gi has no less than B1 rows  (aim to filter out small weights) 
                # sort all cols based on the number of "1"
                # 3) select columns with most "1"s          
                # sec_cols_id = sec_rows.sum(dim=0).topk(keep_col_num)[1]
                all_cols_count = [sec_rows[:, i].sum()  for i in range(num_col)]    # len: num_col
                all_cols_id_sort = sorted(list(range(num_col)), key=lambda x:all_cols_count[x], reverse=True) # 
                sec_cols_id = all_cols_id_sort[:keep_col_num]

                sec_rows_id.sort()
                sec_cols_id.sort()

                # 4) output selected cols in gi as a dense block 
                kernel_dense_group.append([sec_rows_id, sec_cols_id])   # save dense block ID

        if len(kernel_dense_group) == 0:        # Find nothing
            print('No dense block found!!!')
            return unstructured_mask, []

        # ********************* STEP3: Refill Stage **************************************************************
        total_remaining = kernel_mask.sum().item()      # remaining dot

        new_kernel_mask[kernel_mask == 1] = 0           # (C_out, C_in),  remove sparse mask dot
        
        
        # distribute the remaining "1" to each dense block
        if total_remaining > 0:
            # 1) count the remaining empty slots of each dense block
            block_slot = {}
            for iii, (rows, cols) in enumerate(kernel_dense_group):
                counter_one = 0
                for r in rows:
                    for c in cols:
                        if new_kernel_mask[r, c] == 0:
                            counter_one += 1
                block_slot[iii] = counter_one
            total_slot = sum([ v for _,v in block_slot.items() ])       # total number of 0 of all dense blocks 

            # 2) Fill in empty slots based on the ratio 
            for iii in block_slot.keys():
                num_to_refill_block = math.ceil( total_remaining * block_slot[iii] / total_slot )  # +1

                # save empty slot index 
                rows, cols = kernel_dense_group[iii]
                slot_idx = []
                for r in rows:
                    for c in cols:
                        if new_kernel_mask[r, c] == 0:
                            slot_idx.append((r, c))

                # fill in empty slot 
                random.shuffle(slot_idx)
                select_slot_idx = slot_idx[:num_to_refill_block]
                for (r, c) in select_slot_idx:
                    new_kernel_mask[r, c] = 1
            
        # ********************* STEP4: Go back to unstructured **************************************************************

        # >>> keep or drop the entire 3x3 kernel 
        if inner_sparse == 1: 
            # output_unstructured_mask = torch.ones(unstructured_mask.shape)                                          # (C_out, C_in, K, K) 
            # output_unstructured_mask = output_unstructured_mask * new_kernel_mask.unsqueeze(-1).unsqueeze(-1)       # (C_out, C_in, K, K) 
            output_unstructured_mask = new_kernel_mask.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, k1, k2)           # (C_out, C_in, K, K) 
            # adjust group
            output_unstructured_group = []
            for (rows, cols) in kernel_dense_group:
                new_cols = []
                for c in cols:
                    for kkk in range(k1 * k2):
                        new_cols.append( c * k1 * k2 + kkk)
                output_unstructured_group.append( [rows, new_cols])

            return output_unstructured_mask, output_unstructured_group
        
        # >>> apply the same pattern in one dense block
        else:
            output_unstructured_mask = torch.zeros(unstructured_mask.shape)                # (C_out, C_in, K, K) 
            output_unstructured_group = []
            for (rows, cols) in kernel_dense_group:
                group_kernels = []
                for r in rows:
                    for c in cols:
                        group_kernels.append(unstructured_mask[r, c])       # (K, K)
                group_kernels = torch.stack(group_kernels, dim=0)           # (#block, K, K)
                # counter the number of each cells in the current group, and select the top cells
                cell_count = group_kernels.sum(0).view(-1)                  # (K * K)
                keep_cell = math.ceil(inner_sparse * k1 * k2)
                top_cell_idx = cell_count.topk(k=keep_cell)[1]              # e.g.: (0, 2, 4, 5)
                pattern = torch.zeros(k1, k2)                               
                pattern.view(-1)[top_cell_idx] = 1                          # (K, K)

                # save mask
                for r in rows:
                    for c in cols:
                        output_unstructured_mask[r, c] = pattern

                # save group index of unstructured mask (C_out, C_in * K * K)   | revise columns 
                new_cols = []
                for c in cols:
                    for kkk in top_cell_idx.tolist():        # (0, 2, 4, 5)
                        new_cols.append(c * k1 * k2 + kkk)
                output_unstructured_group.append([rows, new_cols])

            return output_unstructured_mask, output_unstructured_group


    def get_kernel_mask(self, unstructured_mask, num_keep_kernel):
        kernel_count = unstructured_mask.sum(dim=(2,3))                                                           # (C_out, C_in)
        top_kernel_idx = kernel_count.view(-1).topk(k = num_keep_kernel)[1]       # (C_out, C_in)
        kernel_mask = torch.zeros(kernel_count.shape)               # (C_out, C_in)
        kernel_mask.view(-1)[top_kernel_idx] = 1                    # (C_out, C_in)
        return kernel_mask  




# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>


def matrix2graph(matrix):
    # matrix: (c_out, c_in * K * K)
    num_row, num_col = matrix.shape
    hyperedge_indices = [0]
    hyperedges = []
    counter = 0
    for i in range(num_col):
        for j in range(num_row):
            if matrix[j, i].item() == 1:
                hyperedges.append(j)
                counter += 1
        hyperedge_indices.append(counter)
    return hyperedge_indices, hyperedges



def hyperGraphPartition(weights, t1=2):
    # weights: (cout, cin * K * K)
    import kahypar as kahypar
    
    # transfer to hyper graph format    
    num_nodes, num_edges = weights.shape
    hyperedge_indices, hyperedges = matrix2graph(weights)  # hyperedge_indices, length = num_edge + 1
    edge_weights = [1] * num_edges
    node_weights = [1] * num_nodes
    
    # build Hyper Graph
    hypergraph = kahypar.Hypergraph(num_nodes, num_edges, hyperedge_indices, hyperedges, t1, edge_weights, node_weights)
    
    # Partition
    context = kahypar.Context()
    context.loadINIconfiguration("config/km1_kKaHyPar_sea20.ini")
    context.setK(t1)
    context.setEpsilon(0.03)
    kahypar.partition(hypergraph, context)
    

    # output 
    groups = {}
    for node_id in hypergraph.nodes():
        group_id = hypergraph.blockID(node_id)
        if group_id not in groups:
            groups[group_id] = [node_id]
        else:
            groups[group_id].append(node_id)
    return groups




def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]


def random_group(kernel_mask, t1):
    num_row = kernel_mask.shape[0]
    group_size = num_row // t1

    group_idx = list(range(num_row))
    random.shuffle(group_idx)

    groups = {}
    for i, group in enumerate(chunks(group_idx, group_size)):
        groups[i] = group
    return groups


