import torch
import torch.nn as nn
import typing, warnings

from torch_pruning.pruner.importance import OBDCImportance

from .scheduler import linear_scheduler
from ..import function
from ... import ops, dependency

class MetaPruner:
    def __init__(
        self,
        # Basic
        model: nn.Module, # a simple pytorch model
        example_inputs: torch.Tensor, # a dummy input for graph tracing. Should be on the same 
        importance: typing.Callable, # tp.importance.Importance for group importance estimation
        global_pruning: bool = False,
        pruning_ratio: float = 0.5,  # channel/dim pruning ratio, also known as pruning ratio
        pruning_ratio_dict: typing.Dict[typing.Union[nn.Module, typing.Tuple[nn.Module]], float] = None,
        max_pruning_ratio: float = 1.0, # maximum pruning ratio. useful if over-pruning happens.
        iterative_steps: int = 1,  # for iterative pruning
        iterative_pruning_ratio_scheduler: typing.Callable = linear_scheduler, # scheduler for iterative pruning.
        ignored_layers: typing.List[nn.Module] = None, # ignored layers
        round_to: int = None,  # round channels to the nearest multiple of round_to
        isomorphic: bool = False,

        # Advanced
        in_channel_groups: typing.Dict[nn.Module, int] = dict(), # The number of channel groups for layer input
        out_channel_groups: typing.Dict[nn.Module, int] = dict(), # The number of channel groups for layer output
        num_heads: typing.Dict[nn.Module, int] = dict(), # The number of heads for multi-head attention
        prune_num_heads: bool = False, # remove entire heads in multi-head attention
        prune_head_dims: bool = True, # remove head dimensions in multi-head attention
        head_pruning_ratio: float = 0.0, # head pruning ratio
        head_pruning_ratio_dict: typing.Dict[nn.Module, float] = None, # layer-specific head pruning ratio
        customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None, # pruners for customized layers. E.g., {nn.Linear: my_linear_pruner}
        unwrapped_parameters: typing.Dict[nn.Parameter, int] = None, # unwrapped nn.Parameters & pruning_dims. For example, {ViT.pos_emb: 0}
        root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM],  # root module for each group
        forward_fn: typing.Callable = None, # a function to execute model.forward
        output_transform: typing.Callable = None, # a function to transform network outputs
        
        # deprecated
        channel_groups: typing.Dict[nn.Module, int] = dict(), # channel grouping
        ch_sparsity: float = None,
        ch_sparsity_dict: typing.Dict[nn.Module, float] = None, 
    ):
        self.model = model
        self.importance = importance

        if ch_sparsity is not None:
            warnings.warn("ch_sparsity is deprecated in v1.3.0. Please use pruning_ratio.")
            pruning_ratio = ch_sparsity
        if ch_sparsity_dict is not None:
            warnings.warn("ch_sparsity_dict is deprecated in v1.3.0. Please use pruning_ratio_dict instead.")
            pruning_ratio_dict = ch_sparsity_dict

        self.pruning_ratio = pruning_ratio
        self.pruning_ratio_dict = pruning_ratio_dict if pruning_ratio_dict is not None else {}
        self.max_pruning_ratio = max_pruning_ratio
        self.global_pruning = global_pruning
        self.isomorphic = isomorphic
        
        if len(channel_groups) > 0:
            warnings.warn("channel_groups is deprecated. Please use in_channel_groups and out_channel_groups instead.")
            out_channel_groups.update(channel_groups)
        
        if len(num_heads) > 0:
            out_channel_groups.update(num_heads)

        self.in_channel_groups = in_channel_groups
        self.out_channel_groups = out_channel_groups
        self.root_module_types = root_module_types
        self.round_to = round_to

        # MHA
        self.num_heads = num_heads
        self.prune_num_heads = prune_num_heads
        self.prune_head_dims = prune_head_dims
        self.head_pruning_ratio = head_pruning_ratio

        ###############################################
        # Ignored layers and submodules
        self.ignored_layers = []
        self.ignored_params = []
        if ignored_layers is not None:
            for layer in ignored_layers:
                if isinstance(layer, nn.Module):
                    self.ignored_layers.extend(list(layer.modules()))
                elif isinstance(layer, nn.Parameter):
                    self.ignored_params.append(layer)

        ###############################################
        # Build dependency graph
        self.DG = dependency.DependencyGraph().build_dependency(
            model,
            example_inputs=example_inputs,
            forward_fn=forward_fn,
            output_transform=output_transform,
            unwrapped_parameters=unwrapped_parameters,
            customized_pruners=customized_pruners,
            ignored_params=self.ignored_params,
        )

        ###############################################
        # Iterative pruning
        # The pruner will prune the model iteratively for several steps to achieve the target pruning ratio
        # E.g., if iterative_steps=5, pruning_ratio=0.5, the pruning ratio of each step will be [0.1, 0.2, 0.3, 0.4, 0.5]
        self.iterative_steps = iterative_steps
        self.iterative_pruning_ratio_scheduler = iterative_pruning_ratio_scheduler
        self.current_step = 0
        # channel pruning ratio for each iterative step
        self.per_step_pruning_ratio = self.iterative_pruning_ratio_scheduler(
            self.pruning_ratio, self.iterative_steps
        )
        self.per_step_head_pruning_ratio = self.iterative_pruning_ratio_scheduler(
            self.head_pruning_ratio, self.iterative_steps
        )

        ###############################################
        # Ranking Scopes
        # We will perform ranking within each scope. 
        # If a scope only contains one layer, then we do local pruning
        # If a scope contains multiple layers, then global ranking will be applied to the entire scope
        # To manually specify the ranking scope, you can use pass a key-value pair to the pruning_ratio_dict, with a tuple of modules as the key.
        self._layer_to_scope = {}
        self._scope_initial_channels = {} # initial channels for different scope. It will be filled during the first pruning step.
        

        ###############################################
        # Layer-specific pruning ratios. Will cover the global ratio if specified
        # The key of the dict can be a single module or a tuple of modules. The pruning ratio will be shared by all modules in the tuple.
        self.pruning_ratio_dict = {}
        user_defined_scope_id = 0
        if pruning_ratio_dict is not None:
            for modules in pruning_ratio_dict:
                ratio = pruning_ratio_dict[modules]

                if isinstance(modules, tuple):
                    scope = modules # will scan all modules sequentially
                else:
                    scope = [modules] # only one model, do local pruning for this module

                scope_name = f"_User_Defined_Scope_{user_defined_scope_id}"
                local_pruning_scope_postfix = 0
                for m in scope:
                    for submodule in m.modules():
                        prunable_types = tuple([ops.type2class(
                            prunable_type) for prunable_type in self.DG.REGISTERED_PRUNERS.keys()])
                        if isinstance(submodule, prunable_types):
                            if isinstance(submodule, nn.Module):
                                if not self.global_pruning:
                                    self._layer_to_scope[submodule] = (scope_name+f"_{local_pruning_scope_postfix}", scope)
                                    local_pruning_scope_postfix+=1 # assign each layer to a unique scope if local pruning
                                else:
                                    self._layer_to_scope[submodule] = (scope_name, scope) # assign all layers to this scope

                            self.pruning_ratio_dict[submodule] = self.iterative_pruning_ratio_scheduler(
                                ratio, self.iterative_steps
                            )
                user_defined_scope_id+=1

        # Head pruning ratio
        self.head_pruning_ratio_dict = {}
        if head_pruning_ratio_dict is not None:
            for module in head_pruning_ratio_dict:
                ratio = head_pruning_ratio_dict[module]
                for submodule in module.modules():
                    prunable_types = tuple([ops.type2class(
                        prunable_type) for prunable_type in self.DG.REGISTERED_PRUNERS.keys()])
                    if isinstance(submodule, prunable_types):
                        self.head_pruning_ratio_dict[submodule] = self.iterative_pruning_ratio_scheduler(
                            ratio, self.iterative_steps
                        )

        ###############################################
        # Detect group convs & group norms
        for m in self.model.modules():
            layer_pruner = self.DG.get_pruner_of_module(m)
            in_ch_group = layer_pruner.get_in_channel_groups(m)
            out_ch_group = layer_pruner.get_out_channel_groups(m)
            if isinstance(m, ops.TORCH_CONV) and m.groups == m.out_channels:
                continue
            if in_ch_group > 1:
                self.in_channel_groups[m] = in_ch_group
            if out_ch_group > 1:
                self.out_channel_groups[m] = out_ch_group
            
        ###############################################
        # Initial channels/dims of each layer
        self.layer_init_out_ch = {}
        self.layer_init_in_ch = {}
        self.init_num_heads = {}
        for m in self.DG.module2node.keys():
            if ops.module2type(m) in self.DG.REGISTERED_PRUNERS:
                self.layer_init_out_ch[m] = self.DG.get_out_channels(m)
                self.layer_init_in_ch[m] = self.DG.get_in_channels(m)
                if m in self.num_heads:
                    self.init_num_heads[m] = self.num_heads[m]
        
        ###############################################
        # Count the number of total channels at initialization
        #if self.global_pruning:
        initial_total_channels = 0
        initial_total_heads = 0
        for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
            group = self._downstream_node_as_root_if_attention(group)
            initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) )
            for dep, _ in group:
                if dep.target.module in self.num_heads and self.DG.is_out_channel_pruning_fn(dep.handler):
                    initial_total_heads += self.num_heads[dep.target.module]
                    break # only count heads once
        self.initial_total_channels = initial_total_channels
        self.initial_total_heads = initial_total_heads
        

    def step(self, interactive=False)-> typing.Union[typing.Generator, None]:
        self.current_step += 1
        if interactive: # yield groups for interactive pruning
            return self._prune() 
        else:
            for group in self._prune():
                group.prune()

    def manual_prune(self, layer, pruning_fn, pruning_ratios_or_idxs):
        if isinstance(pruning_ratios_or_idxs, float):
            if self.DG.is_out_channel_pruning_fn(pruning_fn):
                prunable_channels = self.DG.get_out_channels(layer)
            else:
                prunable_channels = self.DG.get_in_channels(layer)
            full_group = self.DG.get_pruning_group(layer, pruning_fn, list(range(prunable_channels)))
            imp = self.estimate_importance(full_group)
            imp_argsort = torch.argsort(imp)
            n_pruned = int(prunable_channels * (1 - pruning_ratios_or_idxs))
            pruning_idxs = imp_argsort[:n_pruned]
 
        group = self.DG.get_pruning_group(layer, pruning_fn, pruning_idxs)
        group.prune()

    def estimate_importance(self, group) -> torch.Tensor:
        return self.importance(group)

    def pruning_history(self) -> typing.List[typing.Tuple[str, bool, typing.Union[list, tuple]]]:
        return self.DG.pruning_history()

    def load_pruning_history(self, pruning_history) -> None:
        self.DG.load_pruning_history(pruning_history)

    def get_target_pruning_ratio(self, module, step=-1) -> float:
        if step<0: step = self.current_step
        s = self.pruning_ratio_dict.get(module, self.per_step_pruning_ratio)[step]
        return min(s, self.max_pruning_ratio)

    def get_target_head_pruning_ratio(self, module) -> float:
        s = self.head_pruning_ratio_dict.get(module, self.per_step_head_pruning_ratio)[self.current_step]
        return min(s, 1)

    def reset(self) -> None:
        self.current_step = 0

    def update_regularizer(self) -> None:
        pass

    def regularize(self, model, loss) -> typing.Any:
        """ Model regularizer for sparse training
        """
        pass

    def _check_pruning_ratio(self, group) -> bool:
        for dep, _ in group:
            module = dep.target.module
            pruning_fn = dep.handler
            if dep.target.type == ops.OPTYPE.PARAMETER:
                continue
            if self.DG.is_out_channel_pruning_fn(pruning_fn):
                layer_out_ch = self.DG.get_out_channels(module)
                if layer_out_ch is None: continue
                if layer_out_ch < self.layer_init_out_ch[module] * (
                    1 - self.max_pruning_ratio
                ) or layer_out_ch == 1:
                    return False

            elif self.DG.is_in_channel_pruning_fn(pruning_fn):
                layer_in_ch = self.DG.get_in_channels(module)
                if layer_in_ch is None: continue
                if layer_in_ch < self.layer_init_in_ch[module] * (
                    1 - self.max_pruning_ratio
                ) or layer_in_ch == 1:
                    return False
        return True

    def _is_attn_group(self, group) -> bool:
        is_attn = False
        qkv_layers = []
        for dep, _ in group:
            module = dep.target.module
            pruning_fn = dep.handler
            if self.DG.is_out_channel_pruning_fn(pruning_fn) and module in self.num_heads:
                qkv_layers.append(module)
                is_attn = True
        return is_attn, qkv_layers

    def _get_channel_groups(self, group) -> int:
        ch_groups = 1
        #has_unbind = False
        #unbind_node = None

        for dep, _ in group:
            module = dep.target.module
            pruning_fn = dep.handler
            channel_groups = self.out_channel_groups if self.DG.is_out_channel_pruning_fn(pruning_fn) else self.in_channel_groups

            if module in channel_groups:
                ch_groups = channel_groups[module]

            #if dep.source.type==ops.OPTYPE.UNBIND:
            #    has_unbind = True
            #    unbind_node = dep.source

        #if has_unbind and ch_groups>1:
        #    ch_groups = ch_groups // len(unbind_node.outputs) 
        return ch_groups  # no channel grouping

    def _downstream_node_as_root_if_attention(self, group):
        # Use a downstream node as the root if torch.unbind exists.
        is_attention = False
        downstream_dep = None
        for _dep, _idxs in group:
            if _dep.source.module in self.num_heads and self.DG.is_out_channel_pruning_fn(_dep.handler):
                is_attention = True
            if isinstance(_dep.target.module, tuple(self.root_module_types)) and self.DG.is_in_channel_pruning_fn(_dep.handler):
                downstream_dep = _dep
                idxs = _idxs
        if is_attention and downstream_dep is not None: # use a downstream node as the root node for attention layers
            group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, idxs)
        return group

    def _round_to(self, n_pruned, current_channels, round_to):
        rounded_channels = current_channels - n_pruned
        rounded_channels = rounded_channels - rounded_channels % round_to
        n_pruned = current_channels - rounded_channels
        return max(n_pruned, 0)

    def _prune(self) -> typing.Generator:

        if self.current_step > self.iterative_steps:
            warnings.warn("Pruning exceed the maximum iterative steps, no pruning will be performed.")
            return
        
        DEFAULT_SCOPE = "DEFAULT_SCOPE"
        ATTN_HEAD_SCOPE = "ATTN_HEAD_SCOPE"

        ranking_scope = {DEFAULT_SCOPE: [], ATTN_HEAD_SCOPE: {}} # ATTN_HEAD_SCOPE will be a dict, because we need to index these groups later
        ##############################################
        # 1. Pre-compute importance for each group and assign them to different scopes
        ############################################## 

        
        for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
            if self._check_pruning_ratio(group):    
                
                # Compute raw importance score
                group = self._downstream_node_as_root_if_attention(group) # use a downstream node as the root node for attention layers
                ch_groups = self._get_channel_groups(group)
                imp = self.estimate_importance(group) # raw importance score
                group_size = len(imp) // ch_groups
                if imp is None: continue
                if ch_groups > 1:
                    # Corresponding elements of each group will be removed together.
                    # So we average importance across groups here. For example:
                    # imp = [1, 2, 3, 4, 5, 6] with ch_groups=2.
                    # We have two groups [1,2,3] and [4,5,6].
                    # The average importance should be [(1+4)/2, (2+5)/2, (3+6)/2] = [2.5, 3.5, 4.5]
                    dim_imp = imp.view(ch_groups, -1).mean(dim=0) 
                else:
                    # no grouping
                    dim_imp = imp

                # Importance scores for Attention Heads
                _is_attn, qkv_layers = self._is_attn_group(group)
                if _is_attn and self.prune_num_heads and self.get_target_head_pruning_ratio(qkv_layers[0])>0:
                    # average importance of each group. For example:
                    # the importance score of the group
                    # imp = [1, 2, 3, 4, 5, 6] with num_heads=2
                    # Note: head1 = [1, 2, 3], head2 = [4, 5, 6]
                    # the average importance is [(1+2+3)/3, (4+5+6)/3] = [2, 5]
                    head_imp = imp.view(ch_groups, -1).mean(1) # average importance by head.
                    ranking_scope[ATTN_HEAD_SCOPE][group] = (qkv_layers, head_imp)
                

                # Scope 1: User-defined scope, such as layer-wise pruning_ratios
                is_user_defined_scope = False
                for dep, _ in group:
                    for module, pruning_fn in zip([dep.source.module, dep.target.module], [dep.trigger, dep.handler]):
                        if module in self._layer_to_scope and self.DG.is_out_channel_pruning_fn(pruning_fn):
                            scope_name, scope = self._layer_to_scope[module]
                            if len(scope)>0:
                                pruning_ratio = self.get_target_pruning_ratio(module, step=self.current_step)
                                record = (group, ch_groups, group_size, pruning_ratio, dim_imp)
                                if scope_name not in ranking_scope:
                                    ranking_scope[scope_name] = []
                                ranking_scope[scope_name].append(record)
                                is_user_defined_scope = True
                        if is_user_defined_scope: break
                    if is_user_defined_scope: break
                if is_user_defined_scope:   
                    continue
                
                record = (group, ch_groups, group_size, self.per_step_pruning_ratio[self.current_step], dim_imp) # otherwise, use the default pruning ratio
                # Scope 2: Isomorphic Pruning 
                if self.isomorphic:
                    scope_name = "Isomorphic_" # we transform the graph structure into a string tag for easy comparison
                    for dep, _ in group: # if isomorphic, the source and target modules should have the same **layer type** and **pruning function**
                        source = "%s_%s"%(type(dep.source.module), "out" if self.DG.is_out_channel_pruning_fn(dep.handler) else "in")
                        target = "%s_%s"%(type(dep.target.module), "out" if self.DG.is_out_channel_pruning_fn(dep.handler) else "in")
                        scope_name += "%s_%s"%(source, target)
                    if scope_name not in ranking_scope:
                        # New isomorphic group
                        ranking_scope[scope_name] = []
                    ranking_scope[scope_name].append(record)

                elif self.global_pruning: # Scope 3: use the default scope for global pruning
                    ranking_scope[DEFAULT_SCOPE].append(record)
        
                else: # Scope 4: always create a new scope if local pruning
                    module_name = self.DG._module2name[group[0][0].source.module]
                    ranking_scope[module_name] = [ record ]

        if len(ranking_scope[DEFAULT_SCOPE]) == 0 and len(ranking_scope[ATTN_HEAD_SCOPE])==0 and len(ranking_scope)<=2:
            return
        
        ##############################################
        # 2. Thresholding by ranking all importance scores within each scope
        ##############################################

        # Find the threshold for the Multi-head attention scope
        if len(ranking_scope[ATTN_HEAD_SCOPE])>0:
            concat_head_imp = torch.cat([local_imp[-1] for local_imp in ranking_scope[ATTN_HEAD_SCOPE].values()], dim=0)
            target_head_pruning_ratio = self.per_step_head_pruning_ratio[self.current_step]
            n_heads_removed = len(concat_head_imp) - int(
                self.initial_total_heads *
                (1 - target_head_pruning_ratio)
            )
            if n_heads_removed>0:
                topk_head_imp, _ = torch.topk(concat_head_imp, k=n_heads_removed, largest=False)
                head_thres = topk_head_imp[-1]

        width_pruning_scope_names = [ k for k in ranking_scope.keys() if k!=ATTN_HEAD_SCOPE]
        #for name in width_pruning_scope_names: # truncate the name if lenth exceeds 10
        #    print(f"Ranking Scope: {name[:50]} Scope Size={len(ranking_scope[name])}")
        #    if len(ranking_scope[name])>0:
        #        for i in range(len(ranking_scope[name])):
        #            print(ranking_scope[name][i][0], ranking_scope[name][i][-2])
        # Handle other scopes for width pruning.
        
        for scope_id, scope_name in enumerate(width_pruning_scope_names):

            if not self.global_pruning:
                assert len(ranking_scope[scope_name])<=1, "Internal Error: local pruning should only contain less than one layer per scope."

            records = ranking_scope[scope_name] # records[i] -> (group, ch_groups, group_size, pruning_ratio, dim_imp)_i
            # Find the threshold for pruning
            if len(records)>0:
                concat_imp = torch.cat([local_imp[-1] for local_imp in records], dim=0) # concatenate importance scores in this scope
                target_pruning_ratio = records[0][-2] # records[i] -> (group, ch_groups, group_size, pruning_ratio, dim_imp)_i
                if scope_name not in self._scope_initial_channels:
                    self._scope_initial_channels[scope_name] = len(concat_imp)

                n_pruned = len(concat_imp) - int(
                    self._scope_initial_channels[scope_name] *
                    (1 - target_pruning_ratio)
                )
                
                if n_pruned>0:
                    topk_imp, topk_indices = torch.topk(concat_imp, k=n_pruned, largest=False)
                    thres = topk_imp[-1]

                    ##############################################
                    # 3. Pruning in each scope
                    ##############################################
                    for group, ch_groups, group_size, target_pruning_ratio, imp in records:
                        module = group[0].dep.target.module
                        pruning_fn = group[0].dep.handler
                        get_channel_fn = self.DG.get_out_channels if self.DG.is_out_channel_pruning_fn(pruning_fn) else self.DG.get_in_channels
                        
                        # Prune feature dims/channels
                        pruning_indices = []
                        if len(records)>0 and n_pruned>0:
                            if ch_groups > 1: # re-compute importance for each channel group if grouping is enabled
                                if self.global_pruning: # for global pruning, the n_pruned may be shared by multiple layers. For each layer, we should know how many channels/dim should be pruned.
                                    n_pruned_per_group = len((imp <= thres).nonzero().view(-1)) 
                                else: # for local pruning, we can directly use the n_pruned since each scope only contains one layer
                                    n_pruned_per_group = n_pruned
                                if n_pruned_per_group>0:
                                    if self.round_to:
                                        n_pruned_per_group = self._round_to(n_pruned_per_group, group_size, self.round_to)
                                    _is_attn, _ = self._is_attn_group(group)
                                    if not _is_attn or self.prune_head_dims==True:
                                        raw_imp = self.estimate_importance(group) # re-compute importance
                                        for chg in range(ch_groups): # determine pruning indices for each channel group independently
                                            sub_group_imp = raw_imp[chg*group_size: (chg+1)*group_size]
                                            sub_imp_argsort = torch.argsort(sub_group_imp)
                                            sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group]+chg*group_size
                                            pruning_indices.append(sub_pruning_idxs)
                            else: # standard pruning
                                if self.global_pruning:
                                    _pruning_indices = (imp <= thres).nonzero().view(-1)
                                else:
                                    _pruning_indices = topk_indices
                                imp_argsort = torch.argsort(imp)
                                if len(_pruning_indices)>0 and self.round_to: 
                                    n_pruned = len(_pruning_indices)
                                    current_channels = get_channel_fn(module)
                                    n_pruned = self._round_to(n_pruned, current_channels, self.round_to)
                                    _pruning_indices = imp_argsort[:n_pruned]
                                pruning_indices.append(_pruning_indices)
                        
                        # Prune heads
                        if len(ranking_scope[ATTN_HEAD_SCOPE])>0 and n_heads_removed>0:
                            if group in ranking_scope[ATTN_HEAD_SCOPE]:
                                qkv_layers, head_imp = ranking_scope[ATTN_HEAD_SCOPE][group]
                                if not self.global_pruning:
                                    n_heads_removed_per_group = int(self.get_target_head_pruning_ratio(qkv_layers[0]) * len(head_imp))
                                    head_pruning_indices = torch.topk(head_imp, k=n_heads_removed_per_group, largest=False)[1] # local ranking
                                else:
                                    head_pruning_indices = (head_imp <= head_thres).nonzero().view(-1) # global ranking
                                if len(head_pruning_indices)>0:
                                    for head_id in head_pruning_indices:
                                        pruning_indices.append( torch.arange(head_id*group_size, (head_id+1)*group_size, device=head_imp.device) )
                                for qkv_layer in qkv_layers:
                                    self.num_heads[qkv_layer] -= len(head_pruning_indices) # update num heads after pruning
                                    self.out_channel_groups[qkv_layer] = self.num_heads[qkv_layer] # update out_channel_groups
                        if len(pruning_indices)==0: continue
                        pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist()
                        
                        if isinstance(self.importance, OBDCImportance):
                            self.importance.adjust_fisher(group, pruning_indices)
                        # create pruning group
                        group = self.DG.get_pruning_group(
                            module, pruning_fn, pruning_indices)
                        if self.DG.check_pruning_group(group):
                            yield group # yield the group for interactive pruning
