# Copyright (c) OpenMMLab. All rights reserved.
import copy
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from types import MethodType

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import BaseModule
from ordered_set import OrderedSet
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.models.builder import PRUNERS
from .utils import SwitchableBatchNorm2d

# These grad_fn pattern are flags of specific a nn.Module
CONV = ('ThnnConv2DBackward', 'CudnnConvolutionBackward',
        'MkldnnConvolutionBackward')
FC = ('ThAddmmBackward', 'AddmmBackward', 'MmBackward')
BN = ('ThnnBatchNormBackward', 'CudnnBatchNormBackward',
      'NativeBatchNormBackward')
CONCAT = ('CatBackward', )
# the modules which contains NON_PASS grad_fn need to change the parameter size
# according to channels after pruning
NON_PASS = CONV + FC
NON_PASS_MODULE = (nn.Conv2d, nn.Linear)

PASS = BN
PASS_MODULE = (_BatchNorm)

BACKWARD_PARSER_DICT = dict()
MAKE_GROUP_PARSER_DICT = dict()


def register_parser(parser_dict, name=None, force=False):
    """Register a parser function.

    A record will be added to ``parser_dict``, whose key is the specified
    ``name``, and value is the function itself.
    It can be used as a decorator or a normal function.

    Example:
        >>> BACKWARD_PARSER_DICT = dict()
        >>> @register_parser(BACKWARD_PARSER_DICT, 'ThnnConv2DBackward')
        >>> def conv_backward_parser():
        >>>     pass

    Args:
        parser_dict (dict): A dict to map strings to parser functions.
        name (str | None): The function name to be registered. If not
            specified, the function name will be used.
        force (bool, optional): Whether to override an existing function with
            the same name. Default: False.
    """

    def _register(parser_func):
        parser_name = parser_func.__name__ if name is None else name
        if (parser_name not in parser_dict) or force:
            parser_dict[parser_name] = parser_func
        else:
            raise KeyError(
                f'{parser_name} is already registered in task_dict, '
                'add "force=True" if you want to override it')
        return parser_func

    return _register


