# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Optional, Tuple

import torch.nn as nn

from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling,
                                                    SparseBatchNorm2d,
                                                    SparseConv2d,
                                                    SparseMaxPooling,
                                                    SparseSyncBatchNorm2d)
from mmpretrain.registry import MODELS
from .resnet import ResNet


@MODELS.register_module()
class SparseResNet(ResNet):
    """ResNet with sparse module conversion function.

    Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py

    Args:
        depth (int): Network depth, from {18, 34, 50, 101, 152}.
        in_channels (int): Number of input image channels. Defaults to 3.
        stem_channels (int): Output channels of the stem layer. Defaults to 64.
        base_channels (int): Middle channels of the first stage.
            Defaults to 64.
        num_stages (int): Stages of the network. Defaults to 4.
        strides (Sequence[int]): Strides of the first block of each stage.
            Defaults to ``(1, 2, 2, 2)``.
        dilations (Sequence[int]): Dilation of each stage.
            Defaults to ``(1, 1, 1, 1)``.
        out_indices (Sequence[int]): Output from which stages.
            Defaults to ``(3, )``.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
            Defaults to False.
        avg_down (bool): Use AvgPool instead of stride conv when
            downsampling in the bottleneck. Defaults to False.
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters. Defaults to -1.
        conv_cfg (dict | None): The config dict for conv layers.
            Defaults to None.
        norm_cfg (dict): The config dict for norm layers.
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only. Defaults to False.
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed. Defaults to False.
        zero_init_residual (bool): Whether to use zero init for last norm layer
            in resblocks to let them behave as identity. Defaults to True.
        drop_path_rate (float): stochastic depth rate. Defaults to 0.
    """

    def __init__(self,
                 depth: int,
                 in_channels: int = 3,
                 stem_channels: int = 64,
                 base_channels: int = 64,
                 expansion: Optional[int] = None,
                 num_stages: int = 4,
                 strides: Tuple[int] = (1, 2, 2, 2),
                 dilations: Tuple[int] = (1, 1, 1, 1),
                 out_indices: Tuple[int] = (3, ),
                 style: str = 'pytorch',
                 deep_stem: bool = False,
                 avg_down: bool = False,
                 frozen_stages: int = -1,
                 conv_cfg: Optional[dict] = None,
                 norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'),
                 norm_eval: bool = False,
                 with_cp: bool = False,
                 zero_init_residual: bool = False,
                 init_cfg: Optional[dict] = [
                     dict(type='Kaiming', layer=['Conv2d']),
                     dict(
                         type='Constant',
                         val=1,
                         layer=['_BatchNorm', 'GroupNorm'])
                 ],
                 drop_path_rate: float = 0,
                 **kwargs):
        super().__init__(
            depth=depth,
            in_channels=in_channels,
            stem_channels=stem_channels,
            base_channels=base_channels,
            expansion=expansion,
            num_stages=num_stages,
            strides=strides,
            dilations=dilations,
            out_indices=out_indices,
            style=style,
            deep_stem=deep_stem,
            avg_down=avg_down,
            frozen_stages=frozen_stages,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            norm_eval=norm_eval,
            with_cp=with_cp,
            zero_init_residual=zero_init_residual,
            init_cfg=init_cfg,
            drop_path_rate=drop_path_rate,
            **kwargs)
        norm_type = norm_cfg['type']
        enable_sync_bn = False
        if re.search('Sync', norm_type) is not None:
            enable_sync_bn = True
        self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn)

    def dense_model_to_sparse(self, m: nn.Module,
                              enable_sync_bn: bool) -> nn.Module:
        """Convert regular dense modules to sparse modules."""
        output = m
        if isinstance(m, nn.Conv2d):
            m: nn.Conv2d
            bias = m.bias is not None
            output = SparseConv2d(
                m.in_channels,
                m.out_channels,
                kernel_size=m.kernel_size,
                stride=m.stride,
                padding=m.padding,
                dilation=m.dilation,
                groups=m.groups,
                bias=bias,
                padding_mode=m.padding_mode,
            )
            output.weight.data.copy_(m.weight.data)
            if bias:
                output.bias.data.copy_(m.bias.data)

        elif isinstance(m, nn.MaxPool2d):
            m: nn.MaxPool2d
            output = SparseMaxPooling(
                m.kernel_size,
                stride=m.stride,
                padding=m.padding,
                dilation=m.dilation,
                return_indices=m.return_indices,
                ceil_mode=m.ceil_mode)

        elif isinstance(m, nn.AvgPool2d):
            m: nn.AvgPool2d
            output = SparseAvgPooling(
                m.kernel_size,
                m.stride,
                m.padding,
                ceil_mode=m.ceil_mode,
                count_include_pad=m.count_include_pad,
                divisor_override=m.divisor_override)

        elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
            m: nn.BatchNorm2d
            output = (SparseSyncBatchNorm2d
                      if enable_sync_bn else SparseBatchNorm2d)(
                          m.weight.shape[0],
                          eps=m.eps,
                          momentum=m.momentum,
                          affine=m.affine,
                          track_running_stats=m.track_running_stats)
            output.weight.data.copy_(m.weight.data)
            output.bias.data.copy_(m.bias.data)
            output.running_mean.data.copy_(m.running_mean.data)
            output.running_var.data.copy_(m.running_var.data)
            output.num_batches_tracked.data.copy_(m.num_batches_tracked.data)

        elif isinstance(m, (nn.Conv1d, )):
            raise NotImplementedError

        for name, child in m.named_children():
            output.add_module(
                name,
                self.dense_model_to_sparse(
                    child, enable_sync_bn=enable_sync_bn))
        del m
        return output
