# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import pickle
from collections import OrderedDict

import torch

from maskrcnn_benchmark.utils.model_serialization import load_state_dict
from maskrcnn_benchmark.utils.registry import Registry


def _rename_basic_resnet_weights(layer_keys):
    layer_keys = [k.replace("_", ".") for k in layer_keys]
    layer_keys = [k.replace(".w", ".weight") for k in layer_keys]
    layer_keys = [k.replace(".bn", "_bn") for k in layer_keys]
    layer_keys = [k.replace(".b", ".bias") for k in layer_keys]
    layer_keys = [k.replace("_bn.s", "_bn.scale") for k in layer_keys]
    layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys]
    layer_keys = [k.replace("bbox.pred", "bbox_pred") for k in layer_keys]
    layer_keys = [k.replace("cls.score", "cls_score") for k in layer_keys]
    layer_keys = [k.replace("res.conv1_", "conv1_") for k in layer_keys]

    # RPN / Faster RCNN
    layer_keys = [k.replace(".biasbox", ".bbox") for k in layer_keys]
    layer_keys = [k.replace("conv.rpn", "rpn.conv") for k in layer_keys]
    layer_keys = [k.replace("rpn.bbox.pred", "rpn.bbox_pred") for k in layer_keys]
    layer_keys = [k.replace("rpn.cls.logits", "rpn.cls_logits") for k in layer_keys]

    # Affine-Channel -> BatchNorm enaming
    layer_keys = [k.replace("_bn.scale", "_bn.weight") for k in layer_keys]

    # Make torchvision-compatible
    layer_keys = [k.replace("conv1_bn.", "bn1.") for k in layer_keys]

    layer_keys = [k.replace("res2.", "layer1.") for k in layer_keys]
    layer_keys = [k.replace("res3.", "layer2.") for k in layer_keys]
    layer_keys = [k.replace("res4.", "layer3.") for k in layer_keys]
    layer_keys = [k.replace("res5.", "layer4.") for k in layer_keys]

    layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
    layer_keys = [k.replace(".branch2a_bn.", ".bn1.") for k in layer_keys]
    layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
    layer_keys = [k.replace(".branch2b_bn.", ".bn2.") for k in layer_keys]
    layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
    layer_keys = [k.replace(".branch2c_bn.", ".bn3.") for k in layer_keys]

    layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys]
    layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys]

    # GroupNorm
    layer_keys = [k.replace("conv1.gn.s", "bn1.weight") for k in layer_keys]
    layer_keys = [k.replace("conv1.gn.bias", "bn1.bias") for k in layer_keys]
    layer_keys = [k.replace("conv2.gn.s", "bn2.weight") for k in layer_keys]
    layer_keys = [k.replace("conv2.gn.bias", "bn2.bias") for k in layer_keys]
    layer_keys = [k.replace("conv3.gn.s", "bn3.weight") for k in layer_keys]
    layer_keys = [k.replace("conv3.gn.bias", "bn3.bias") for k in layer_keys]
    layer_keys = [k.replace("downsample.0.gn.s", "downsample.1.weight") \
        for k in layer_keys]
    layer_keys = [k.replace("downsample.0.gn.bias", "downsample.1.bias") \
        for k in layer_keys]

    return layer_keys

def _rename_fpn_weights(layer_keys, stage_names):
    for mapped_idx, stage_name in enumerate(stage_names, 1):
        suffix = ""
        if mapped_idx < 4:
            suffix = ".lateral"
        layer_keys = [
            k.replace("fpn.inner.layer{}.sum{}".format(stage_name, suffix), "fpn_inner{}".format(mapped_idx)) for k in layer_keys
        ]
        layer_keys = [k.replace("fpn.layer{}.sum".format(stage_name), "fpn_layer{}".format(mapped_idx)) for k in layer_keys]


    layer_keys = [k.replace("rpn.conv.fpn2", "rpn.conv") for k in layer_keys]
    layer_keys = [k.replace("rpn.bbox_pred.fpn2", "rpn.bbox_pred") for k in layer_keys]
    layer_keys = [
        k.replace("rpn.cls_logits.fpn2", "rpn.cls_logits") for k in layer_keys
    ]

    return layer_keys


