import logging
import math

import numpy as np
import torch
import torch.nn as nn

from transformers import (
    TrainerCallback,
)
from accelerate.optimizer import AcceleratedOptimizer

from .layer import Linear
from .optimizer import SftAdamW, SftSM3

import matplotlib.pyplot as plt
import random, os

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def update_optimizer(
    optimizer,
    param,
    changing_indices,
    init_momenta={},
):
    """
    Updates optimizer state for a PEFT parameter tensor after dropping and regrowth.

    Args:
      - optimizer: the optimizer to update.
      - param: the parameter tensor.
      - changing_indices: the indices in the optimizer state that need to be updated.
      - init_momenta: dict mapping state keys to seed values. If not supplied, values
          are seeded to zero.
    """
    if isinstance(optimizer, AcceleratedOptimizer):
        optimizer = optimizer.optimizer

    optimizer_state = optimizer.state[param]
    for optim_aux in ['age', 'exp_avg', 'exp_avg_sq']:
        optimizer_params = optimizer_state[optim_aux]
        init = init_momenta.get(optim_aux, None)
        if init is not None:
            if isinstance(init, torch.Tensor):
                init = init.to(dtype=optimizer_params.dtype)
            optimizer_params[changing_indices] = init
        else:
            optimizer_params[changing_indices] = 0.0
    '''
    optimizer_params = optimizer_state['lr_factor']
    mask = torch.zeros_like(optimizer_params, dtype=torch.bool)
    mask[changing_indices] = True
    optimizer_params[mask] = 5.0
    optimizer_params[~mask] = 1.0
    '''

def update_lr_factor(
    optimizer,
    param,
    changing_indices,
    lr_factor
):
    """
    Updates optimizer state for a PEFT parameter tensor after dropping and regrowth.
    """
    if isinstance(optimizer, AcceleratedOptimizer):
        optimizer = optimizer.optimizer

    optimizer_state = optimizer.state[param]
    optimizer_params = optimizer_state['lr_factor']
    mask = torch.zeros_like(optimizer_params, dtype=torch.bool)
    mask[changing_indices] = True

    optimizer_params[mask] = lr_factor
    optimizer_params[~mask] = 1.0

