# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
from collections import OrderedDict

import torch


def get_layer_maps(layer_num, with_bn):
    layer_maps = {'conv': {}, 'bn': {}}
    if with_bn:
        if layer_num == 11:
            layer_idxs = [0, 4, 8, 11, 15, 18, 22, 25]
        elif layer_num == 13:
            layer_idxs = [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]
        elif layer_num == 16:
            layer_idxs = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40]
        elif layer_num == 19:
            layer_idxs = [
                0, 3, 7, 10, 14, 17, 20, 23, 27, 30, 33, 36, 40, 43, 46, 49
            ]
        else:
            raise ValueError(f'Invalid number of layers: {layer_num}')
        for i, layer_idx in enumerate(layer_idxs):
            if i == 0:
                new_layer_idx = layer_idx
            else:
                new_layer_idx += int((layer_idx - layer_idxs[i - 1]) / 2)
            layer_maps['conv'][layer_idx] = new_layer_idx
            layer_maps['bn'][layer_idx + 1] = new_layer_idx
    else:
        if layer_num == 11:
            layer_idxs = [0, 3, 6, 8, 11, 13, 16, 18]
            new_layer_idxs = [0, 2, 4, 5, 7, 8, 10, 11]
        elif layer_num == 13:
            layer_idxs = [0, 2, 5, 7, 10, 12, 15, 17, 20, 22]
            new_layer_idxs = [0, 1, 3, 4, 6, 7, 9, 10, 12, 13]
        elif layer_num == 16:
            layer_idxs = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
            new_layer_idxs = [0, 1, 3, 4, 6, 7, 8, 10, 11, 12, 14, 15, 16]
        elif layer_num == 19:
            layer_idxs = [
                0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34
            ]
            new_layer_idxs = [
                0, 1, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19
            ]
        else:
            raise ValueError(f'Invalid number of layers: {layer_num}')

        layer_maps['conv'] = {
            layer_idx: new_layer_idx
            for layer_idx, new_layer_idx in zip(layer_idxs, new_layer_idxs)
        }

    return layer_maps


def convert(src, dst, layer_num, with_bn=False):
    """Convert keys in torchvision pretrained VGG models to mmpretrain
    style."""

    # load pytorch model
    assert os.path.isfile(src), f'no checkpoint found at {src}'
    blobs = torch.load(src, map_location='cpu')

    # convert to pytorch style
    state_dict = OrderedDict()

    layer_maps = get_layer_maps(layer_num, with_bn)

    prefix = 'backbone'
    delimiter = '.'
    for key, weight in blobs.items():
        if 'features' in key:
            module, layer_idx, weight_type = key.split(delimiter)
            new_key = delimiter.join([prefix, key])
            layer_idx = int(layer_idx)
            for layer_key, maps in layer_maps.items():
                if layer_idx in maps:
                    new_layer_idx = maps[layer_idx]
                    new_key = delimiter.join([
                        prefix, 'features',
                        str(new_layer_idx), layer_key, weight_type
                    ])
            state_dict[new_key] = weight
            print(f'Convert {key} to {new_key}')
        elif 'classifier' in key:
            new_key = delimiter.join([prefix, key])
            state_dict[new_key] = weight
            print(f'Convert {key} to {new_key}')
        else:
            state_dict[key] = weight

    # save checkpoint
    checkpoint = dict()
    checkpoint['state_dict'] = state_dict
    torch.save(checkpoint, dst)


def main():
    parser = argparse.ArgumentParser(description='Convert model keys')
    parser.add_argument('src', help='src torchvision model path')
    parser.add_argument('dst', help='save path')
    parser.add_argument(
        '--bn', action='store_true', help='whether original vgg has BN')
    parser.add_argument(
        '--layer-num',
        type=int,
        choices=[11, 13, 16, 19],
        default=11,
        help='number of VGG layers')
    args = parser.parse_args()
    convert(args.src, args.dst, layer_num=args.layer_num, with_bn=args.bn)


if __name__ == '__main__':
    main()
