# Copyright (c) Facebook, Inc. and its affiliates.
from collections import OrderedDict

from detectron2.checkpoint import DetectionCheckpointer


def _rename_HRNet_weights(weights):
    # We detect and  rename HRNet weights for DensePose. 1956 and 1716 are values that are
    # common to all HRNet pretrained weights, and should be enough to accurately identify them
    if (
        len(weights["model"].keys()) == 1956
        and len([k for k in weights["model"].keys() if k.startswith("stage")]) == 1716
    ):
        hrnet_weights = OrderedDict()
        for k in weights["model"].keys():
            hrnet_weights["backbone.bottom_up." + str(k)] = weights["model"][k]
        return {"model": hrnet_weights}
    else:
        return weights


class DensePoseCheckpointer(DetectionCheckpointer):
    """
    Same as :class:`DetectionCheckpointer`, but is able to handle HRNet weights
    """

    def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables):
        super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables)

    def _load_file(self, filename: str) -> object:
        """
        Adding hrnet support
        """
        weights = super()._load_file(filename)
        return _rename_HRNet_weights(weights)