class SftSelector:
    """
    Implements SFT tunable parameter reselection. Simply construct the SftSelector and call
    .step() after each training update step.
    """

    def __init__(
        self,
        model,
        optimizer,
        sft_config,
        total_update_steps,
        grad_accumulation_steps,
        output_dir='',
        completed_steps=0, # number of already completed steps if resuming from saved ckpt.
        cal_data=None,
        sparsity_ratio=0.6,
        merge_ratio=0.1
    ):
        self.model = model
        self.optimizer = optimizer
        self.sft_config = sft_config
        self.total_update_steps = total_update_steps
        self.grad_accumulation_steps = grad_accumulation_steps
        self.completed_steps = completed_steps
        self.begin_selection_phase()
        self.previous_incoming_params = {}

        self.indices_count = {}
        self.previous_indices = {}
        self.pre_remaining_values = {}
        self.pre_remaining_indices = {}
        self.remaining_diff = {}
        self.previous_delta_values={}

        '''
        self.output_dir = os.path.join(output_dir, 'pics')
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        '''
        self.changing_indices = {}
        self.cycle_step = 0
        self.warmup_steps = 5
        self.lr_factor_started = False
        self.round_steps = self.sft_config.reselection_steps + self.sft_config.selection_accumulation_steps

        self.samples_count = 0

        ## data
        self.cal_data = cal_data
        self.sparsity_ratio = sparsity_ratio
        self.merge_ratio = merge_ratio

    def step(self):
        self.completed_steps += 1
        self.cycle_step += 1

        # logger.info('completed_steps: {}'.format(self.completed_steps))
        
        #if self.sft_config.dst == 'fix':
        #   return


        if (self.completed_steps + 1) % self.sft_config.reselection_steps == 0:
            self.begin_selection_phase()
            self.lr_factor_started = False
        '''
        if (
                self.sft_config.selection_accumulation_steps + 1 <= self.completed_steps % self.sft_config.reselection_steps and
                self.completed_steps % self.sft_config.reselection_steps <= self.sft_config.selection_accumulation_steps + 5
        ):
            self.begin_selection_phase()
            
        else:
            # Remove hooks
            for n, m in self.model.named_modules():
                if (
                        isinstance(m, Linear) and
                        m.active_adapter is not None and
                        m.active_adapter in m.sft_delta
                ):
                    if any(param.requires_grad for param in m.parameters()):
                        m.apply_hook(None)
        '''

        if (
            self.completed_steps % self.sft_config.reselection_steps ==
            self.sft_config.selection_accumulation_steps
        ): 
            if False and self.completed_steps > self.sft_config.reselection_steps + self.sft_config.selection_accumulation_steps:
                self.prune_wanda()
            self.end_selection_phase()

            self.cycle_step = 0
            self.lr_factor_started = False

        if self.completed_steps > self.sft_config.reselection_steps and self.lr_factor_started:
            self.update_lr_factor()

            for module_name, changing_indices_ in self.changing_indices.items():
                m = self.model.get_submodule(module_name)
                delta = m.sft_delta[m.active_adapter]
                update_lr_factor(self.optimizer, delta.values, changing_indices_, self.lr_factor)


        '''
        if self.reselection_scores:
            for module_name, (candidate_indices, candidate_grads, candidate_grads_sq, candidate_samples, fixed_grads) in self.reselection_scores.items():
                logger.info('completed_steps: {}, samples_num: {}, candidate_grads_sq: {}, module_name: {}'.format(self.completed_steps, candidate_samples[:2], candidate_grads_sq[:2], module_name))
                break
        '''

        # print('#### sample count: {}'.format(self.samples_count))

    def update_lr_factor(self):
        '''
        if self.cycle_step < self.warmup_steps:
            # Warmup phase: increase lr_factor from 0.1 to 10
            self.lr_factor = 0.1 + (10 - 0.1) * (self.cycle_step / self.warmup_steps)
        else:
            # Decay phase: decrease lr_factor from 10 to 0.1
            total_decay_steps = self.round_steps - self.warmup_steps
            progress = self.cycle_step - self.warmup_steps
            self.lr_factor = 10 - (10 - 1) * (progress / total_decay_steps)
        '''
        self.lr_factor = 5.0

    def begin_selection_phase(self):
        if self.sft_config.selection_algorithm == "sm3":
            return

        logger.info('Beginning selection phase')
        self.reselection_scores = {}
        self.samples_count = 0
        # Apply hooks to gather gradients for growth selection
        for n, m in self.model.named_modules():
            if (
                isinstance(m, Linear) and
                m.active_adapter is not None and
                m.active_adapter in m.sft_delta
            ):
                if any(param.requires_grad for param in m.parameters()):
                    m.apply_hook(self.gradient_accumulation_hook(n))

    def end_selection_phase(self):
        logger.info('Ending selection phase')
        if self.completed_steps > self.total_update_steps:
            return

        if self.sft_config.selection_algorithm != "sm3":
            # Remove hooks
            for n, m in self.model.named_modules():
                if (
                    isinstance(m, Linear) and
                    m.active_adapter is not None and
                    m.active_adapter in m.sft_delta
                ):
                    if any(param.requires_grad for param in m.parameters()):
                        m.apply_hook(None)

        # Replace all parameters if it's the first reselection step, linear
        # decay from initial rate otherwise.
        if self.sft_config.reselection_rate_policy == "linear":
            if self.completed_steps == self.sft_config.selection_accumulation_steps:
                p = 1
            else:
                p = self.sft_config.initial_reselection_rate * (
                        1 - self.completed_steps / self.total_update_steps
                )
        elif self.sft_config.reselection_rate_policy == "cosine":
            p = self.sft_config.initial_reselection_rate * (
                    1 + math.cos(math.pi * self.completed_steps / self.total_update_steps)
            ) / 2

        elif self.sft_config.reselection_rate_policy == "constant":
            if self.completed_steps == self.sft_config.selection_accumulation_steps:
                p = 1
            else:
                p = self.sft_config.initial_reselection_rate

        else:
            raise ValueError(f'Unsupported reselection rate policy {self.sft_config.reselection_rate_policy}')

        logger.info('dst update_p: {}'.format(p))
        if p > 0:
            self.select(p)
        self.reselection_scores = {}
        self.samples_count = 0

    def select(self, p):
        algorithm = self.sft_config.selection_algorithm
        if algorithm == "sm3":
            self.select_sm3(p)
        elif algorithm == "rigl": #or (algorithm == "acc" and p == 1):
            self.select_rigl(p)
        elif algorithm == "acc" and self.sft_config.dst != 'fix':
            self.select_accumulative_rigl()
        elif algorithm == "acc" and self.sft_config.dst == 'fix':
            self.select_accumulative_rigl_fix()
        elif algorithm == "mag_soft":
            self.select_rigl_mag_soft(p)
        elif algorithm == "mest":
            self.select_mest(p)
        elif algorithm == "rc":
            self.select_rc(p)

        else:
            raise ValueError(
                f'Invalid selection method {algorithm}'
            )

        # print("selection_algorithm: {}".format(algorithm))


    def gradient_accumulation_hook(self, module_name):

        @torch.no_grad()
        def _gradient_accumulation_hook(grad):
            m = self.model.get_submodule(module_name)
            grad = grad.reshape(-1)

            ## sensitivity
            sensitivity = torch.abs(m.weight.view_as(grad) * grad) #torch.abs(m.weight.view_as(grad) * grad).clone()
            #sensitivity = torch.abs(m.weight.view_as(grad))           
 
            if module_name in self.reselection_scores:
                candidate_indices, candidate_grads, candidate_grads_sq, samples, fixed_indices, fixed_grads, sens_indices, candidate_sens = self.reselection_scores[module_name]
                new_grads = grad[candidate_indices]
                candidate_grads += new_grads
                candidate_grads_sq.addcmul_(new_grads, new_grads)
                #candidate_grads.mul_(0.9).add_(new_grads, alpha=(1.0 - 0.9))
                #candidate_grads_sq.mul_(0.999).addcmul_(new_grads, new_grads, value=1.0 - 0.999)
                samples += 1


                # Update fixed_grads with the current gradient of m.sft_delta[m.active_adapter].indices
                # fixed_grads += torch.abs(grad[m.sft_delta[m.active_adapter].indices])#.clone())
             
                fixed_grads += torch.abs(grad[fixed_indices])
                candidate_sens += sensitivity[sens_indices]
               
                self.reselection_scores[module_name] = (
                candidate_indices,
                candidate_grads,
                candidate_grads_sq,
                samples,
                fixed_indices,
                fixed_grads,
                sens_indices,
                candidate_sens
            )
            
            else:

                ## use the first batch for accumulation
                num_candidates = len(m.sft_delta[m.active_adapter].values)
                abs_grads = torch.abs(grad)


                if self.sft_config.dst == 'fix':
                    candidate_indices = m.sft_delta[m.active_adapter].indices
                    
                else:
                    # Proportion of randomly selected candidates
                    num_topk = num_candidates
    
                    if self.sft_config.mask_zeros or self.completed_steps < self.sft_config.selection_accumulation_steps:
                        mask = m.weight.view(-1) == 0
                        abs_grads[mask] = -1
                    _, topk_indices = torch.topk(
                        abs_grads,
                        num_topk,
                        largest=True,
                        sorted=False,
                    )
    
                    candidate_indices = topk_indices
    
                candidate_grads = grad[candidate_indices]#.clone()

                
                # Store the initial gradient of m.sft_delta[m.active_adapter].indices
                # fixed_grads = torch.abs(grad[m.sft_delta[m.active_adapter].indices])#.clone())
                
                ## need
                fixed_indices = m.sft_delta[m.active_adapter].indices 
                fixed_grads = torch.abs(grad[m.sft_delta[m.active_adapter].indices])
                
                # Retain the smallest 20% of fixed_grads and their indices
                num_to_keep = max(1, int(num_candidates * 1.0))
                # fixed_grads, topk_indices_in_fixed_grads_all = torch.topk(fixed_grads_all, num_to_keep, largest=False, sorted=True)
                # fixed_indices = m.sft_delta[m.active_adapter].indices[topk_indices_in_fixed_grads_all]
                
                ### generate sensity indices
                sensitivity[m.sft_delta[m.active_adapter].indices] = 0

                non_zero_indices = torch.where(sensitivity != 0)[0]
                non_zero_sens = sensitivity[non_zero_indices]

                kk = min(num_to_keep, non_zero_sens.numel())

                _, smallest_indices = torch.topk(
                    non_zero_sens,
                    kk,
                    largest=False,  ## True
                    sorted=False
                )
                sens_indices = non_zero_indices[smallest_indices]
                candidate_sens = sensitivity[sens_indices]
                

                '''
                ### generate sensity indices
                sensitivity[m.sft_delta[m.active_adapter].indices] = 0

                non_zero_indices = torch.where(sensitivity != 0)[0]
                non_zero_sens = sensitivity[non_zero_indices]
                _, smallest_indices = torch.topk(
                    non_zero_sens,
                    num_candidates,
                    largest=False,
                    sorted=False
                )
                sens_indices = non_zero_indices[smallest_indices]
                candidate_sens = sensitivity[sens_indices]
                '''

                self.reselection_scores[module_name] = (
                    candidate_indices.to(m.sft_delta[m.active_adapter].indices.dtype),
                    candidate_grads,
                    candidate_grads * candidate_grads,
                    torch.ones_like(candidate_grads, dtype=torch.int64),
                    fixed_indices,
                    fixed_grads,
                    sens_indices,
                    candidate_sens,
                    )


        return _gradient_accumulation_hook

    def active_sft_deltas(self):
        for n, m in self.model.named_modules():
            if (
                isinstance(m, Linear) and
                m.active_adapter is not None and
                m.active_adapter in m.sft_delta  ## a SparseDelta: including values and indices (delta)
            ):
                yield (
                    f'{n}.sft_delta.{m.active_adapter}',
                    m.sft_delta[m.active_adapter]
                )

    def update_statistics(self, current_indices, previous_indices, occurrence_counts):
        """
        """
        new_counts = torch.ones(len(current_indices), dtype=torch.long, device=current_indices.device)

        # Find positions of current_indices in previous_indices
        mask = torch.isin(current_indices, previous_indices)

        if mask.any():  # Check if there are any matches
            matching_indices = current_indices[mask]

            # Get the positions in previous_indices
            previous_positions = torch.searchsorted(previous_indices.sort()[0], matching_indices, side='left')
            new_counts[mask] += occurrence_counts[previous_positions]

        return new_counts

    @torch.no_grad()
    def select_accumulative_rigl(self):
        num_overlaps = 0
        total_indices = 0

        betas = {}
        for group in self.optimizer.param_groups:
            for p in group['params']:
                betas[p] = group['betas']

        for module_name, (
                candidate_indices,
                candidate_grads,
                candidate_grads_sq,
                candidate_samples,
                fixed_grads,
                _,
                sens_indices,
                candidate_sens
        ) in self.reselection_scores.items():
            m = self.model.get_submodule(module_name)
            delta = m.sft_delta[m.active_adapter]
            delta.merge(m.weight)
            delta.values.grad = None

            ### prune to target sparsity
            # self.prune2target(m, sens_indices, candidate_sens, delta.indices)            

            _, new_candidate_indices = torch.topk(
                torch.abs(candidate_sens),
                len(delta.values),
                largest=True,
                sorted=False,
            )
            new_params = candidate_indices[new_candidate_indices]
            new_samples = candidate_samples[new_candidate_indices]
            new_grads = candidate_grads[new_candidate_indices]
            new_grads_sq = candidate_grads_sq[new_candidate_indices]

            is_old_param = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=delta.indices.device,
            )
            is_old_param[delta.indices] = True
            is_new_param = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=delta.indices.device,
            )
            is_new_param[new_params] = True
            is_incoming = (~is_old_param)[new_params]
            is_leaving = (~is_new_param)[delta.indices]

            delta.indices[is_leaving] = new_params[is_incoming]
            delta.values.zero_()

            changing_indices = torch.nonzero(is_leaving).squeeze(1)
            new_samples = new_samples[is_incoming]
            new_grads = new_grads[is_incoming]
            new_grads_sq = new_grads_sq[is_incoming]

            new_grads /= new_samples
            new_grads_sq /= new_samples
            new_ages = new_samples / self.grad_accumulation_steps
            beta1, beta2 = betas[delta.values]
            new_grads *= (1.0 - beta1 ** new_ages)
            new_grads_sq *= (1.0 - beta2 ** new_ages)
            update_optimizer(
                self.optimizer,
                delta.values,
                changing_indices,
                init_momenta={
                    'age': new_ages,
                    'exp_avg': new_grads,
                    'exp_avg_sq': new_grads_sq,
                }
            )

            self.changing_indices[module_name] = changing_indices

            is_remaining_param = is_old_param & is_new_param
            num_overlaps += torch.sum(is_remaining_param).item()
            total_indices += delta.indices.numel()

        logger.info(f'Replacement overlap: {100*num_overlaps/total_indices:.4f}%')

        self.reselection_scores = {}

    @torch.no_grad()
    def select_accumulative_rigl_fix(self):
        num_overlaps = 0
        total_indices = 0

        betas = {}
        for group in self.optimizer.param_groups:
            for p in group['params']:
                betas[p] = group['betas']

        for module_name, (
                candidate_indices,
                candidate_grads,
                candidate_grads_sq,
                candidate_samples,
                fixed_grads,
                _
        ) in self.reselection_scores.items():
            m = self.model.get_submodule(module_name)
            delta = m.sft_delta[m.active_adapter]
            delta.merge(m.weight)
            delta.values.grad = None

            _, new_candidate_indices = torch.topk(
                torch.abs(candidate_grads),
                len(delta.values),
                largest=True,
                sorted=False,
            )
            new_params = candidate_indices[new_candidate_indices]
            new_samples = candidate_samples[new_candidate_indices]
            new_grads = candidate_grads[new_candidate_indices]
            new_grads_sq = candidate_grads_sq[new_candidate_indices]

            is_old_param = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=delta.indices.device,
            )
            is_old_param[delta.indices] = True
            is_new_param = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=delta.indices.device,
            )
            is_new_param[new_params] = True
            is_incoming = torch.arange(len(delta.values), device=candidate_indices.device)
            is_leaving = torch.arange(len(delta.values), device=candidate_indices.device)

            delta.values.zero_()

            ## is_incoming, is_leaving is the indices for delta.values(small mlp)
            changing_indices = is_leaving
            new_samples = new_samples[is_incoming]
            new_grads = new_grads[is_incoming]
            new_grads_sq = new_grads_sq[is_incoming]

            #new_grads /= new_samples
            #new_grads_sq /= new_samples
            new_ages = new_samples / self.grad_accumulation_steps
            #beta1, beta2 = betas[delta.values]
            #new_grads *= (1.0 - beta1 ** new_ages)
            #new_grads_sq *= (1.0 - beta2 ** new_ages)
            update_optimizer(
                self.optimizer,
                delta.values,
                changing_indices,
                init_momenta={
                    'age': 1,
                    'exp_avg': new_grads,
                    'exp_avg_sq': new_grads_sq,
                }
            )

            is_remaining_param = is_old_param & is_new_param
            num_overlaps += torch.sum(is_remaining_param).item()
            total_indices += delta.indices.numel()

        logger.info(f'Replacement overlap: {100*num_overlaps/total_indices:.4f}%')

        self.reselection_scores = {}

    @torch.no_grad()
    def select_rigl(self, change_proportion):
        #n_replacements = 0
        #total_params = 0
        #n_outgoing_is_incoming = 0

        betas = {}

        intersection_info_per_layer = {}
        fixed_grads_per_layer = {}

        for group in self.optimizer.param_groups:
            for p in group['params']:
                betas[p] = group['betas']

        for module_name, (
            candidate_indices,
            candidate_grads,
            candidate_grads_sq,
            candidate_samples,
            fixed_indices,
            fixed_grads,
            sens_indices,
            candidate_sens
        ) in self.reselection_scores.items():
            m = self.model.get_submodule(module_name)
            delta = m.sft_delta[m.active_adapter]

            ### prune to target sparsity
            num_to_reallocate = int(len(delta.values) * change_proportion)

            if self.sft_config.pruned:
                self.prune2target(m, sens_indices, candidate_sens, delta.indices, num_to_reallocate, self.sft_config.initial_reselection_rate)

            # fixed_grads = fixed_grads / candidate_samples[0]

            delta.values.grad = None

            ## merge select
            if self.sft_config.sel_merge and self.completed_steps >= 0:
                num_merge = self.select_merge(m, fixed_grads, num_to_reallocate, change_proportion, fixed_indices)
                num_to_reallocate += num_merge


            # Find the k deltas with smallest absolute values
            _, changing_indices = torch.topk(
                torch.abs(delta.values),
                num_to_reallocate,
                largest=False,
                sorted=True,
            )
            outgoing_params = delta.indices[changing_indices]
            # binary mask of weights to drop
            is_outgoing = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=outgoing_params.device,
            )
            is_outgoing[outgoing_params] = True
            assert torch.sum(is_outgoing) == num_to_reallocate
            # binary mask of currently active weights
            is_current = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=delta.indices.device,
            )
            is_current[delta.indices] = True
            # weights that will stil be active after dropping
            is_remaining = is_current & ~is_outgoing

            # don't consider growing any already active candidate
            is_valid_candidate = ~is_remaining[candidate_indices]
            candidate_indices = candidate_indices[is_valid_candidate]
            candidate_grads = candidate_grads[is_valid_candidate]
            candidate_grads_sq = candidate_grads_sq[is_valid_candidate]
            candidate_samples = candidate_samples[is_valid_candidate]
            candidate_scores = torch.abs(candidate_grads)

            # Proportion for random selection
            num_topk = num_to_reallocate

            # take the top k growth candidates with highest gradient magnitudes
            best_scores, topk_indices = torch.topk(
                candidate_scores,
                num_topk,
                largest=True,
                sorted=True,
            )

            best_candidate_indices = topk_indices

            incoming_params = candidate_indices[best_candidate_indices]
            incoming_grads = candidate_grads[best_candidate_indices]
            incoming_grads_sq = candidate_grads_sq[best_candidate_indices]
            incoming_samples = candidate_samples[best_candidate_indices]
            # binary mask of weights to grow
            is_incoming = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=incoming_params.device,
            )
            is_incoming[incoming_params] = True

            # filter out weights which have been selected to be dropped and
            # grown simultaneously
            assert torch.sum(is_incoming) == len(best_candidate_indices)
            outgoing_is_incoming = is_incoming[outgoing_params]
            changing_indices = changing_indices[~outgoing_is_incoming]
            incoming_is_outgoing = is_outgoing[incoming_params]
            assert torch.sum(outgoing_is_incoming) == torch.sum(incoming_is_outgoing)
            incoming_params = incoming_params[~incoming_is_outgoing]
            incoming_grads = incoming_grads[~incoming_is_outgoing]
            incoming_grads_sq = incoming_grads_sq[~incoming_is_outgoing]
            incoming_samples = incoming_samples[~incoming_is_outgoing]
            changing_indices = changing_indices[:len(incoming_params)]


            # update delta indices and values
            delta.indices[changing_indices] = incoming_params.to(delta.indices.dtype)
            delta.values[changing_indices] = 0.0


            # seed the optimizer momenta appropriately
            incoming_grads /= incoming_samples
            incoming_grads_sq /= incoming_samples
            incoming_ages = incoming_samples / self.grad_accumulation_steps
            beta1, beta2 = betas[delta.values]


            # bias counter-correction: these are unbiased estimates of the momenta,
            # so bias them in order that they will be unbiased after Adam's bias
            # correction
            incoming_grads *= (1.0 - beta1 ** incoming_ages)
            incoming_grads_sq *= (1.0 - beta2 ** incoming_ages)
            update_optimizer(
                self.optimizer,
                delta.values,
                changing_indices,
                init_momenta={
                    'age': incoming_ages,
                    'exp_avg': incoming_grads,
                    'exp_avg_sq': incoming_grads_sq,
                }
            )


    @torch.no_grad()
    def select_rigl_mag_soft(self, change_proportion):
        n_replacements = 0
        total_params = 0

        betas = {}
        for group in self.optimizer.param_groups:
            for p in group['params']:
                betas[p] = group['betas']

        for module_name, (
                candidate_indices,
                candidate_grads,
                candidate_grads_sq,
                candidate_samples
        ) in self.reselection_scores.items():
            m = self.model.get_submodule(module_name)
            delta = m.sft_delta[m.active_adapter]
            delta.values.grad = None

            num_to_reallocate = int(len(delta.values) * change_proportion)

            if num_to_reallocate <= 0:
                break

            # Use weight_magnitude_soft to select changing_indices
            score_drop = torch.abs(delta.values)
            T = 1 + self.completed_steps * (2 / self.total_update_steps)
            flat_matrix = (score_drop.flatten()) ** T
            probabilities = flat_matrix / flat_matrix.sum()

            # Sample indices according to probabilities
            changing_indices = torch.multinomial(probabilities, num_to_reallocate, replacement=False)

            outgoing_params = delta.indices[changing_indices]
            # binary mask of weights to drop
            is_outgoing = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=outgoing_params.device,
            )
            is_outgoing[outgoing_params] = True  ## remove: mask=1
            assert torch.sum(is_outgoing) == num_to_reallocate
            # binary mask of currently active weights
            is_current = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=delta.indices.device,
            )
            is_current[delta.indices] = True
            # weights that will stil be active after dropping
            is_remaining = is_current & ~is_outgoing

            # don't consider growing any already active candidate
            is_valid_candidate = ~is_remaining[candidate_indices]
            candidate_indices = candidate_indices[is_valid_candidate]
            candidate_grads = candidate_grads[is_valid_candidate]
            candidate_grads_sq = candidate_grads_sq[is_valid_candidate]
            candidate_samples = candidate_samples[is_valid_candidate]
            candidate_scores = torch.abs(candidate_grads)
            # take the top k growth candidates with highest gradient magnitudes
            best_scores, best_candidate_indices = torch.topk(
                candidate_scores,
                min(num_to_reallocate, len(candidate_grads)),
                largest=True,
                sorted=True,
            )
            incoming_params = candidate_indices[best_candidate_indices]
            incoming_grads = candidate_grads[best_candidate_indices]
            incoming_grads_sq = candidate_grads_sq[best_candidate_indices]
            incoming_samples = candidate_samples[best_candidate_indices]
            # binary mask of weights to grow
            is_incoming = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=incoming_params.device,
            )
            is_incoming[incoming_params] = True

            # filter out weights which have been selected to be dropped and
            # grown simultaneously
            assert torch.sum(is_incoming) == len(best_candidate_indices)
            outgoing_is_incoming = is_incoming[outgoing_params]
            changing_indices = changing_indices[~outgoing_is_incoming]
            incoming_is_outgoing = is_outgoing[incoming_params]
            assert torch.sum(outgoing_is_incoming) == torch.sum(incoming_is_outgoing)
            incoming_params = incoming_params[~incoming_is_outgoing]
            incoming_grads = incoming_grads[~incoming_is_outgoing]
            incoming_grads_sq = incoming_grads_sq[~incoming_is_outgoing]
            incoming_samples = incoming_samples[~incoming_is_outgoing]
            changing_indices = changing_indices[:len(incoming_params)]

            n_replacements += len(changing_indices)
            total_params += len(delta.indices)

            # update delta indices and values
            delta.indices[changing_indices] = incoming_params.to(delta.indices.dtype)
            delta.values[changing_indices] = 0.0  ## new regrow weights start from 0

            # seed the optimizer momenta appropriately
            incoming_grads /= incoming_samples
            incoming_grads_sq /= incoming_samples
            incoming_ages = incoming_samples / self.grad_accumulation_steps
            beta1, beta2 = betas[delta.values]
            # bias counter-correction: these are unbiased estimates of the momenta,
            # so bias them in order that they will be unbiased after Adam's bias
            # correction
            incoming_grads *= (1.0 - beta1 ** incoming_ages)
            incoming_grads_sq *= (1.0 - beta2 ** incoming_ages)
            update_optimizer(
                self.optimizer,
                delta.values,
                changing_indices,
                init_momenta={
                    'age': incoming_ages,
                    'exp_avg': incoming_grads,
                    'exp_avg_sq': incoming_grads_sq,
                }
            )

        logger.info(
            f'Replacing {n_replacements} ({100*n_replacements/total_params:.4f}%)'
        )

    @torch.no_grad()
    def select_mest(self, change_proportion):
        n_replacements = 0
        total_params = 0

        betas = {}
        for group in self.optimizer.param_groups:
            for p in group['params']:
                betas[p] = group['betas']

        for module_name, (
                candidate_indices,
                candidate_grads,
                candidate_grads_sq,
                candidate_samples
        ) in self.reselection_scores.items():
            m = self.model.get_submodule(module_name)
            delta = m.sft_delta[m.active_adapter]

            num_to_reallocate = int(len(delta.values) * change_proportion)

            # Check if delta.values.grad is None
            if delta.values.grad is None:
                grad_abs = torch.zeros_like(delta.values)  # Create a zero tensor if grad is None
            else:
                grad_abs = torch.abs(delta.values.grad)  # Otherwise, calculate the absolute value of the gradient
                
            print("grad_abs:{}".format(grad_abs))
            # Calculate score_drop with the adjusted gradient term
            score_drop = torch.abs(delta.values) + 0.01 * grad_abs

            # Use torch.topk to select the indices with the smallest scores for reallocation
            _, changing_indices = torch.topk(
                score_drop.view(-1),
                num_to_reallocate,
                largest=False,
                sorted=True,
            )

            delta.values.grad = None

            outgoing_params = delta.indices[changing_indices]
            # binary mask of weights to drop
            is_outgoing = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=outgoing_params.device,
            )
            is_outgoing[outgoing_params] = True  ## remove: mask=1
            assert torch.sum(is_outgoing) == num_to_reallocate
            # binary mask of currently active weights
            is_current = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=delta.indices.device,
            )
            is_current[delta.indices] = True
            # weights that will stil be active after dropping
            is_remaining = is_current & ~is_outgoing

            # don't consider growing any already active candidate
            is_valid_candidate = ~is_remaining[candidate_indices]
            candidate_indices = candidate_indices[is_valid_candidate]
            candidate_grads = candidate_grads[is_valid_candidate]
            candidate_grads_sq = candidate_grads_sq[is_valid_candidate]
            candidate_samples = candidate_samples[is_valid_candidate]
            candidate_scores = torch.abs(candidate_grads)
            # take the top k growth candidates with highest gradient magnitudes
            best_scores, best_candidate_indices = torch.topk(
                candidate_scores,
                min(num_to_reallocate, len(candidate_grads)),
                largest=True,
                sorted=True,
            )
            incoming_params = candidate_indices[best_candidate_indices]
            incoming_grads = candidate_grads[best_candidate_indices]
            incoming_grads_sq = candidate_grads_sq[best_candidate_indices]
            incoming_samples = candidate_samples[best_candidate_indices]
            # binary mask of weights to grow
            is_incoming = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=incoming_params.device,
            )
            is_incoming[incoming_params] = True

            # filter out weights which have been selected to be dropped and
            # grown simultaneously
            assert torch.sum(is_incoming) == len(best_candidate_indices)
            outgoing_is_incoming = is_incoming[outgoing_params]
            changing_indices = changing_indices[~outgoing_is_incoming]
            incoming_is_outgoing = is_outgoing[incoming_params]
            assert torch.sum(outgoing_is_incoming) == torch.sum(incoming_is_outgoing)
            incoming_params = incoming_params[~incoming_is_outgoing]
            incoming_grads = incoming_grads[~incoming_is_outgoing]
            incoming_grads_sq = incoming_grads_sq[~incoming_is_outgoing]
            incoming_samples = incoming_samples[~incoming_is_outgoing]
            changing_indices = changing_indices[:len(incoming_params)]

            n_replacements += len(changing_indices)
            total_params += len(delta.indices)

            # update delta indices and values
            delta.indices[changing_indices] = incoming_params.to(delta.indices.dtype)
            delta.values[changing_indices] = 0.0  ## new regrow weights start from 0

            # seed the optimizer momenta appropriately
            incoming_grads /= incoming_samples
            incoming_grads_sq /= incoming_samples
            incoming_ages = incoming_samples / self.grad_accumulation_steps
            beta1, beta2 = betas[delta.values]
            # bias counter-correction: these are unbiased estimates of the momenta,
            # so bias them in order that they will be unbiased after Adam's bias
            # correction
            incoming_grads *= (1.0 - beta1 ** incoming_ages)
            incoming_grads_sq *= (1.0 - beta2 ** incoming_ages)
            update_optimizer(
                self.optimizer,
                delta.values,
                changing_indices,
                init_momenta={
                    'age': incoming_ages,
                    'exp_avg': incoming_grads,
                    'exp_avg_sq': incoming_grads_sq,
                }
            )

        logger.info(
            f'Replacing {n_replacements} ({100*n_replacements/total_params:.4f}%)'
        )

    @torch.no_grad()
    def select_rc(self, change_proportion):
        n_replacements = 0
        total_params = 0

        betas = {}
        for group in self.optimizer.param_groups:
            for p in group['params']:
                betas[p] = group['betas']

        for module_name, (
                candidate_indices,
                candidate_grads,
                candidate_grads_sq,
                candidate_samples
        ) in self.reselection_scores.items():
            m = self.model.get_submodule(module_name)
            delta = m.sft_delta[m.active_adapter]

            num_to_reallocate = int(len(delta.values) * change_proportion)

            eplison = 0.00001
            score_drop = torch.abs(delta.values) / torch.sum(torch.abs(delta.values) + eplison, dim=0) + \
                         torch.abs(delta.values) / torch.sum(torch.abs(delta.values) + eplison, dim=1).reshape(-1, 1)

            # Use torch.topk to select the indices with the smallest scores for reallocation
            _, changing_indices = torch.topk(
                score_drop.view(-1),
                num_to_reallocate,
                largest=False,
                sorted=True,
            )

            delta.values.grad = None

            outgoing_params = delta.indices[changing_indices]
            # binary mask of weights to drop
            is_outgoing = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=outgoing_params.device,
            )
            is_outgoing[outgoing_params] = True  ## remove: mask=1
            assert torch.sum(is_outgoing) == num_to_reallocate
            # binary mask of currently active weights
            is_current = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=delta.indices.device,
            )
            is_current[delta.indices] = True
            # weights that will stil be active after dropping
            is_remaining = is_current & ~is_outgoing

            # don't consider growing any already active candidate
            is_valid_candidate = ~is_remaining[candidate_indices]
            candidate_indices = candidate_indices[is_valid_candidate]
            candidate_grads = candidate_grads[is_valid_candidate]
            candidate_grads_sq = candidate_grads_sq[is_valid_candidate]
            candidate_samples = candidate_samples[is_valid_candidate]
            candidate_scores = torch.abs(candidate_grads)
            # take the top k growth candidates with highest gradient magnitudes
            best_scores, best_candidate_indices = torch.topk(
                candidate_scores,
                min(num_to_reallocate, len(candidate_grads)),
                largest=True,
                sorted=True,
            )
            incoming_params = candidate_indices[best_candidate_indices]
            incoming_grads = candidate_grads[best_candidate_indices]
            incoming_grads_sq = candidate_grads_sq[best_candidate_indices]
            incoming_samples = candidate_samples[best_candidate_indices]
            # binary mask of weights to grow
            is_incoming = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=incoming_params.device,
            )
            is_incoming[incoming_params] = True

            # filter out weights which have been selected to be dropped and
            # grown simultaneously
            assert torch.sum(is_incoming) == len(best_candidate_indices)
            outgoing_is_incoming = is_incoming[outgoing_params]
            changing_indices = changing_indices[~outgoing_is_incoming]
            incoming_is_outgoing = is_outgoing[incoming_params]
            assert torch.sum(outgoing_is_incoming) == torch.sum(incoming_is_outgoing)
            incoming_params = incoming_params[~incoming_is_outgoing]
            incoming_grads = incoming_grads[~incoming_is_outgoing]
            incoming_grads_sq = incoming_grads_sq[~incoming_is_outgoing]
            incoming_samples = incoming_samples[~incoming_is_outgoing]
            changing_indices = changing_indices[:len(incoming_params)]

            n_replacements += len(changing_indices)
            total_params += len(delta.indices)

            # update delta indices and values
            delta.indices[changing_indices] = incoming_params.to(delta.indices.dtype)
            delta.values[changing_indices] = 0.0  ## new regrow weights start from 0

            # seed the optimizer momenta appropriately
            incoming_grads /= incoming_samples
            incoming_grads_sq /= incoming_samples
            incoming_ages = incoming_samples / self.grad_accumulation_steps
            beta1, beta2 = betas[delta.values]
            # bias counter-correction: these are unbiased estimates of the momenta,
            # so bias them in order that they will be unbiased after Adam's bias
            # correction
            incoming_grads *= (1.0 - beta1 ** incoming_ages)
            incoming_grads_sq *= (1.0 - beta2 ** incoming_ages)
            update_optimizer(
                self.optimizer,
                delta.values,
                changing_indices,
                init_momenta={
                    'age': incoming_ages,
                    'exp_avg': incoming_grads,
                    'exp_avg_sq': incoming_grads_sq,
                }
            )

        logger.info(
            f'Replacing {n_replacements} ({100*n_replacements/total_params:.4f}%)'
        )
    
    @torch.no_grad()
    def select_sm3(self, p):
        n_replacements = 0
        total_params = 0

        for _, delta in self.active_sft_deltas():
            num_to_reallocate = int(len(delta.values) * p)
            _, changing_indices = torch.topk(
                torch.abs(delta.values),
                num_to_reallocate,
                largest=False,
                sorted=True,
            )

            is_current = torch.zeros(
                [delta.dense_numel],
                dtype=torch.bool,
                device=delta.indices.device,
            )
            is_current[delta.indices] = True
            is_valid_candidate = ~is_current

            optimizer_state = self.optimizer.state[delta.values]
            row_grads_sq = optimizer_state['accumulator_0']
            col_grads_sq = optimizer_state['accumulator_1']
            # take outer product of row and column wise SM3 buffers
            # (here we assume 2D parameter tensors).
            estimated_momenta = torch.outer(row_grads_sq, col_grads_sq)
            estimated_momenta = estimated_momenta.view(-1)[is_valid_candidate]
            candidate_indices = torch.arange(
                0,
                delta.dense_numel,
                device=is_valid_candidate.device,
            )
            candidate_indices = candidate_indices[is_valid_candidate]
            _, best_candidate_indices = torch.topk(
                estimated_momenta,
                num_to_reallocate,
                largest=True,
                sorted=False,
            )
            incoming_params = candidate_indices[best_candidate_indices]

            n_replacements += len(changing_indices)
            total_params += len(delta.indices)

            delta.indices[changing_indices] = incoming_params.to(delta.indices.dtype)
            delta.values[changing_indices] = 0.0

        logger.info(
            f'Replacing {n_replacements} ({100*n_replacements/total_params:.4f}%)'
        )

    @torch.no_grad()
    def prune2target(self, m, sensity_indices, sensity_candidates, delta_indices, num_to_reallocate, init_p):

        flat_weight = m.weight.data.view(-1)
        num_zeros_target = int(flat_weight.numel() * self.sparsity_ratio)
        
        # Get non-zero indices in flat_weight
        non_zero_indices = torch.nonzero(torch.abs(flat_weight) > 1e-6, as_tuple=True)[0]
        combined_non_zero_indices = torch.unique(torch.cat((non_zero_indices, delta_indices)))

        num_non_zeros_current = combined_non_zero_indices.numel()
        num_zeros_current = flat_weight.numel() - num_non_zeros_current

        prune_decay = self.completed_steps / self.total_update_steps
        # prune_decay = 0.5 * (1 - math.cos(math.pi * self.completed_steps / self.total_update_steps))
        num_zeros_needed = int((num_zeros_target - num_zeros_current)*prune_decay)

        # reallocate_num = int(len(delta_indices) * init_p)
        num_zeros_needed = min(num_zeros_needed, num_to_reallocate)

        # Calculate how many more zeros are needed to reach the target
        if num_zeros_current < num_zeros_target and num_zeros_current > 0: # and self.completed_steps > self.sft_config.selection_accumulation_steps:

            logger.info('Beginning prune2target: {}, {}, {}, sensity_candidates: {}'.format(num_zeros_current, num_zeros_target, num_zeros_needed, len(sensity_indices)))

            # Find the indices of the smallest `num_zeros_needed` elements in `sensity_candidates`
            if num_zeros_needed > 0:
                num_zeros_needed = min(num_zeros_needed, sensity_candidates.numel())
                _, smallest_indices = torch.topk(sensity_candidates, num_zeros_needed, largest=False)
                target_indices = sensity_indices[smallest_indices]

                flat_weight[target_indices] = flat_weight.new_zeros(1)

        num_zeros_after = (m.weight.data == 0).sum().item()
        non_zero_indices_after = torch.nonzero(torch.abs(flat_weight) > 1e-6, as_tuple=True)[0]
        combined_non_zero_after = torch.unique(torch.cat((non_zero_indices_after, delta_indices))).numel()
        curr_sparsity = 1 - combined_non_zero_after/flat_weight.numel()
        logger.info('Before prune: {}, {}, {}, After prune: {}, Current sparsity: {:.4f}'.format(num_zeros_current, num_zeros_target, num_zeros_needed, num_zeros_after, curr_sparsity))

    @torch.no_grad()
    def select_merge(self, m, fixed_grads, num_to_reallocate, p, fixed_grads_indices):

        # logger.info('Beginning merge phase')

        delta = m.sft_delta[m.active_adapter]
        num_candidate_merge = int(len(delta.values)) - num_to_reallocate

        # num_merge = int(self.sft_config.initial_reselection_rate*len(delta.values)) - num_to_reallocate

        if num_candidate_merge <= 0 or p < 0.01: #or num_merge <= 0:
            return 0

        num_merge = int(num_candidate_merge * self.merge_ratio)

        values_abs = torch.abs(delta.values)
        _, candidate_indices = torch.topk(
            values_abs,
            num_candidate_merge,
            largest=True,
            sorted=True,
        )

        gradients_abs = torch.abs(fixed_grads[candidate_indices])
        _, merge_sub_indices = torch.topk(
            gradients_abs,
            num_merge,
            largest=False,
            sorted=True,
        )

        merge_indices = candidate_indices[merge_sub_indices]

        merge_indices_ori = delta.indices[merge_indices]
        nonzero_indices = (torch.abs(m.weight.data.view(-1)) > 0).nonzero(as_tuple=True)[0]
        zero_indices = (torch.abs(m.weight.data.view(-1)) <= 0).nonzero(as_tuple=True)[0]
        num_from_nonzeros = torch.isin(merge_indices_ori, nonzero_indices).sum().item()
        num_from_zeros = torch.isin(merge_indices_ori, zero_indices).sum().item()
        logger.info(f'merge to non-zero: {num_from_nonzeros}, to zero: {num_from_zeros}, total: {len(merge_indices)}')

        delta.merge(m.weight, indices_merge=merge_indices)

        delta.values.grad = None
        delta.values[merge_indices] = 0.0

        '''
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError(
                        "Adam does not support sparse gradients, please consider SparseAdam instead")

                state = self.optimizer.state[p]
                state["age"] = torch.ones_like(p, dtype=self.optimizer.momentum_dtype)
                state["exp_avg"] = torch.zeros_like(p, dtype=self.optimizer.momentum_dtype)
                state["exp_avg_sq"] = torch.zeros_like(p, dtype=self.optimizer.momentum_dtype)
        '''
        return num_merge


    @torch.no_grad()
    def find_layers(self, module, layers=[nn.Linear], name=''):

        if isinstance(module, Linear):
            '''
            if hasattr(module, 'active_adapter') and hasattr(module, 'sft_delta'):
                if module.active_adapter not in module.sft_delta:
                    return {name: module}
            elif not hasattr(module, 'active_adapter') or not hasattr(module, 'sft_delta'):
                # If the layer doesn't have the required attributes, include it by default
                return {name: module}
            '''
            return {name: module}

        res = {}
        for name1, child in module.named_children():
            res.update(self.find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
        return res

    @torch.no_grad()
    def prepare_calibration_input(self, model, dataloader, device):
        layers = model.model.model.layers

        # dev = model.hf_device_map["model.embed_tokens"]
        if "model.embed_tokens" in model.hf_device_map:
            device = model.hf_device_map["model.embed_tokens"]

        dtype = next(iter(model.parameters())).dtype
        inps = torch.zeros((128, 2048, 4096), dtype=dtype, device=device)
        inps.requires_grad = False
        cache = {'i': 0, 'attention_mask': None, "position_ids": None, "cache_position": None}

        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps[cache['i']] = inp
                cache['i'] += 1
                cache['attention_mask'] = kwargs['attention_mask']
                cache['position_ids'] = kwargs['position_ids']
                cache['cache_position'] = kwargs['cache_position']
                raise ValueError

        layers[0] = Catcher(layers[0])
        for batch in dataloader:
            try:
                model(batch[0].to(device))
            except ValueError:
                pass
        layers[0] = layers[0].module

        outs = torch.zeros_like(inps)
        attention_mask = cache['attention_mask']
        position_ids = cache['position_ids']
        cache_position = cache['cache_position']

        return inps, outs, attention_mask, position_ids, cache_position

    @torch.no_grad()
    def prune_wanda(self):
        nsamples = 128

        ### merge delta to weights
        weights_sum_dict = {}
        for module_name, items_values in self.reselection_scores.items():

            m = self.model.get_submodule(module_name)
            delta = m.sft_delta[m.active_adapter]
            weights_sum_dict[module_name] = m.weight.view(-1)[delta.indices[1000:1005]]

            values = delta.values.to(m.weight.dtype)
            # values = torch.full_like(delta.values, 5, dtype=m.weight.dtype)
            m.weight.view(-1).scatter_reduce_(0, delta.indices.long(), values, "sum", include_self=True)
            # delta.merge(m.weight)

        device = torch.device("cuda:0")
        print("loading calibdation data")

        with torch.no_grad():
            inps, outs, attention_mask, position_ids, cache_position = self.prepare_calibration_input(self.model, self.cal_data, device)

        layers = self.model.model.model.layers
        for i in range(len(layers)):
            layer = layers[i]
            subset = self.find_layers(layer)

            # print("layer: {}, subset: {}".format(i, subset))
            wrapped_layers = {}
            for name in subset:
                wrapped_layers[name] = WrappedGPT(subset[name])
                #print(f"pruning layer {i} name {name}")

            
            def add_batch(name):
                def tmp(_, inp, out):
                    wrapped_layers[name].add_batch(inp[0].data, out.data)

                return tmp

            handles = []
            for name in wrapped_layers:
                handles.append(subset[name].register_forward_hook(add_batch(name)))
            for j in range(nsamples):
                with torch.no_grad():
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position)[0]
            for h in handles:
                h.remove()

            for name in subset:
                W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(
                    wrapped_layers[name].scaler_row.reshape((1, -1)))

                W_mask = (torch.zeros_like(W_metric) == 1)  ## initialize a mask to be all False

                sort_res = torch.sort(W_metric, dim=-1, stable=True)

                current_zeros = (subset[name].weight.data == 0).sum().item()
                num_elements = subset[name].weight.data.numel()
                current_sparsity = current_zeros / num_elements

                prune_decay = self.completed_steps / self.total_update_steps
                target_sparsity = min(0.75, current_sparsity + (0.75 - current_sparsity) * prune_decay)

                # Unstructured pruning by selecting the smallest `num_zeros_needed` elements
                indices = sort_res[1][:, :int(W_metric.shape[1] * target_sparsity)]
                W_mask.scatter_(1, indices, True)

                subset[name].weight.data[W_mask] = 0  ## set weights to zero

                after_zeros = (subset[name].weight.data == 0).sum().item()
                after_sparsity = after_zeros / num_elements
                print(f"pruning layer {i} name {name}, after sparisty: {after_sparsity}")

            for j in range(nsamples):
                with torch.no_grad():
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position)[0]
            inps, outs = outs, inps

        ### unmerge delta to weights
        for module_name, items_values in self.reselection_scores.items():

            m = self.model.get_submodule(module_name)
            delta = m.sft_delta[m.active_adapter]
            weights_sum_merge = m.weight.view(-1)[delta.indices[1000:1005]]

            # values = torch.full_like(delta.values, 5, dtype=m.weight.dtype)
            values = delta.values.to(m.weight.dtype)
            # m.weight.data += 0.005
            m.weight.view(-1).scatter_reduce_(0, delta.indices.long(), -values, "sum", include_self=True)
            # m.weight.view(-1).index_add_(0, delta.indices.long(), values)
            # delta.unmerge(m.weight)

            weights_sum_unmerge = m.weight.view(-1)[delta.indices[1000:1005]]            
            print(f"{module_name}, weights_ori: {weights_sum_dict[module_name]}; weights_sum_merge: {weights_sum_merge}; weights_sum_unmerge: {weights_sum_unmerge}, values mean: {values[1000:1005]}, indics: {delta.indices[1000:1005]}")
            

