# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import torch
from mmcv.runner import save_checkpoint
from torch import nn as nn

from mmdet3d.apis import init_model


def fuse_conv_bn(conv, bn):
    """During inference, the functionary of batch norm layers is turned off but
    only the mean and var alone channels are used, which exposes the chance to
    fuse it with the preceding conv layers to save computations and simplify
    network structures."""
    conv_w = conv.weight
    conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
        bn.running_mean)

    factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
    conv.weight = nn.Parameter(conv_w *
                               factor.reshape([conv.out_channels, 1, 1, 1]))
    conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
    return conv


def fuse_module(m):
    last_conv = None
    last_conv_name = None

    for name, child in m.named_children():
        if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)):
            if last_conv is None:  # only fuse BN that is after Conv
                continue
            fused_conv = fuse_conv_bn(last_conv, child)
            m._modules[last_conv_name] = fused_conv
            # To reduce changes, set BN as Identity instead of deleting it.
            m._modules[name] = nn.Identity()
            last_conv = None
        elif isinstance(child, nn.Conv2d):
            last_conv = child
            last_conv_name = name
        else:
            fuse_module(child)
    return m


def parse_args():
    parser = argparse.ArgumentParser(
        description='fuse Conv and BN layers in a model')
    parser.add_argument('config', help='config file path')
    parser.add_argument('checkpoint', help='checkpoint file path')
    parser.add_argument('out', help='output path of the converted model')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    # build the model from a config file and a checkpoint file
    model = init_model(args.config, args.checkpoint)
    # fuse conv and bn layers of the model
    fused_model = fuse_module(model)
    save_checkpoint(fused_model, args.out)


if __name__ == '__main__':
    main()