def _rename_weights_for_resnet(weights, stage_names):
    original_keys = sorted(weights.keys())
    layer_keys = sorted(weights.keys())

    # for X-101, rename output to fc1000 to avoid conflicts afterwards
    layer_keys = [k if k != "pred_b" else "fc1000_b" for k in layer_keys]
    layer_keys = [k if k != "pred_w" else "fc1000_w" for k in layer_keys]

    # performs basic renaming: _ -> . , etc
    layer_keys = _rename_basic_resnet_weights(layer_keys)

    # FPN
    layer_keys = _rename_fpn_weights(layer_keys, stage_names)

    # Mask R-CNN
    layer_keys = [k.replace("mask.fcn.logits", "mask_fcn_logits") for k in layer_keys]
    layer_keys = [k.replace(".[mask].fcn", "mask_fcn") for k in layer_keys]
    layer_keys = [k.replace("conv5.mask", "conv5_mask") for k in layer_keys]

    # Keypoint R-CNN
    layer_keys = [k.replace("kps.score.lowres", "kps_score_lowres") for k in layer_keys]
    layer_keys = [k.replace("kps.score", "kps_score") for k in layer_keys]
    layer_keys = [k.replace("conv.fcn", "conv_fcn") for k in layer_keys]

    # Rename for our RPN structure
    layer_keys = [k.replace("rpn.", "rpn.head.") for k in layer_keys]

    key_map = {k: v for k, v in zip(original_keys, layer_keys)}

    logger = logging.getLogger(__name__)
    logger.info("Remapping C2 weights")
    max_c2_key_size = max([len(k) for k in original_keys if "_momentum" not in k])

    new_weights = OrderedDict()
    for k in original_keys:
        v = weights[k]
        if "_momentum" in k:
            continue
        # if 'fc1000' in k:
        #     continue
        w = torch.from_numpy(v)
        # if "bn" in k:
        #     w = w.view(1, -1, 1, 1)
        logger.info("C2 name: {: <{}} mapped name: {}".format(k, max_c2_key_size, key_map[k]))
        new_weights[key_map[k]] = w

    return new_weights


def _load_c2_pickled_weights(file_path):
    with open(file_path, "rb") as f:
        if torch._six.PY3:
            data = pickle.load(f, encoding="latin1")
        else:
            data = pickle.load(f)
    if "blobs" in data:
        weights = data["blobs"]
    else:
        weights = data
    return weights


def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
    import re
    logger = logging.getLogger(__name__)
    logger.info("Remapping conv weights for deformable conv weights")
    layer_keys = sorted(state_dict.keys())
    for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
        if not stage_with_dcn:
            continue
        for old_key in layer_keys:
            pattern = ".*layer{}.*conv2.*".format(ix)
            r = re.match(pattern, old_key)
            if r is None:
                continue
            for param in ["weight", "bias"]:
                if old_key.find(param) == -1:
                    continue
                new_key = old_key.replace(
                    "conv2.{}".format(param), "conv2.conv.{}".format(param)
                )
                logger.info("pattern: {}, old_key: {}, new_key: {}".format(
                    pattern, old_key, new_key
                ))
                state_dict[new_key] = state_dict[old_key]
                del state_dict[old_key]
    return state_dict


_C2_STAGE_NAMES = {
    "R-50": ["1.2", "2.3", "3.5", "4.2"],
    "R-101": ["1.2", "2.3", "3.22", "4.2"],
    "R-152": ["1.2", "2.7", "3.35", "4.2"],
}

C2_FORMAT_LOADER = Registry()


@C2_FORMAT_LOADER.register("R-50-C4")
@C2_FORMAT_LOADER.register("R-50-C5")
@C2_FORMAT_LOADER.register("R-101-C4")
@C2_FORMAT_LOADER.register("R-101-C5")
@C2_FORMAT_LOADER.register("R-50-FPN")
@C2_FORMAT_LOADER.register("R-50-FPN-RETINANET")
@C2_FORMAT_LOADER.register("R-101-FPN")
@C2_FORMAT_LOADER.register("R-101-FPN-RETINANET")
@C2_FORMAT_LOADER.register("R-152-FPN")
def load_resnet_c2_format(cfg, f):
    state_dict = _load_c2_pickled_weights(f)
    conv_body = cfg.MODEL.BACKBONE.CONV_BODY
    arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "")
    arch = arch.replace("-RETINANET", "")
    stages = _C2_STAGE_NAMES[arch]
    state_dict = _rename_weights_for_resnet(state_dict, stages)
    # ***********************************
    # for deformable convolutional layer
    state_dict = _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg)
    # ***********************************
    return dict(model=state_dict)


def load_c2_format(cfg, f):
    return C2_FORMAT_LOADER[cfg.MODEL.BACKBONE.CONV_BODY](cfg, f)