class SelectorStepCallback(TrainerCallback):

    def __init__(self, trainer):
        self.trainer = trainer
        self.step_completed = 0

    def on_step_end(self, args, state, control, **kwargs):
        self.trainer._selector.step()

        '''
        self.step_completed += 1
        print("step: {}".format(self.step_completed))

        # Change the optimizer and dataset every 5 epochs
        if self.step_completed > 10:
            print(f"Epoch {self.step_completed}: Changing dataset")

            # Switch the dataset between C4 and SFT
            self.trainer.train_dataset = self.trainer.switch_dataset
            print("Switched to SFT dataset.")

            # Update the dataloader using the built-in method
            self.trainer.train_dataloader = self.trainer.get_train_dataloader()
            print("after train_dataloader len: {}".format(len(self.trainer.train_dataloader)))
        else:
            print("before train_dataloader len: {}".format(len(self.trainer.train_dataloader)))
        '''

def SftTrainer(_Trainer):
    """
    Wraps a Trainer or subclass thereof for SFT training. The resulting class
    should be constructed with a SftModel as the model and passing sft_config as
    a SftConfig instance.
    """

    class _SftTrainer(_Trainer):

        def __init__(
            self,
            *args,
            total_update_steps=None,
            sft_config=None,
            cal_data=None,
            **kwargs
        ):
            super().__init__(*args, **kwargs)
            logger.setLevel(self.args.get_process_log_level())

            if sft_config is None:
                raise ValueError('Missing sft_config')
            self.sft_config = sft_config
            '''
            if self.args.max_steps > 0:
                max_steps = self.args.max_steps
            else:
                train_dataloader = self.get_train_dataloader()
                len_dataloader = len(train_dataloader)
                num_update_steps_per_epoch = (
                    len_dataloader //
                    self.args.gradient_accumulation_steps
                )
                num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
                max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
            '''

            max_steps = total_update_steps
            print('total_update_steps: {}'.format(max_steps))
            self._selector = SftSelector(
                self.model,
                self.create_optimizer(),
                self.sft_config,
                max_steps,
                self.args.gradient_accumulation_steps,
                output_dir=self.args.output_dir,
                cal_data=cal_data,
                sparsity_ratio=self.args.sparsity_ratio,
                merge_ratio=self.args.merge_ratio
            )
            self.add_callback(SelectorStepCallback(self))

        def create_optimizer(self):
            if self.optimizer is None:
                optimizer_grouped_parameters = [
                    {
                        "params": [
                            p for n, p in self.model.named_parameters() if p.requires_grad
                        ],
                        "weight_decay": self.args.weight_decay,
                    },
                ]

                _, optimizer_kwargs = _Trainer.get_optimizer_cls_and_kwargs(self.args)
                logger.info(f'optimizer_kwargs: {optimizer_kwargs}')

                if self.sft_config.selection_algorithm == "sm3":
                    deltas = {
                        delta.values: delta
                        for _1, _2, delta in self.model.active_deltas()
                    }

                    self.optimizer = SftSM3(
                        optimizer_grouped_parameters,
                        deltas,
                        **optimizer_kwargs
                    )
                else:
                    self.optimizer = SftAdamW(optimizer_grouped_parameters, **optimizer_kwargs)

            return self.optimizer


    return _SftTrainer

# Define WrappedGPT class
class WrappedGPT:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.nsamples = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        self.scaler_row *= self.nsamples / (self.nsamples+tmp)
        self.nsamples += tmp

        inp = inp.type(torch.float32)
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2  / self.nsamples