@PRUNERS.register_module()
class StructurePruner(BaseModule, metaclass=ABCMeta):
    """Base class for structure pruning. This class defines the basic functions
    of a structure pruner. Any pruner that inherits this class should at least
    define its own `sample_subnet` and `set_min_channel` functions. This part
    is being continuously optimized, and there may be major changes in the
    future.

    Reference to https://github.com/jshilong/FisherPruning

    Args:
        except_start_keys (List[str]): the module whose name start with a
            string in except_start_keys will not be prune.
    """

    def __init__(self, except_start_keys=['head.fc']):
        super(StructurePruner, self).__init__()
        if except_start_keys is None:
            self.except_start_keys = list()
        else:
            self.except_start_keys = except_start_keys

    def prepare_from_supernet(self, supernet):
        """Prepare for pruning."""

        module2name = OrderedDict()
        name2module = OrderedDict()
        var2module = OrderedDict()

        # record the visited module name during trace path
        visited = dict()

        for name, module in supernet.model.named_modules():
            if hasattr(module, 'weight'):
                module2name[module] = name
                name2module[name] = module
                var2module[id(module.weight)] = module
                self.add_pruning_attrs(module)
                visited[name] = False
            if isinstance(module, SwitchableBatchNorm2d):
                name2module[name] = module
        self.name2module = name2module

        pseudo_img = torch.randn(1, 3, 224, 224)
        # todo: support two stage detector and mmseg
        pseudo_img = supernet.forward_dummy(pseudo_img)
        pseudo_loss = supernet.cal_pseudo_loss(pseudo_img)

        non_pass_paths = list()
        cur_non_pass_path = list()
        self.trace_non_pass_path(pseudo_loss.grad_fn, module2name, var2module,
                                 cur_non_pass_path, non_pass_paths, visited)

        bn_conv_links = dict()
        self.trace_bn_conv_links(pseudo_loss.grad_fn, module2name, var2module,
                                 bn_conv_links, visited)
        self.bn_conv_links = bn_conv_links

        # a node can be the name of a conv module or a str like 'concat_{id}'
        node2parents = self.find_node_parents(non_pass_paths)
        self.node2parents = node2parents

        same_out_channel_groups = self.make_same_out_channel_groups(
            node2parents, name2module)

        self.module2group = dict()
        for group_name, group in same_out_channel_groups.items():
            for module_name in group:
                self.module2group[module_name] = group_name

        self.modules_have_ancest = list()
        for node_name, parents_name in node2parents.items():
            if node_name in name2module and len(parents_name) > 0:
                self.modules_have_ancest.append(node_name)

        self.modules_have_child = OrderedSet()
        for parents_name in node2parents.values():
            for name in parents_name:
                # The node is a module in supernet
                if name in name2module:
                    self.modules_have_child.add(name)

        self.channel_spaces = self.build_channel_spaces(name2module)

    @abstractmethod
    def sample_subnet(self):
        """Sample a subnet from the supernet.

        Returns:
            dict: Record the information to build the subnet from the supernet,
                its keys are the properties ``space_id`` in the pruner's search
                spaces, and its values are corresponding sampled out_mask.
        """
        pass

    def get_space_id(self, module_name):
        """Get the corresponding space_id of the module_name.

        The modules who share the same space_id will share the same out_mask.
        If the module is the output module(there is no other ``nn.Module``
        whose input is its output), this function will return None. As the
        output module can not be pruned.
        If the input of this module is the concatenation of the output
        of several ``nn.Module``, this function will return a dict object.
        If this module is in one of the groups, this function will return the
        group name. As the modules in the same group should share the same
        space_id.
        Otherwise, this function will return the module_name as space_id.

        Args:
            module_name (str): the name of a ``nn.Module``.

        Return:
            str or dict or None: the corresponding space_id of the module_name.
        """
        if 'concat' in module_name and module_name not in self.name2module:
            # each module_name in concat_parents should be in name2module
            if 'item' in module_name:
                space_id = self.get_space_id(self.node2parents[module_name][0])
            else:
                concat_parents = [
                    self.get_space_id(parent)
                    for parent in self.node2parents[module_name]
                ]
                space_id = dict(concat=concat_parents)

        elif module_name not in self.modules_have_child:
            return None
        elif module_name in self.module2group:
            space_id = self.module2group[module_name]
        else:
            space_id = module_name
        return space_id

    def set_subnet(self, subnet_dict):
        """Modify the in_mask and out_mask of modules in supernet according to
        subnet_dict.

        Args:
            subnet_dict (dict): the key is space_id and the value is the
                corresponding sampled out_mask.
        """
        for module_name in self.modules_have_child:
            space_id = self.get_space_id(module_name)
            module = self.name2module[module_name]
            module.out_mask = subnet_dict[space_id].to(module.out_mask.device)

        for bn, conv in self.bn_conv_links.items():
            module = self.name2module[bn]
            conv_space_id = self.get_space_id(conv)
            # conv_space_id is None means the conv layer in front of
            # this bn module can not be pruned. So we should not set
            # the out_mask of this bn layer
            if conv_space_id is not None:
                module.out_mask = subnet_dict[conv_space_id].to(
                    module.out_mask.device)

        for module_name in self.modules_have_ancest:
            module = self.name2module[module_name]
            parents = self.node2parents[module_name]
            # To avoid ambiguity, we only allow the following two cases:
            # 1. all elements in parents are ``Conv2d``,
            # 2. there is only one element in parents, ``concat`` or ``chunk``
            # In case 1, all the ``Conv2d`` share the same space_id and
            # out_mask.
            # So in all cases, we only need the very first element in parents
            parent = parents[0]
            space_id = self.get_space_id(parent)

            if isinstance(space_id, dict):
                if 'concat' in space_id:
                    in_mask = []
                    for parent_space_id in space_id['concat']:
                        in_mask.append(subnet_dict[parent_space_id])
                    module.in_mask = torch.cat(
                        in_mask, dim=1).to(module.in_mask.device)
            else:
                module.in_mask = subnet_dict[space_id].to(
                    module.in_mask.device)

    def export_subnet(self):
        """Generate subnet configs according to the in_mask and out_mask of a
        module."""
        channel_cfg = dict()
        for name, module in self.name2module.items():

            channel_cfg[name] = dict()
            if hasattr(module, 'in_mask'):
                channel_cfg[name]['in_channels'] = int(
                    module.in_mask.cpu().sum())
                channel_cfg[name]['raw_in_channels'] = int(
                    module.in_mask.cpu().numel())

            if hasattr(module, 'out_mask'):
                channel_cfg[name]['out_channels'] = int(
                    module.out_mask.cpu().sum())
                channel_cfg[name]['raw_out_channels'] = int(
                    module.out_mask.cpu().numel())

        return channel_cfg

    def set_max_channel(self):
        """Set the number of channels each layer to maximum."""
        subnet_dict = dict()
        for space_id, out_mask in self.channel_spaces.items():
            new_out_mask = torch.ones_like(out_mask)
            subnet_dict[space_id] = new_out_mask
        self.set_subnet(subnet_dict)

    @abstractmethod
    def set_min_channel(self):
        """Set the number of channels each layer to minimum."""
        pass

    def find_make_group_parser(self, node_name, name2module):
        """Find the corresponding make_group_parser according to the
        ``node_name``"""
        if 'concat' in node_name and node_name not in name2module:
            return MAKE_GROUP_PARSER_DICT['concat']
        elif 'chunk' in node_name and node_name not in name2module:
            return MAKE_GROUP_PARSER_DICT['chunk']
        elif (node_name in name2module
              and isinstance(name2module[node_name], nn.Conv2d)
              and name2module[node_name].in_channels
              == name2module[node_name].groups):
            return MAKE_GROUP_PARSER_DICT['depthwise']
        else:
            return MAKE_GROUP_PARSER_DICT['common']

    @register_parser(MAKE_GROUP_PARSER_DICT, 'concat')
    def concat_make_group_parser(self, node_name, parents_name, group_idx,
                                 same_in_channel_groups,
                                 same_out_channel_groups):
        if 'item' in node_name:
            return self.make_group_parser(node_name, parents_name, group_idx,
                                          same_in_channel_groups,
                                          same_out_channel_groups)
        else:
            return group_idx, same_in_channel_groups, same_out_channel_groups

    @register_parser(MAKE_GROUP_PARSER_DICT, 'chunk')
    def chunk_make_group_parser(self, group_idx, same_in_channel_groups,
                                same_out_channel_groups, **kwargs):
        return group_idx, same_in_channel_groups, same_out_channel_groups

    @register_parser(MAKE_GROUP_PARSER_DICT, 'depthwise')
    def depthwise_make_group_parser(self, node_name, parents_name, **kwargs):
        # depth wise conv should be in the same group with its parent
        parents_name.add(node_name)
        return self.make_group_parser(node_name, parents_name, **kwargs)

    @register_parser(MAKE_GROUP_PARSER_DICT, 'common')
    def make_group_parser(self, node_name, parents_name, group_idx,
                          same_in_channel_groups, same_out_channel_groups):
        added = False
        for group_name in same_in_channel_groups:
            group_parents = set(same_out_channel_groups[group_name])
            if len(parents_name.intersection(group_parents)) > 0:
                same_in_channel_groups[group_name].append(node_name)
                same_out_channel_groups[group_name] = list(
                    parents_name.union(group_parents))
                added = True
                break
        if not added:
            group_idx += 1
            same_in_channel_groups[group_idx] = [node_name]
            same_out_channel_groups[group_idx] = list(parents_name)

        return group_idx, same_in_channel_groups, same_out_channel_groups

    def make_same_out_channel_groups(self, node2parents, name2module):
        """Modules have the same child should be in the same group."""
        idx = -1
        # the nodes in same_out_channel_groups are parents of
        # nodes in same_in_channel_groups
        same_in_channel_groups, same_out_channel_groups = {}, {}
        for node_name, parents_name in node2parents.items():
            parser = self.find_make_group_parser(node_name, name2module)
            idx, same_in_channel_groups, same_out_channel_groups = \
                parser(self,
                       node_name=node_name,
                       parents_name=parents_name,
                       group_idx=idx,
                       same_in_channel_groups=same_in_channel_groups,
                       same_out_channel_groups=same_out_channel_groups)

        groups = dict()
        idx = 0
        for group in same_out_channel_groups.values():
            if len(group) > 1:
                group_name = f'group_{idx}'
                groups[group_name] = group
                idx += 1

        return groups

    @staticmethod
    def modify_conv_forward(module):
        """Modify the forward method of a conv layer."""

        def modified_forward(self, feature):
            feature = feature * self.in_mask
            return F.conv2d(feature, self.weight, self.bias, self.stride,
                            self.padding, self.dilation, self.groups)

        return MethodType(modified_forward, module)

    @staticmethod
    def modify_fc_forward(module):
        """Modify the forward method of a linear layer."""

        def modified_forward(self, feature):
            if not len(self.in_mask.shape) == len(self.out_mask.shape):
                self.in_mask = self.in_mask.reshape(self.in_mask.shape[:2])

            feature = feature * self.in_mask
            return F.linear(feature, self.weight, self.bias)

        return MethodType(modified_forward, module)

    def add_pruning_attrs(self, module):
        """Add masks to a ``nn.Module``."""
        if type(module).__name__ == 'Conv2d':
            module.register_buffer(
                'in_mask',
                module.weight.new_ones((1, module.in_channels, 1, 1), ))
            module.register_buffer(
                'out_mask',
                module.weight.new_ones((1, module.out_channels, 1, 1), ))
            module.forward = self.modify_conv_forward(module)
        if type(module).__name__ == 'Linear':
            module.register_buffer(
                'in_mask', module.weight.new_ones((1, module.in_features), ))
            module.register_buffer(
                'out_mask', module.weight.new_ones((1, module.out_features), ))
            module.forward = self.modify_fc_forward(module)
        if isinstance(module, nn.modules.batchnorm._BatchNorm):
            module.register_buffer(
                'out_mask',
                module.weight.new_ones((1, len(module.weight), 1, 1), ))

    def find_node_parents(self, paths):
        """Find the parent node of a node.

        A node in the ``paths`` can be a module name or a operation name such
        as `concat_140719322997152`. Note that the string of numbers following
        ``concat`` do not have a particular meaning. It just make the operation
        name unique.

        Args:
            paths (list): The traced paths.
        """
        node2parents = dict()
        for path in paths:
            if len(path) == 0:
                continue
            for i, node_name in enumerate(path[:-1]):
                parent_name = path[i + 1]
                if node_name in node2parents.keys():
                    node2parents[node_name].add(parent_name)
                else:
                    node2parents[node_name] = OrderedSet([parent_name])

            leaf_name = path[-1]
            if leaf_name not in node2parents.keys():
                node2parents[leaf_name] = OrderedSet()
        return node2parents

    def build_channel_spaces(self, name2module):
        """Build channel search space.

        Args:
            name2module (dict): A mapping between module_name and module.

        Return:
            dict: The channel search space. The key is space_id and the value
                is the corresponding out_mask.
        """
        search_space = dict()

        for module_name in self.modules_have_child:
            need_prune = True
            for key in self.except_start_keys:
                if module_name.startswith(key):
                    need_prune = False
                    break
            if not need_prune:
                continue
            if module_name in self.module2group:
                space_id = self.module2group[module_name]
            else:
                space_id = module_name
            module = name2module[module_name]
            if space_id not in search_space:
                search_space[space_id] = module.out_mask

        return search_space

    def deploy_subnet(self, supernet, channel_cfg):
        """Deploy subnet according `channel_cfg`."""
        for name, module in supernet.model.named_modules():
            if name not in channel_cfg:
                continue

            channels_per_layer = channel_cfg[name]
            requires_grad = module.weight.requires_grad
            out_channels = channels_per_layer['out_channels']
            temp_weight = module.weight.data[:out_channels]

            if hasattr(module, 'out_channels'):
                module.out_channels = out_channels
            if hasattr(module, 'out_features'):
                module.out_features = out_channels
            if hasattr(module, 'out_mask'):
                module.out_mask = module.out_mask[:, :out_channels]

            if 'in_channels' in channels_per_layer:
                in_channels = channels_per_layer['in_channels']

                if in_channels > 1:
                    temp_weight = temp_weight[:, :in_channels].data
                if hasattr(module, 'in_channels'):
                    module.in_channels = in_channels
                if hasattr(module, 'in_features'):
                    module.in_features = in_channels
                if hasattr(module, 'in_mask'):
                    module.in_mask = module.in_mask[:, :in_channels]
                # TODO Seems not support GroupConv
                if getattr(module, 'groups', in_channels) > 1:
                    module.groups = in_channels

            module.weight = nn.Parameter(temp_weight.data)
            module.weight.requires_grad = requires_grad

            if hasattr(module, 'bias') and module.bias is not None:
                module.bias = nn.Parameter(module.bias.data[:out_channels])
                module.bias.requires_grad = requires_grad

            if hasattr(module, 'running_mean'):
                module.running_mean = module.running_mean[:out_channels]

            if hasattr(module, 'running_var'):
                module.running_var = module.running_var[:out_channels]

    def get_max_channel_bins(self, max_channel_bins):
        """Get the max number of channel bins of all the groups which can be
        pruned during searching.

        Args:
            max_channel_bins (int): The max number of bins in each layer.
        """
        channel_bins_dict = dict()
        for space_id in self.channel_spaces.keys():
            channel_bins_dict[space_id] = torch.ones((max_channel_bins, ))
        return channel_bins_dict

    def set_channel_bins(self, channel_bins_dict, max_channel_bins):
        """Set subnet according to the number of channel bins in a layer.

        Args:
            channel_bins_dict (dict): The number of bins in each layer. Key is
                the space_id of each layer and value is the corresponding
                mask of channel bin.
            max_channel_bins (int): The max number of bins in each layer.
        """
        subnet_dict = dict()
        for space_id, bin_mask in channel_bins_dict.items():
            mask = self.channel_spaces[space_id]
            shape = mask.shape
            channel_num = shape[1]
            channels_per_bin = channel_num // max_channel_bins
            new_mask = []
            for mask in bin_mask:
                new_mask.extend([1] * channels_per_bin if mask else [0] *
                                channels_per_bin)
            new_mask.extend([0] * (channel_num % max_channel_bins))
            new_mask = torch.tensor(new_mask).reshape(*shape)
            subnet_dict[space_id] = new_mask
        self.set_subnet(subnet_dict)

    def trace_non_pass_path(self, grad_fn, module2name, var2module, cur_path,
                            result_paths, visited):
        """Trace the topology of all the ``NON_PASS_MODULE``."""
        grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn

        if grad_fn is not None:
            parser = self.find_backward_parser(grad_fn)
            if parser is not None:
                parser(self, grad_fn, module2name, var2module, cur_path,
                       result_paths, visited)
            else:
                # If the op is AccumulateGrad, parents is (),
                parents = grad_fn.next_functions
                if parents is not None:
                    for parent in parents:
                        self.trace_non_pass_path(parent, module2name,
                                                 var2module, cur_path,
                                                 result_paths, visited)
        else:
            result_paths.append(copy.deepcopy(cur_path))

    def trace_bn_conv_links(self, grad_fn, module2name, var2module,
                            bn_conv_links, visited):
        """Get the convolutional layer placed before a bn layer in the model.

        Example:
            >>> conv = nn.Conv2d(3, 3, 3)
            >>> bn = nn.BatchNorm2d(3)
            >>> pseudo_img = torch.rand(1, 3, 224, 224)
            >>> out = bn(conv(pseudo_img))
            >>> print(out.grad_fn.next_functions)
            ((<ThnnConv2DBackward object at 0x0000020E40639688>, 0),
            (<AccumulateGrad object at 0x0000020E40639208>, 0),
            (<AccumulateGrad object at 0x0000020E406398C8>, 0))
            >>> # op.next_functions[0][0] is ThnnConv2DBackward means
            >>> # the parent of this NativeBatchNormBackward op is
            >>> # ThnnConv2DBackward
            >>> # op.next_functions[1][0].variable is the weight of this bn
            >>> # module
            >>> # op.next_functions[2][0].variable is the bias of this bn
            >>> # module
        """
        grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn
        if grad_fn is not None:
            is_bn_grad_fn = False
            for fn_name in BN:
                if type(grad_fn).__name__.startswith(fn_name):
                    is_bn_grad_fn = True
                    break

            if is_bn_grad_fn:
                conv_grad_fn = grad_fn.next_functions[0][0]
                conv_var = conv_grad_fn.next_functions[1][0].variable
                bn_var = grad_fn.next_functions[1][0].variable
                conv_module = var2module[id(conv_var)]
                bn_module = var2module[id(bn_var)]
                conv_name = module2name[conv_module]
                bn_name = module2name[bn_module]
                if visited[bn_name]:
                    pass
                else:
                    visited[bn_name] = True
                    bn_conv_links[bn_name] = conv_name

                    self.trace_bn_conv_links(conv_grad_fn, module2name,
                                             var2module, bn_conv_links,
                                             visited)

            else:
                # If the op is AccumulateGrad, parents is (),
                parents = grad_fn.next_functions
                if parents is not None:
                    for parent in parents:
                        self.trace_bn_conv_links(parent, module2name,
                                                 var2module, bn_conv_links,
                                                 visited)

    def find_backward_parser(self, grad_fn):
        for name, parser in BACKWARD_PARSER_DICT.items():
            if type(grad_fn).__name__.startswith(name):
                return parser

    @register_parser(BACKWARD_PARSER_DICT, 'ThnnConv2DBackward')
    @register_parser(BACKWARD_PARSER_DICT, 'CudnnConvolutionBackward')
    @register_parser(BACKWARD_PARSER_DICT, 'MkldnnConvolutionBackward')
    def conv_backward_parser(self, grad_fn, module2name, var2module, cur_path,
                             result_paths, visited):
        """Parse the backward of a conv layer.

        Example:
            >>> conv = nn.Conv2d(3, 3, 3)
            >>> pseudo_img = torch.rand(1, 3, 224, 224)
            >>> out = conv(pseudo_img)
            >>> print(out.grad_fn.next_functions)
            ((None, 0), (<AccumulateGrad object at 0x0000020E405CBD88>, 0),
            (<AccumulateGrad object at 0x0000020E405CB588>, 0))
            >>> # op.next_functions[0][0] is None means this ThnnConv2DBackward
            >>> # op has no parents
            >>> # op.next_functions[1][0].variable is the weight of this Conv2d
            >>> # module
            >>> # op.next_functions[2][0].variable is the bias of this Conv2d
            >>> # module
        """
        variable = grad_fn.next_functions[1][0].variable
        var_id = id(variable)
        module = var2module[var_id]
        name = module2name[module]
        parent = grad_fn.next_functions[0][0]
        cur_path.append(name)
        if visited[name]:
            result_paths.append(copy.deepcopy(cur_path))
        else:
            visited[name] = True
            self.trace_non_pass_path(parent, module2name, var2module, cur_path,
                                     result_paths, visited)
        cur_path.pop(-1)

    @register_parser(BACKWARD_PARSER_DICT, 'ThAddmmBackward')
    @register_parser(BACKWARD_PARSER_DICT, 'AddmmBackward')
    @register_parser(BACKWARD_PARSER_DICT, 'MmBackward')
    def linear_backward_parser(self, grad_fn, module2name, var2module,
                               cur_path, result_paths, visited):
        """Parse the backward of a conv layer.

        Example:
            >>> fc = nn.Linear(3, 3, bias=True)
            >>> input = torch.rand(3, 3)
            >>> out = fc(input)
            >>> print(out.grad_fn.next_functions)
            ((<AccumulateGrad object at 0x0000020E405F75C8>, 0), (None, 0),
            (<TBackward object at 0x0000020E405F7D48>, 0))
            >>> # op.next_functions[0][0].variable is the bias of this Linear
            >>> # module
            >>> # op.next_functions[1][0] is None means this AddmmBackward op
            >>> # has no parents
            >>> # op.next_functions[2][0] is the TBackward op, and
            >>> # op.next_functions[2][0].next_functions[0][0].variable is
            >>> # the transpose of the weight of this Linear module
        """
        variable = grad_fn.next_functions[2][0].next_functions[0][0].variable
        var_id = id(variable)
        module = var2module[var_id]
        name = module2name[module]
        parent = grad_fn.next_functions[1][0]

        cur_path.append(name)
        if visited[name]:
            result_paths.append(copy.deepcopy(cur_path))
        else:
            visited[name] = True
            self.trace_non_pass_path(parent, module2name, var2module, cur_path,
                                     result_paths, visited)
        cur_path.pop(-1)

    @register_parser(BACKWARD_PARSER_DICT, 'CatBackward')
    def concat_backward_parser(self, grad_fn, module2name, var2module,
                               cur_path, result_paths, visited):
        """Parse the backward of a concat operation.

        Example:
            >>> conv = nn.Conv2d(3, 3, 3)
            >>> pseudo_img = torch.rand(1, 3, 224, 224)
            >>> out1 = conv(pseudo_img)
            >>> out2 = conv(pseudo_img)
            >>> out = torch.cat([out1, out2], dim=1)
            >>> print(out.grad_fn.next_functions)
            ((<ThnnConv2DBackward object at 0x0000020E405F24C8>, 0),
            (<ThnnConv2DBackward object at 0x0000020E405F2648>, 0))
            >>> # the length of ``out.grad_fn.next_functions`` is two means
            >>> # ``out`` is obtained by concatenating two tensors
        """
        parents = grad_fn.next_functions
        concat_id = '_'.join([str(id(p)) for p in parents])
        name = f'concat_{concat_id}'
        cur_path.append(name)
        if name in visited and visited[name]:
            result_paths.append(copy.deepcopy(cur_path))
        else:
            visited[name] = True
            for i, parent in enumerate(parents):
                cur_path.append(f'{name}_item_{i}')
                self.trace_non_pass_path(parent, module2name, var2module,
                                         cur_path, result_paths, visited)
                if cur_path.pop(-1) != f'{name}_item_{i}':
                    print(f'{name}_item_{i}')
        cur_path.pop(-1)
