from __future__ import print_function
import torch
import copy
import numpy as np
import math
import wandb
from sparselearning.decay import CosineDecay, LinearDecay, ConstantDecay, WSDDecay
from transformers import AutoConfig, AutoTokenizer, default_data_collator
import torch.distributed as dist
from peft_pretraining import training_utils, args_utils
# from sparselearning.optimizer import AdamDST, SftAdamW
from sparselearning.optimizer_new import Adam

import datasets


class Masking(object):
    """
    Controls the dynamic sparsity patterns in neural networks during training.
    
    This class manages the complete lifecycle of sparse training, including initialization,
    weight pruning, weight regrowth across the network.
    It supports various sparse training algorithms through different combinations of 
    prune_mode and growth_mode parameters. For example:
    - RigL: prune_mode='magnitude', growth_mode='gradient'
    - SET: prune_mode='magnitude', growth_mode='random'
    
    """
    def __init__(
            self,
            preprocess_batched,
            pad_idx,
            growth_prune_ratio=1.0,
            redistribution_mode='none',
            threshold=0.001,
            args=None,
            distributed=False,
            device=None,
            global_rank=0,
            world_size=0
    ):
        self.args = args
        self.distributed = distributed
        if device is None:
            self.device = torch.device('cuda')
        else:
            self.device = device

        self.growth_mode = args.growth
        self.prune_mode = args.prune
        self.growth_prune_ratio = growth_prune_ratio
        self.redistribution_mode = redistribution_mode

        self.prune_funcs = {}
        self.prune_funcs['magnitude'] = self.magnitude_prune
        self.prune_funcs['SET'] = self.magnitude_and_negativity_prune
        self.prune_funcs['threshold'] = self.threshold_prune

        self.growth_funcs = {}
        self.growth_funcs['random'] = self.random_growth
        self.growth_funcs['momentum'] = self.momentum_growth
        self.growth_funcs['momentum_neuron'] = self.momentum_neuron_growth

        self.masks = {}
        self.masks_pre_pre = {}
        self.final_masks = {}
        self.grads = {}
        self.scores = {}
        self.modules = []
        self.names = []

        self.adjusted_growth = 0
        self.adjustments = []
        self.baseline_nonzero = None
        self.name2baseline_nonzero = {}

        self.preprocess_batched = preprocess_batched
        self.pad_idx = pad_idx
        self.global_rank = global_rank
        self.world_size = world_size


        # stats
        self.momentum_dict = {}
        self.name2variance = {}
        self.name2zeros = {}
        self.name2nonzeros = {}
        self.total_variance = 0
        self.total_removed = 0
        self.total_zero = 0
        self.total_nonzero = 0
        self.prune_rate = args.prune_rate
        self.name2prune_rate = {}
        self.name2density = {}
        self.steps = 0

        # global growth/prune state
        self.threshold = threshold
        self.growth_threshold = threshold
        self.growth_increment = 0.2
        self.increment = 0.2
        self.tolerance = 0.02
        if self.args.fix:
            self.prune_every_k_steps = None
        else:
            self.prune_every_k_steps = args.update_frequency

        self.set_prune_rate_decay()
        self.set_density_decay()
        self.set_temperature_decay()

    def synchronize_masks(self):
        """ Synchronize masks across GPUs. """
        if self.distributed:
            for name in self.masks.keys():
                torch.distributed.broadcast(self.masks[name], src=0, async_op=False)

    def init_sparse_masks(self, erk_power_scale=1.0):
        if self.args.density_decay == 'constant':
            density = self.args.density
        else:
            density = self.args.initial_density

        if self.sparse_init == 'uniform':
            self.baseline_nonzero = 0
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    self.masks[name][:] = (torch.rand(weight.shape) < density).float().data.cuda() #lsw
                    self.baseline_nonzero += weight.numel()*density
            # self.apply_mask()

        elif self.sparse_init == 'fixed_ERK':
            print('initialize by fixed_ERK')
            total_params = 0
            for name, weight in self.masks.items():
                total_params += weight.numel()
            is_epsilon_valid = False
            # # The following loop will terminate worst case when all masks are in the
            # custom_sparsity_map. This should probably never happen though, since once
            # we have a single variable or more with the same constant, we have a valid
            # epsilon. Note that for each iteration we add at least one variable to the
            # custom_sparsity_map and therefore this while loop should terminate.
            dense_layers = set()
            while not is_epsilon_valid:
                # We will start with all layers and try to find right epsilon. However if
                # any probablity exceeds 1, we will make that layer dense and repeat the
                # process (finding epsilon) with the non-dense layers.
                # We want the total number of connections to be the same. Let say we have
                # for layers with N_1, ..., N_4 parameters each. Let say after some
                # iterations probability of some dense layers (3, 4) exceeded 1 and
                # therefore we added them to the dense_layers set. Those layers will not
                # scale with erdos_renyi, however we need to count them so that target
                # paratemeter count is achieved. See below.
                # eps * (p_1 * N_1 + p_2 * N_2) + (N_3 + N_4) =
                #    (1 - default_sparsity) * (N_1 + N_2 + N_3 + N_4)
                # eps * (p_1 * N_1 + p_2 * N_2) =
                #    (1 - default_sparsity) * (N_1 + N_2) - default_sparsity * (N_3 + N_4)
                # eps = rhs / (\sum_i p_i * N_i) = rhs / divisor.

                divisor = 0
                rhs = 0
                raw_probabilities = {}
                for name, mask in self.masks.items():
                    n_param = np.prod(mask.shape)
                    n_zeros = n_param * (1 - density)
                    n_ones = n_param * density

                    if name in dense_layers:
                        # See `- default_sparsity * (N_3 + N_4)` part of the equation above.
                        rhs -= n_zeros

                    else:
                        # Corresponds to `(1 - default_sparsity) * (N_1 + N_2)` part of the
                        # equation above.
                        rhs += n_ones
                        # Erdos-Renyi probability: epsilon * (n_in + n_out / n_in * n_out).
                        raw_probabilities[name] = (
                                                          np.sum(mask.shape) / np.prod(mask.shape)
                                                  ) ** erk_power_scale
                        # Note that raw_probabilities[mask] * n_param gives the individual
                        # elements of the divisor.
                        divisor += raw_probabilities[name] * n_param
                # By multipliying individual probabilites with epsilon, we should get the
                # number of parameters per layer correctly.
                epsilon = rhs / divisor
                # If epsilon * raw_probabilities[mask.name] > 1. We set the sparsities of that
                # mask to 0., so they become part of dense_layers sets.
                max_prob = np.max(list(raw_probabilities.values()))
                max_prob_one = max_prob * epsilon
                if max_prob_one > 1:
                    is_epsilon_valid = False
                    for mask_name, mask_raw_prob in raw_probabilities.items():
                        if mask_raw_prob == max_prob:
                            print(f"Sparsity of var:{mask_name} had to be set to 0.")
                            dense_layers.add(mask_name)
                else:
                    is_epsilon_valid = True

            density_dict = {}
            total_nonzero = 0.0
            # With the valid epsilon, we can set sparsities of the remaning layers.
            for name, mask in self.masks.items():
                n_param = np.prod(mask.shape)
                if name in dense_layers:
                    density_dict[name] = 1.0
                else:
                    probability_one = epsilon * raw_probabilities[name]
                    density_dict[name] = probability_one
                print(
                    f"layer: {name}, shape: {mask.shape}, density: {density_dict[name]}"
                )
                self.masks[name][:] = (torch.rand(mask.shape) < density_dict[name]).float().data.cuda()

                total_nonzero += density_dict[name] * mask.numel()
            print(f"Overall density {total_nonzero / total_params}")

        elif self.sparse_init == 'uniform_ratio':
            self.baseline_nonzero = 0
            attn_params = []
            mlp_params = []

            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks:
                        continue
                    if 'attn' in name.lower() or 'attention' in name.lower():
                        attn_params.append((name, weight))
                    else:
                        mlp_params.append((name, weight))

            # 参数量
            N_attn = sum(w.numel() for _, w in attn_params)
            N_mlp = sum(w.numel() for _, w in mlp_params)

            # density ratio：attention:mlp = 2 : 1
            r = self.args.am_ratio
            D_total = density  # e.g., 0.2

            # 解密度值
            d_mlp = D_total * (N_attn + N_mlp) / (r * N_attn + N_mlp)
            d_attn = r * d_mlp

            # 应用mask
            for name, weight in attn_params:
                self.masks[name][:] = (torch.rand(weight.shape) < d_attn).float().data.cuda()
                self.baseline_nonzero += weight.numel() * d_attn

            for name, weight in mlp_params:
                self.masks[name][:] = (torch.rand(weight.shape) < d_mlp).float().data.cuda()
                self.baseline_nonzero += weight.numel() * d_mlp

        # self.apply_mask()
        self.fired_masks = copy.deepcopy(self.masks) # used for over-paremeters

        self.init_prune_rate(self.prune_rate)
        self.init_density_per_layer()
        self.print_nonzero_counts()

        total_size = 0
        for name, weight in self.masks.items():
            total_size  += weight.numel()
        print('Total Model parameters:', total_size)

        sparse_size = 0
        for name, weight in self.masks.items():
            sparse_size += (weight != 0).sum().int().item()

        print('Total initial parameters under sparsity level of {0}: {1}'.format(density, sparse_size / total_size))


    def init_prune_rate(self, prune_rate):
        for name in self.masks:
            self.name2prune_rate[name] = prune_rate

    def init_density_per_layer(self):
        """ Get the initialized density of each layer. """
        for name in self.masks:
            mask = self.masks[name]
            num_nonzero = mask.float().sum().item()
            total_weights = mask.numel()
            density = num_nonzero / total_weights
            self.name2density[name] = density

    def setting_optimizer(self, model):
        if self.args.optimizer.lower() == "adam":
            self.optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
        elif self.args.optimizer.lower() == "adamw":
            self.optimizer = torch.optim.AdamW(model.parameters(), lr=self.args.lr)
        elif self.args.optimizer.lower() == "adamdst":
            # self.optimizer = SftAdamW(model.named_parameters(), lr=self.args.lr, masks=self.masks, masks_pre=self.masks_pre_pre, decay_steps=self.args.op_decay_steps, decay_max=self.args.op_decay_max)
            self.optimizer = Adam(model.named_parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay, masks=self.masks, masks_pre=self.masks_pre_pre, decay_steps=self.args.op_decay_steps, decay_max=self.args.op_decay_max)
        else:
            raise ValueError(f"Optimizer {self.args.optimizer} not supported")

        self.lr_scheduler = training_utils.get_scheduler(
            optimizer=self.optimizer,
            scheduler_type=self.args.scheduler,
            num_training_steps=self.args.total_training_steps,
            warmup_steps=self.args.warmup_steps,
            min_lr_ratio=self.args.min_lr_ratio,
        )

    def step(self):
        """
        Executes a single optimization step in the sparse training loop.
    
        This function:
        1. Performs the standard optimizer step to update weights
        2. Applies sparsity masks to maintain network sparsity
        3. Updates prune rates according to the configured decay schedule
        4. Periodically prunes weights and logs sparsity statistics
        
        The prune rate controls the proportion of weights pruned during topology
        evolution, and can follow either a cosine or constant decay schedule.
        """
        self.optimizer.step()
        self.apply_mask()
        self.prune_rate_decay.step()
        self.density_decay.step()
        self.temperature_decay.step()

        for name in self.masks:
            if self.args.prune_rate_decay == 'cosine':
                self.name2prune_rate[name] = self.prune_rate_decay.get_current_value()
            elif self.args.prune_rate_decay == 'constant':
                self.name2prune_rate[name] = self.args.prune_rate
            elif self.args.prune_rate_decay == 'WSD':
                self.name2prune_rate[name] = self.prune_rate_decay.get_current_value()
            self.prune_rate = self.name2prune_rate[name]

        self.steps += 1

        if self.prune_every_k_steps is not None:


            if ((self.steps-1) % self.prune_every_k_steps >= (self.prune_every_k_steps - self.args.accumulate_grad_steps)) and self.growth_mode == 'gradient_acc':
                for module in self.modules:
                    for name, tensor in module.named_parameters():
                        if name not in self.masks:
                            continue

                        grad_flat = tensor.grad.clone().abs().view(-1)
                        mask_flat = self.masks[name].view(-1)

                        if name not in self.momentum_dict:
                            maintain_num = int(self.args.maintain_num * self.name2nonzeros[name] * self.name2prune_rate[name])
                            if maintain_num == 0:
                                continue

                            inactive_mask = (mask_flat == 0)
                            inactive_grad = grad_flat * inactive_mask
                            sorted_vals, sorted_indices = torch.sort(inactive_grad, descending=True)
                            topk_values = sorted_vals[:maintain_num]
                            topk_indices = sorted_indices[:maintain_num]

                            self.momentum_dict[name] = {
                                "values": topk_values.clone(),  # |g|
                                "values_sq": topk_values.pow(2).clone(),  # g^2
                                "indices": topk_indices.clone()
                            }

                        else:
                            indices = self.momentum_dict[name]["indices"]
                            selected_grad = grad_flat[indices]

                            self.momentum_dict[name]["values"] += selected_grad.abs()
                            self.momentum_dict[name]["values_sq"] += selected_grad.pow(2)


            if self.steps % self.prune_every_k_steps == 0:
                self.set_new_layer_densities()
                self.masks_pre_pre = {k: v.clone().detach() for k, v in self.masks.items()}

                self.truncate_weights()
                self.print_nonzero_counts()
                self.masks_pre = {k: v.clone().detach() for k, v in self.masks.items()}

                ## update masks in optimizer
                if self.args.optimizer.lower() == "adamdst":
                    self.optimizer.masks = self.masks
                    self.optimizer.masks_pre = self.masks_pre_pre

                self.momentum_dict = {}

                ## check overall density
                total_size = 0
                for name, weight in self.masks.items():
                    total_size += weight.numel()
                print('Total Model parameters:', total_size)

                sparse_size = 0
                for name, weight in self.masks.items():
                    sparse_size += (weight != 0).sum().int().item()

                print('Total parameters under sparsity level of {0}'.format(sparse_size / total_size))



    def add_module(self, module, density, sparse_init='ER'):
        self.sparse_init = sparse_init
        self.modules.append(module)
        print('adding module')
        for name, tensor in module.named_parameters():
            print(f'(len: {len(tensor.size())}) size of {name}: {tensor.size()}')

            if self.args.dense_embedding and 'embed' in name:
                print(f'Keeping embedding layer dense: {name}')
                continue  # skip embedding layer, if requested

            if self.args.dense_head and 'lm_head' in name:
                print(f'Keeping lm_head layer dense: {name}')
                continue  # skip embedding layer, if requested

            if len(tensor.size()) == 4 or len(tensor.size()) == 2:
                self.names.append(name)
                # self.masks[name] = torch.ones_like(tensor, dtype=torch.float32, requires_grad=False).cuda()  # old version
                self.masks[name] = torch.ones_like(tensor, dtype=tensor.dtype, requires_grad=False).cuda()

        self.remove_weight_partial_name('bias')
        self.init_sparse_masks()
        self.setting_optimizer(model=module)
        self.apply_mask()  # apply masks
        self.gather_statistics()
        self.masks_pre = {k: v.clone().detach() for k, v in self.masks.items()}

    def remove_weight(self, name):
        if name in self.masks:
            print('Removing {0} of size {1} = {2} parameters.'.format(name, self.masks[name].shape,
                                                                      self.masks[name].numel()))
            self.masks.pop(name)
        elif name + '.weight' in self.masks:
            print('Removing {0} of size {1} = {2} parameters.'.format(name, self.masks[name + '.weight'].shape,
                                                                      self.masks[name + '.weight'].numel()))
            self.masks.pop(name + '.weight')
        else:
            print('ERROR', name)

    def remove_weight_partial_name(self, partial_name):
        removed = set()
        for name in list(self.masks.keys()):
            if partial_name in name:

                print('Removing {0} of size {1} with {2} parameters...'.format(name, self.masks[name].shape,
                                                                                   np.prod(self.masks[name].shape)))
                removed.add(name)
                self.masks.pop(name)

        print('Removed {0} layers.'.format(len(removed)))

        i = 0
        while i < len(self.names):
            name = self.names[i]
            if name in removed:
                self.names.pop(i)
            else:
                i += 1

    def remove_type(self, nn_type):
        for module in self.modules:
            for name, module in module.named_modules():
                if isinstance(module, nn_type):
                    self.remove_weight(name)

    def apply_mask(self):
        self.synchronize_masks()
        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name not in self.masks:
                    continue

                mask_ = self.masks[name].to(tensor.dtype)
                tensor.data.mul_(mask_)

                state = self.optimizer.state.get(tensor)
                if not state:
                    continue

                # support SGD and Adam
                for key in ['momentum_buffer', 'exp_avg', 'exp_avg_sq']:
                    if key in state:
                        state[key].mul_(mask_.to(state[key].dtype))


    def truncate_weights(self):
        """
        Core function responsible for dynamic neural network topology evolution through structured sparsity.
        
        This method implements the sparse training paradigm by first pruning (truncating) weights based on
        specified criteria, then activating new parameters to maintain a constant sparsity level. The process
        follows these steps:
        
        1. Collect network statistics for informed decision-making
        2. Calculate parameter redistribution across layers
        3. Remove parameters based on the specified prune_mode
        4. Regrow parameters using the specified growth_mode
        
        Common weight pruning strategies include:
        - magnitude: Remove smallest magnitude weights (most common)
        - soft_magnitude: Rank smallest magnitude weights and remove based on probability
        - SET: Remove smallest and most negative weights
        - global_magnitude: Apply magnitude pruning globally across all layers
        
        Common weight regrowth strategies include:
        - random: Randomly activate new parameters (used in SET method)
        - gradient: Use gradient information to guide parameter activation (used in RigL method)
        - momentum: Leverage momentum data for intelligent regrowth
        """

        # ## evaluate the impact of pruning
        # print(f"Performing evaluation for befor pruning")
        # total_loss, evaluated_on_tokens = self.evaluate_model_core()
        # perplexity = math.exp(total_loss)
        # print(f"Eval for before pruning: "
        #       f"eval_loss {total_loss}  "
        #       f"perplexity {perplexity}  "
        #       f"tokens {evaluated_on_tokens}")

        self.gather_statistics()

        total_removed = 0
        if self.prune_mode == 'global_magnitude':
            total_removed = self.global_magnitude_prune()
        else:
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    mask = self.masks[name]

                    # prune
                    if self.prune_mode == 'magnitude':
                        new_mask = self.magnitude_prune(mask, weight, name)
                    elif self.prune_mode == 'magnitude_inverse':
                        new_mask = self.magnitude_inverse_prune(mask, weight, name)
                    elif self.prune_mode == 'magnitude_soft':
                        new_mask = self.magnitude_soft_prune(weight, name)
                    elif self.prune_mode == 'SET':
                        new_mask = self.magnitude_and_negativity_prune(mask, weight, name)
                    elif self.prune_mode == 'Taylor_FO':
                        new_mask = self.taylor_FO(mask, weight, name)
                    elif self.prune_mode == 'threshold':
                        new_mask = self.threshold_prune(mask, weight, name)
                    elif self.prune_mode == 'magnitude_increase':
                        new_mask = self.magnitude_increase(weight, mask, name)

                    total_removed += self.name2nonzeros[name] - new_mask.float().sum().item()
                    self.masks[name] = new_mask.to(weight.dtype)

        # Do we want to re-init weight values here (between pruning and growing)?
        # If newly grown weights should just start from value 0, then just apply_mask
        # If 'no', then no applying mask between pruning and growing (immediately regrown weights retain their values)
        if self.args.reinit == 'zero':
            self.apply_mask()
        elif self.args.reinit == 'original':
            self.reinit_weights_original_distribution()

        # ## evaluate the impact of pruning
        # print(f"Performing evaluation for after pruning")
        # total_loss, evaluated_on_tokens = self.evaluate_model_core()
        # perplexity = math.exp(total_loss)
        # print(f"Eval for after pruning: "
        #       f"eval_loss {total_loss}  "
        #       f"perplexity {perplexity}  "
        #       f"tokens {evaluated_on_tokens}")


        # growing
        if self.growth_mode == 'global_momentum':
            _ = self.global_momentum_growth(total_removed + self.adjusted_growth)
        else:

            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    new_mask = self.masks[name].data.byte()

                    left_after_prune = new_mask.float().sum().item()
                    desired_num = math.ceil(self.name2density[name] * self.masks[name].numel())
                    total_regrowth = int(desired_num - left_after_prune)
                    assert total_regrowth >= 0, "total_regrowth should be >= 0"

                    # growth
                    if self.growth_mode == 'random':
                        new_mask, regrow_idx = self.random_growth(name, new_mask, total_regrowth, weight)

                        if self.args.optimizer.lower() == "adamdst" and regrow_idx.numel() > 0:
                            acc_steps = self.args.accumulate_grad_steps

                            self.update_optimizer_momenta_for_regrowth(
                                param=weight,
                                regrow_indices=regrow_idx,
                                init_momenta={
                                    "step": acc_steps,
                                }
                            )

                    elif self.growth_mode == 'momentum':
                        new_mask = self.momentum_growth(name, new_mask, total_regrowth, weight)

                    elif self.growth_mode == 'gradient':  # RigL
                        new_mask, regrow_idx = self.gradient_growth(name, new_mask, total_regrowth, weight)

                        if self.args.optimizer.lower() == "adamdst" and regrow_idx.numel() > 0:
                            acc_steps = self.args.accumulate_grad_steps

                            self.update_optimizer_momenta_for_regrowth(
                                param=weight,
                                regrow_indices=regrow_idx,
                                init_momenta={
                                    "step": acc_steps,
                                }
                            )

                    elif self.growth_mode == 'gradient_acc':  # RigL_acc

                        new_mask, regrow_idx, topk_pos = self.gradient_growth_acc(name, new_mask, total_regrowth, weight)

                        if self.args.optimizer.lower() == "adamdst":
                            all_values = self.momentum_dict[name]["values"]  # shape: [N]
                            all_values_sq = self.momentum_dict[name]["values_sq"]

                            beta1, beta2 = self.optimizer.defaults["betas"]
                            acc_steps = self.args.accumulate_grad_steps

                            # get matched positions in all_indices
                            selected_grads = all_values[topk_pos] / acc_steps
                            selected_grads_sq = all_values_sq[topk_pos] / acc_steps

                            # acc_steps = acc_steps.to(dtype=torch.float32)
                            selected_grads *= (1.0 - beta1 ** acc_steps)
                            selected_grads_sq *= (1.0 - beta2 ** acc_steps)

                            self.update_optimizer_momenta_for_regrowth(
                                param=weight,
                                regrow_indices=regrow_idx,
                                init_momenta={
                                    "step": acc_steps,
                                    # "exp_avg": selected_grads,
                                    # "exp_avg_sq": selected_grads_sq,
                                }
                            )

                    elif self.growth_mode == 'momentum_neuron':
                        new_mask = self.momentum_neuron_growth(name, new_mask, total_regrowth, weight)

                    elif self.growth_mode == 'mix_growth':
                        new_mask = self.mix_growth(name, new_mask, total_regrowth, weight)

                    else:
                        raise ValueError(f"Unknown growth mode: {self.growth_mode}")

                    self.masks[name] = new_mask.to(weight.dtype)

        self.apply_mask()


    '''
                    REDISTRIBUTION
    '''
    def gather_statistics(self):
        self.name2nonzeros = {}
        self.name2zeros = {}
        self.name2variance = {}

        self.total_variance = 0.0
        self.total_removed = 0
        self.total_nonzero = 0
        self.total_zero = 0.0
        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name not in self.masks: continue
                mask = self.masks[name]

                self.name2nonzeros[name] = mask.float().sum().item()
                self.name2zeros[name] = mask.numel() - self.name2nonzeros[name]

                prune_rate = self.name2prune_rate[name]
                num_remove = math.ceil(prune_rate*self.name2nonzeros[name])

                self.total_removed += num_remove
                self.total_nonzero += self.name2nonzeros[name]
                self.total_zero += self.name2zeros[name]

    '''
                    prune
    '''
    def magnitude_increase(self, weight, mask, name): # lsw addition
        prune_rate = self.name2prune_rate[name]
        x, idx = torch.sort(torch.abs(weight.data.view(-1)))
        pruning_number = self.name2nonzeros[name] * prune_rate
        k = math.ceil(self.name2zeros[name] + pruning_number)
        threshold = x[k - 1].item()
        # magIN_num = (torch.abs(weight) > torch.abs(self.pre_tensor[name])).sum().item()
        # smaller_num = (torch.abs(weight) < torch.abs(self.pre_tensor[name])).sum().item()
        # bigThan_mean = (torch.abs(weight) > threshold).sum().item()
        # print('mag increase number', magIN_num/num_nonzero, 'threshold', bigThan_mean/num_nonzero)
        return (torch.abs(weight) > torch.abs(self.pre_tensor[name])) | (torch.abs(weight) > threshold)  # check if mask if right?

    def threshold_prune(self, mask, weight, name):
        return (torch.abs(weight.data) > self.threshold)

    def taylor_FO(self, mask, weight, name):

        num_remove = math.ceil(self.name2prune_rate[name] * self.name2nonzeros[name])
        num_zeros = self.name2zeros[name]
        k = math.ceil(num_zeros + num_remove)

        x, idx = torch.sort((weight.data * weight.grad).pow(2).flatten())
        mask.data.view(-1)[idx[:k]] = 0.0

        return mask


    def kernel_pruning(self, mask, weight, name):

        score = torch.clone(weight.grad * weight).detach().abs_()

        num_remove = math.ceil(self.name2prune_rate[name] * self.name2nonzeros[name])
        if num_remove == 0.0: return weight.data != 0.0
        #num_remove = math.ceil(self.name2prune_rate[name]*self.name2nonzeros[name])

        num_zeros = self.name2zeros[name]
        x, idx = torch.sort(score.data.view(-1))
        k = math.ceil(num_zeros + num_remove)
        mask.data.view(-1)[idx[:k]] = 0.0
        return mask

    def magnitude_prune(self, mask, weight, name):
        sparsity = self.name2zeros[name]/float(self.masks[name].numel())
        prune_rate = self.name2prune_rate[name]

        num_remove = math.ceil(prune_rate*self.name2nonzeros[name])
        if num_remove == 0.0: return weight.data != 0.0
        #num_remove = math.ceil(self.name2prune_rate[name]*self.name2nonzeros[name])
        num_zeros = self.name2zeros[name]

        x, idx = torch.sort(torch.abs(weight.data.view(-1)))
        n = idx.shape[0]
        num_nonzero = n-num_zeros

        k = math.ceil(num_zeros + num_remove)
        threshold = x[k-1].item()

        return (torch.abs(weight.data) > threshold)

    def magnitude_inverse_prune(self, mask, weight, name):
        prune_rate = self.name2prune_rate[name]
        current_mask = self.masks[name]

        # 获取当前非零权重
        active_weights = weight.data[current_mask.bool()]
        num_active = active_weights.numel()

        num_keep = math.ceil((1 - prune_rate) * num_active)
        if num_keep <= 0:
            return torch.zeros_like(current_mask)

        # 保留最小的部分
        threshold = torch.topk(torch.abs(active_weights), num_keep, largest=False).values.max()

        # 新 mask：只保留小于等于 threshold 的原始非零权重
        new_mask = current_mask.clone()
        new_mask[(torch.abs(weight.data) > threshold) & (current_mask.bool())] = 0
        return new_mask

    def magnitude_soft_prune(self, weight, name):
        """
        Soft magnitude pruning with temperature-scaled sampling.
        from Zhang et al. https://arxiv.org/abs/2501.19107

        To avoid errors:
        If the probability vector is too sparse to draw `num_to_stay`
        unique indices, the current mask is returned unchanged.
        """
        # take the absolute value of the masked weights
        matrix = torch.abs(weight * self.masks[name])

        num_active = self.masks[name].float().sum().item()
        num_to_stay = math.ceil(num_active * (1 - self.name2prune_rate[name]))

        flat_matrix = matrix.flatten()
        flat_matrix = torch.where(torch.isnan(flat_matrix), torch.zeros_like(flat_matrix), flat_matrix)
        flat_matrix = torch.where(torch.isinf(flat_matrix), torch.zeros_like(flat_matrix), flat_matrix)

        T = self.temperature_decay.get_current_value()
        flat_matrix = flat_matrix.float() ** T

        # define probabilities of weights to stay unpruned
        total = flat_matrix.sum()
        if total == 0:
            return self.masks[name].clone()
        probs = flat_matrix / total

        if probs.numel() > 2 ** 24:  # avoid CUDA limit of torch.multinomial
            # numpy handles > 2**24 (~16M) categories fine
            probs = probs.detach().cpu().numpy()

            if np.flatnonzero(probs).size < num_to_stay:
                # if not enough non-zero probs, return the original mask
                return self.masks[name].clone()

            keep_idx_np = np.random.choice(probs.size, size=num_to_stay, replace=False, p=probs)
            keep_idx = torch.from_numpy(keep_idx_np).to(weight.device, dtype=torch.long)
        else:
            if torch.nonzero(probs).squeeze().numel() < num_to_stay:
                # if not enough non-zero probs, return the original mask
                return self.masks[name].clone()

            keep_idx = torch.multinomial(probs, num_to_stay, replacement=False)

        new_mask = torch.zeros_like(weight, device=self.device)
        new_mask.view(-1)[keep_idx] = 1
        return new_mask

    def global_magnitude_prune(self):
        prune_rate = 0.0
        for name in self.name2prune_rate:
            if name in self.masks:
                prune_rate = self.name2prune_rate[name]
        tokill = math.ceil(prune_rate*self.baseline_nonzero)
        total_removed = 0
        prev_removed = 0
        while total_removed < tokill*(1.0-self.tolerance) or (total_removed > tokill*(1.0+self.tolerance)):
            total_removed = 0
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue
                    remain = (torch.abs(weight.data) > self.threshold).sum().item()
                    total_removed += self.name2nonzeros[name] - remain

            if prev_removed == total_removed: break
            prev_removed = total_removed
            if total_removed > tokill*(1.0+self.tolerance):
                self.threshold *= 1.0-self.increment
                self.increment *= 0.99
            elif total_removed < tokill*(1.0-self.tolerance):
                self.threshold *= 1.0+self.increment
                self.increment *= 0.99

        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                self.masks[name][:] = torch.abs(weight.data) > self.threshold

        return int(total_removed)


    def global_momentum_growth(self, total_regrowth):
        togrow = total_regrowth
        total_grown = 0
        last_grown = 0
        while total_grown < togrow*(1.0-self.tolerance) or (total_grown > togrow*(1.0+self.tolerance)):
            total_grown = 0
            total_possible = 0
            for module in self.modules:
                for name, weight in module.named_parameters():
                    if name not in self.masks: continue

                    new_mask = self.masks[name]
                    grad = self.get_momentum_for_weight(weight)
                    grad = grad*(new_mask==0).float()
                    possible = (grad !=0.0).sum().item()
                    total_possible += possible
                    grown = (torch.abs(grad.data) > self.growth_threshold).sum().item()
                    total_grown += grown
            print(total_grown, self.growth_threshold, togrow, self.growth_increment, total_possible)
            if total_grown == last_grown: break
            last_grown = total_grown


            if total_grown > togrow*(1.0+self.tolerance):
                self.growth_threshold *= 1.02
                #self.growth_increment *= 0.95
            elif total_grown < togrow*(1.0-self.tolerance):
                self.growth_threshold *= 0.98
                #self.growth_increment *= 0.95

        total_new_nonzeros = 0
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue

                new_mask = self.masks[name]
                grad = self.get_momentum_for_weight(weight)
                grad = grad*(new_mask==0).float()
                self.masks[name][:] = (new_mask.byte() | (torch.abs(grad.data) > self.growth_threshold)).float()
                total_new_nonzeros += new_mask.float().sum().item()
        return total_new_nonzeros


    def magnitude_and_negativity_prune(self, mask, weight, name):
        num_remove = math.ceil(self.name2prune_rate[name]*self.name2nonzeros[name])
        num_zeros = self.name2zeros[name]

        # find magnitude threshold
        # remove all weights which absolute value is smaller than threshold
        x, idx = torch.sort(weight[weight > 0.0].data.view(-1))
        k = math.ceil(num_remove/2.0)
        if k >= x.shape[0]:
            k = x.shape[0]

        threshold_magnitude = x[k-1].item()

        # find negativity threshold
        # remove all weights which are smaller than threshold
        x, idx = torch.sort(weight[weight < 0.0].view(-1))
        k = math.ceil(num_remove/2.0)
        if k >= x.shape[0]:
            k = x.shape[0]
        threshold_negativity = x[k-1].item()


        pos_mask = (weight.data > threshold_magnitude) & (weight.data > 0.0)
        neg_mask = (weight.data < threshold_negativity) & (weight.data < 0.0)


        new_mask = pos_mask | neg_mask
        return new_mask

    '''
                    GROWTH
    '''

    def random_growth_ori(self, name, new_mask, total_regrowth, weight):
        """
        This function implements the random growth strategy for sparse neural networks,
        which is used in algorithms like SET. It randomly
        selects zero-valued positions in the weight matrix to be activated, with the
        total number of new connections controlled by the total_regrowth parameter.
        """
        n = (new_mask==0).sum().item()
        if n == 0: return new_mask
        expeced_growth_probability = (total_regrowth/n)
        new_weights = torch.rand(new_mask.shape).cuda() < expeced_growth_probability #lsw
        # new_weights = torch.rand(new_mask.shape) < expeced_growth_probability
        new_mask_ = new_mask.byte() | new_weights
        if (new_mask_!=0).sum().item() == 0:
            new_mask_ = new_mask
        return new_mask_

    def random_growth(self, name, new_mask, total_regrowth, weight):
        """
        Randomly regrow connections at zero-valued positions in the mask.

        Returns:
            new_mask: Updated binary mask after regrowth.
            regrow_idx: Flattened indices of newly regrown connections.
        """
        flat_mask = new_mask.view(-1)
        zero_indices = (flat_mask == 0).nonzero(as_tuple=False).view(-1)

        num_candidates = zero_indices.numel()
        if num_candidates == 0 or total_regrowth == 0:
            return new_mask, torch.tensor([], dtype=torch.long, device=flat_mask.device)

        total_regrowth = min(total_regrowth, num_candidates)

        # 随机选出 regrow 的位置
        rand_idx = torch.randperm(num_candidates, device=flat_mask.device)[:total_regrowth]
        regrow_idx = zero_indices[rand_idx]

        # 更新 mask
        flat_mask[regrow_idx] = 1.0

        return new_mask, regrow_idx

    def momentum_growth(self, name, new_mask, total_regrowth, weight):
        grad = self.get_momentum_for_weight(weight)
        grad = grad*(new_mask==0).float()
        y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
        new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0

        return new_mask

    def kernel_gradient_growth(self, name, new_mask, total_regrowth, weight):
        grad = self.grads[name]
        grad = grad * (new_mask == 0).float()
        y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
        new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0

        return new_mask

    def gradient_growth(self, name, new_mask, total_regrowth, weight):
        """
        This function implements the gradient-based growth strategy for sparse neural networks,
        which is a key component of methods like RigL (Rigged Lottery Tickets). It prioritizes
        regrowth at zero-valued positions where gradients have the highest magnitude, indicating
        where new connections would have the most immediate impact on loss reduction.
        """
        grad = weight.grad.clone()
        grad = grad*(new_mask==0).float()

        y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
        regrow_idx = idx[:total_regrowth]
        new_mask.data.view(-1)[regrow_idx] = 1.0

        return new_mask, regrow_idx


    def gradient_growth_acc(self, name, new_mask, total_regrowth, weight):
        """
        Regrow connections based on accumulated gradient scores in self.momentum_dict.
        """
        if name not in self.momentum_dict:
            raise ValueError(f"Missing momentum entry for {name} during regrowth.")

        score = self.momentum_dict[name]["values"]
        indices = self.momentum_dict[name]["indices"]

        sorted_vals, sorted_pos = torch.sort(score, descending=True)

        topk_pos = sorted_pos[:total_regrowth]
        regrow_idx = indices[topk_pos]

        # Update new_mask
        new_mask.data.view(-1)[regrow_idx] = 1.0

        return new_mask, regrow_idx, topk_pos

    def mix_growth(self, name, new_mask, total_regrowth, weight):
        gradient_grow = int(total_regrowth * self.args.mix)
        random_grow = total_regrowth - gradient_grow
        grad = weight.grad.clone()
        grad = grad * (new_mask == 0).float()

        y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
        new_mask.data.view(-1)[idx[:gradient_grow]] = 1.0

        n = (new_mask == 0).sum().item()
        expeced_growth_probability = (random_grow / n)
        new_weights = torch.rand(new_mask.shape).cuda() < expeced_growth_probability
        new_mask = new_mask.bool() | new_weights

        return new_mask

    def momentum_neuron_growth(self, name, new_mask, total_regrowth, weight):
        grad = self.get_momentum_for_weight(weight)

        M = torch.abs(grad)
        if len(M.shape) == 2: sum_dim = [1]
        elif len(M.shape) == 4: sum_dim = [1, 2, 3]

        v = M.mean(sum_dim).data
        v /= v.sum()

        slots_per_neuron = (new_mask==0).sum(sum_dim)

        M = M*(new_mask==0).float()
        for i, fraction  in enumerate(v):
            neuron_regrowth = math.floor(fraction.item()*total_regrowth)
            available = slots_per_neuron[i].item()

            y, idx = torch.sort(M[i].flatten())
            if neuron_regrowth > available:
                neuron_regrowth = available
            threshold = y[-(neuron_regrowth)].item()
            if threshold == 0.0: continue
            if neuron_regrowth < 10: continue
            new_mask[i] = new_mask[i] | (M[i] > threshold)

        return new_mask

    '''
                UTILITY
    '''
    def get_momentum_for_weight(self, weight):
        if 'exp_avg' in self.optimizer.state[weight]:
            adam_m1 = self.optimizer.state[weight]['exp_avg']
            adam_m2 = self.optimizer.state[weight]['exp_avg_sq']
            grad = adam_m1/(torch.sqrt(adam_m2) + 1e-08)
        elif 'momentum_buffer' in self.optimizer.state[weight]:
            grad = self.optimizer.state[weight]['momentum_buffer']
        return grad


    def print_nonzero_counts(self):
        total_active = 0
        total_params = 0
        total_active_incl_bias = 0
        total_params_incl_bias = 0
        for module in self.modules:
            for name, tensor in module.named_parameters():
                total_params_incl_bias += tensor.numel()
                if name not in self.masks:
                    total_active_incl_bias += tensor.numel()
                    continue
                mask = self.masks[name]
                num_nonzeros = (mask != 0).sum().item()
                total_active += num_nonzeros
                total_active_incl_bias += num_nonzeros
                total_params += mask.numel()

                ## cal diff
                if hasattr(self, 'masks_pre') and self.masks_pre is not None:
                    current_mask = mask.to(torch.bool)
                    previous_mask = self.masks_pre[name].to(torch.bool)
                    diff = (current_mask != previous_mask).sum().item()
                    diff_ratio = diff / num_nonzeros if num_nonzeros > 0 else 0
                else:
                    diff = None
                    diff_ratio = None

                if name in self.name2nonzeros:
                    val = f"{name}: {self.name2nonzeros[name]}->{num_nonzeros}, density: {num_nonzeros/float(mask.numel()):.3f}"
                    val += f", changed: {diff}, change ratio: {diff_ratio:.4f}"
                    print(val)


        if self.args.single_gpu and self.args.wandb_used:
            wandb.log({
                "sparsity/density": total_active / total_params,
                "sparsity/density_incl_bias": total_active_incl_bias / total_params_incl_bias,
                "sparsity/total_active": total_active,
                "sparsity/total_active_incl_bias": total_active_incl_bias,
                "sparsity/total_params": total_params,
                "sparsity/total_params_incl_bias": total_params_incl_bias,
                "sparsity/density_decay": self.density_decay.get_current_value(),
                "sparsity/prune_rate": self.prune_rate_decay.get_current_value(),
                "sparsity/temperature": self.temperature_decay.get_current_value(),
            })

        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name not in self.masks: continue
                print(f'prune rate {name}: {self.name2prune_rate[name]}')
                break  # only print the first tensor, prune_rate is same across layers in our experiments

    def reset_momentum(self):
        """
        Taken from: https://github.com/AlliedToasters/synapses/blob/master/synapses/SET_layer.py
        Resets buffers from memory according to passed indices.
        When connections are reset, parameters should be treated
        as freshly initialized.
        """
        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name not in self.masks: continue
                mask = self.masks[name]
                weights = list(self.optimizer.state[tensor])
                for w in weights:
                    if w == 'momentum_buffer':
                        # momentum
                        if self.args.reset_mom_zero:
                            print('zero')
                            self.optimizer.state[tensor][w][mask == 0] = 0
                        else:
                            print('mean')
                            self.optimizer.state[tensor][w][mask==0] = torch.mean(self.optimizer.state[tensor][w][mask.byte()])
                        # self.optimizer.state[tensor][w][mask==0] = 0
                    elif w == 'square_avg' or \
                        w == 'exp_avg' or \
                        w == 'exp_avg_sq' or \
                        w == 'exp_inf':
                        # Adam
                        self.optimizer.state[tensor][w][mask==0] = torch.mean(self.optimizer.state[tensor][w][mask.byte()])

    def fired_masks_update(self):
        ntotal_fired_weights = 0.0
        ntotal_weights = 0.0
        layer_fired_weights = {}
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks: continue
                self.fired_masks[name] = self.masks[name].data.byte() | self.fired_masks[name].data.byte()
                ntotal_fired_weights += float(self.fired_masks[name].sum().item())
                ntotal_weights += float(self.fired_masks[name].numel())
                layer_fired_weights[name] = float(self.fired_masks[name].sum().item())/float(self.fired_masks[name].numel())
                print('Layerwise percentage of the fired weights of', name, 'is:', layer_fired_weights[name])
        total_fired_weights = ntotal_fired_weights/ntotal_weights
        print('The percentage of the total fired weights is:', total_fired_weights)
        return layer_fired_weights, total_fired_weights

    def set_prune_rate_decay(self):
        if self.args.prune_rate_decay == 'cosine':
            self.prune_rate_decay = CosineDecay(
                init_value=self.args.prune_rate,
                T_max=self.args.total_training_steps,
                eta_min=0.005,
            )
        elif self.args.prune_rate_decay == 'linear':
            self.prune_rate_decay = LinearDecay(
                init_value=self.args.prune_rate,
                final_value=0.005,
                num_steps=self.args.total_training_steps,
            )
        elif self.args.prune_rate_decay == 'WSD':
            self.prune_rate_decay = WSDDecay(
                init_value=self.args.prune_rate,
                total_steps=self.args.total_training_steps,
            )
        elif self.args.prune_rate_decay == 'constant':
            self.prune_rate_decay = ConstantDecay(self.args.prune_rate)
        else:
            raise Exception(f'Unknown prune_rate_decay mode: {self.args.prune_rate_decay}')

    def set_density_decay(self):
        if self.args.density_decay == 'cosine':
            self.density_decay = CosineDecay(
                init_value=self.args.initial_density,
                T_max=self.args.total_training_steps,
                eta_min=self.args.density,
            )
        elif self.args.density_decay == 'linear':
            self.density_decay = LinearDecay(
                init_value=self.args.initial_density,
                final_value=self.args.density,
                num_steps=self.args.total_training_steps,
            )
        elif self.args.density_decay == 'constant':
            self.density_decay = ConstantDecay(self.args.density)
        else:
            raise Exception(f'Unknown density_decay mode: {self.args.density_decay}')

    def set_temperature_decay(self):
        if self.args.temperature_decay == 'linear':
            self.temperature_decay = LinearDecay(
                init_value=self.args.init_temperature,
                final_value=self.args.temperature,
                num_steps=self.args.total_training_steps,
            )
        elif self.args.temperature_decay == 'constant':
            self.temperature_decay = ConstantDecay(self.args.temperature)
        else:
            raise Exception(f'Unknown temperature_decay mode: {self.args.temperature_decay}')

    def set_new_layer_densities(self):
        if self.args.density_decay != 'constant':
            total = 0
            total_active = 0
            for name in self.masks:
                total_active += self.masks[name].float().sum().item()
                total += self.masks[name].numel()

            prev_density = total_active / total
            new_density = self.density_decay.get_current_value()
            cur_dens_decay_factor = new_density / prev_density
            print(f'cur_dens_decay_factor: {cur_dens_decay_factor}  prev_density: {prev_density}  new_density: {new_density}  total: {total}  total_active: {total_active}')

            if self.args.single_gpu and self.args.wandb_used:
                wandb.log({
                    "sparsity/density_decay_factor": cur_dens_decay_factor,
                })

            for name in self.masks:
                self.name2density[name] = self.name2density[name] * cur_dens_decay_factor
                self.name2density[name] = min(self.name2density[name], 1)
                # assert 0 <= self.name2density[name] <= 1, \
                #     f'Density {self.name2density[name]} of layer {name} out of range [0, 1]'

    def reinit_weights_original_distribution(self):
        """Reinitialize pruned weights using the original initialization scheme."""
        for module in self.modules:
            for name, param in module.named_parameters():
                if name not in self.masks:
                    continue
                inactive_mask = (self.masks[name] == 0)

                embed = False
                if 'embed' in name:
                    embed = True

                with torch.no_grad():
                    temp_layer = torch.nn.Parameter(torch.empty_like(param))
                    weight_init(temp_layer, embedding=embed)
                    param.data[inactive_mask] = temp_layer.data[inactive_mask]
        # all weights have non-zero values now,
        # but this will be solved when we .apply_mask after growing new connections

    @torch.no_grad()
    def evaluate_model_core(self):
        """Evaluate the current model."""

        val_dir = '/c4_sampling/c4_filtered_validation_10M'
        val_data = datasets.load_dataset("arrow", data_dir=val_dir, split="validation", streaming=True)

        if not self.args.single_gpu:
            val_data = datasets.distributed.split_dataset_by_node(val_data, rank=self.global_rank, world_size=self.world_size)

        val_data_mapped = val_data.map(
            self.preprocess_batched,
            batched=True,
            remove_columns=["text"],  # ["text", "timestamp", "url"],
        )
        # val_data_mapped.batch = lambda batch_size: training_utils.batch_fn(val_data_mapped, batch_size)
        dataloader = torch.utils.data.DataLoader(
            val_data_mapped,
            batch_size=self.args.batch_size,
            collate_fn=default_data_collator,
        )

        target_eval_tokens = 10_000_000
        evaluated_on_tokens = 0
        total_loss = torch.tensor(0.0).to(self.device)
        total_batches = 1

        # for batch in val_data_mapped.batch(batch_size=batch_size):
        for batch in dataloader:
            if evaluated_on_tokens > target_eval_tokens:
                break
            total_batches += 1

            # batch = default_data_collator(batch)
            batch = {k: v.to(self.device) for k, v in batch.items()}
            labels = batch["input_ids"].clone()
            labels[labels == self.pad_idx] = -100

            # Standard, single model
            loss = self.modules[0](**batch, labels=labels).loss
            total_loss += loss.detach()

            evaluated_on_tokens += (batch["input_ids"] != self.pad_idx).sum().item() * self.world_size

        total_loss = total_loss / total_batches

        # Gather losses across all GPUs
        gathered_losses = [torch.zeros_like(total_loss) for _ in range(self.world_size)]
        dist.all_gather(gathered_losses, total_loss)
        total_loss = sum([t.item() for t in gathered_losses]) / self.world_size

        return total_loss, evaluated_on_tokens

    def update_optimizer_momenta_for_regrowth(self, param, regrow_indices, init_momenta={}):
        if init_momenta is None:
            init_momenta = {}

        optimizer_state = self.optimizer.state[param]

        for key in ['step']:  ## 'exp_avg', 'exp_avg_sq',

            if key not in optimizer_state:
                continue

            optimizer_params = optimizer_state[key]
            init = init_momenta.get(key, None)

            if init is not None:
                if isinstance(init, torch.Tensor):
                    init = init.to(dtype=optimizer_params.dtype)

                optimizer_params.view(-1)[regrow_indices] = init
            else:
                optimizer_params.view(-1)[regrow_indices] = 0.0



def weight_init(weight, embedding=False):
    """Initialize weights using the original initialization scheme."""
    if embedding:
        std_embedding = (2 / 5) ** 0.5  # approx 0.632
        weight.data.normal_(mean=0.0, std=std_embedding)
    else:
        # std = config.initializer_range
        std = 0.02  # default value
        weight.data.normal_(mean=0.0, std=std)


